sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. sglang/bench_one_batch.py +2 -1
  2. sglang/eval/loogle_eval.py +7 -0
  3. sglang/srt/_custom_ops.py +29 -1
  4. sglang/srt/configs/deepseekvl2.py +11 -2
  5. sglang/srt/configs/internvl.py +3 -0
  6. sglang/srt/configs/janus_pro.py +3 -0
  7. sglang/srt/configs/model_config.py +10 -8
  8. sglang/srt/configs/update_config.py +3 -1
  9. sglang/srt/conversation.py +2 -1
  10. sglang/srt/custom_op.py +5 -2
  11. sglang/srt/disaggregation/common/conn.py +34 -6
  12. sglang/srt/disaggregation/decode.py +9 -1
  13. sglang/srt/disaggregation/mini_lb.py +3 -2
  14. sglang/srt/disaggregation/mooncake/conn.py +93 -76
  15. sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
  16. sglang/srt/disaggregation/nixl/conn.py +17 -13
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
  18. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
  19. sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
  20. sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
  21. sglang/srt/distributed/parallel_state.py +103 -15
  22. sglang/srt/entrypoints/engine.py +31 -33
  23. sglang/srt/entrypoints/http_server.py +20 -32
  24. sglang/srt/entrypoints/openai/protocol.py +3 -3
  25. sglang/srt/entrypoints/openai/serving_chat.py +48 -6
  26. sglang/srt/eplb/expert_location_dispatch.py +1 -1
  27. sglang/srt/function_call/base_format_detector.py +74 -12
  28. sglang/srt/function_call/deepseekv3_detector.py +26 -11
  29. sglang/srt/function_call/ebnf_composer.py +95 -63
  30. sglang/srt/function_call/function_call_parser.py +4 -2
  31. sglang/srt/function_call/kimik2_detector.py +41 -16
  32. sglang/srt/function_call/llama32_detector.py +6 -3
  33. sglang/srt/function_call/mistral_detector.py +11 -3
  34. sglang/srt/function_call/pythonic_detector.py +16 -14
  35. sglang/srt/function_call/qwen25_detector.py +12 -3
  36. sglang/srt/function_call/qwen3_coder_detector.py +151 -0
  37. sglang/srt/hf_transformers_utils.py +0 -1
  38. sglang/srt/layers/activation.py +24 -3
  39. sglang/srt/layers/attention/base_attn_backend.py +3 -1
  40. sglang/srt/layers/attention/flashattention_backend.py +3 -3
  41. sglang/srt/layers/attention/flashinfer_backend.py +40 -1
  42. sglang/srt/layers/communicator.py +12 -12
  43. sglang/srt/layers/dp_attention.py +72 -24
  44. sglang/srt/layers/linear.py +13 -102
  45. sglang/srt/layers/logits_processor.py +34 -24
  46. sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
  47. sglang/srt/layers/moe/ep_moe/layer.py +23 -402
  48. sglang/srt/layers/moe/fused_moe_native.py +7 -47
  49. sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
  50. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  52. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +54 -263
  57. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
  58. sglang/srt/layers/moe/topk.py +190 -23
  59. sglang/srt/layers/quantization/__init__.py +20 -134
  60. sglang/srt/layers/quantization/awq.py +578 -11
  61. sglang/srt/layers/quantization/awq_triton.py +339 -0
  62. sglang/srt/layers/quantization/base_config.py +85 -10
  63. sglang/srt/layers/quantization/blockwise_int8.py +17 -55
  64. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
  65. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +23 -79
  66. sglang/srt/layers/quantization/fp8.py +273 -62
  67. sglang/srt/layers/quantization/fp8_kernel.py +210 -46
  68. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  69. sglang/srt/layers/quantization/gptq.py +501 -143
  70. sglang/srt/layers/quantization/marlin_utils.py +790 -0
  71. sglang/srt/layers/quantization/modelopt_quant.py +34 -112
  72. sglang/srt/layers/quantization/moe_wna16.py +45 -49
  73. sglang/srt/layers/quantization/petit.py +252 -0
  74. sglang/srt/layers/quantization/petit_utils.py +104 -0
  75. sglang/srt/layers/quantization/qoq.py +7 -6
  76. sglang/srt/layers/quantization/scalar_type.py +352 -0
  77. sglang/srt/layers/quantization/unquant.py +422 -0
  78. sglang/srt/layers/quantization/utils.py +340 -9
  79. sglang/srt/layers/quantization/w4afp8.py +8 -4
  80. sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
  81. sglang/srt/layers/quantization/w8a8_int8.py +51 -115
  82. sglang/srt/layers/radix_attention.py +5 -3
  83. sglang/srt/layers/vocab_parallel_embedding.py +1 -41
  84. sglang/srt/lora/lora.py +0 -4
  85. sglang/srt/lora/lora_manager.py +162 -164
  86. sglang/srt/lora/lora_registry.py +124 -0
  87. sglang/srt/lora/mem_pool.py +83 -35
  88. sglang/srt/lora/utils.py +12 -5
  89. sglang/srt/managers/cache_controller.py +288 -0
  90. sglang/srt/managers/io_struct.py +60 -30
  91. sglang/srt/managers/mm_utils.py +7 -8
  92. sglang/srt/managers/schedule_batch.py +163 -113
  93. sglang/srt/managers/schedule_policy.py +68 -27
  94. sglang/srt/managers/scheduler.py +256 -86
  95. sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
  96. sglang/srt/managers/tokenizer_manager.py +38 -27
  97. sglang/srt/managers/tp_worker.py +16 -4
  98. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  99. sglang/srt/mem_cache/allocator.py +74 -23
  100. sglang/srt/mem_cache/base_prefix_cache.py +14 -2
  101. sglang/srt/mem_cache/chunk_cache.py +5 -2
  102. sglang/srt/mem_cache/hicache_storage.py +168 -0
  103. sglang/srt/mem_cache/hiradix_cache.py +194 -5
  104. sglang/srt/mem_cache/memory_pool.py +16 -1
  105. sglang/srt/mem_cache/memory_pool_host.py +44 -2
  106. sglang/srt/mem_cache/radix_cache.py +26 -0
  107. sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
  108. sglang/srt/metrics/collector.py +9 -0
  109. sglang/srt/model_executor/cuda_graph_runner.py +66 -31
  110. sglang/srt/model_executor/forward_batch_info.py +210 -25
  111. sglang/srt/model_executor/model_runner.py +147 -42
  112. sglang/srt/model_loader/loader.py +7 -1
  113. sglang/srt/model_loader/utils.py +4 -4
  114. sglang/srt/models/clip.py +1 -1
  115. sglang/srt/models/deepseek.py +9 -6
  116. sglang/srt/models/deepseek_janus_pro.py +1 -1
  117. sglang/srt/models/deepseek_v2.py +192 -173
  118. sglang/srt/models/deepseek_vl2.py +5 -5
  119. sglang/srt/models/gemma.py +48 -0
  120. sglang/srt/models/gemma2.py +52 -0
  121. sglang/srt/models/gemma3_causal.py +63 -0
  122. sglang/srt/models/gemma3_mm.py +1 -1
  123. sglang/srt/models/gemma3n_mm.py +2 -4
  124. sglang/srt/models/granitemoe.py +385 -0
  125. sglang/srt/models/grok.py +9 -3
  126. sglang/srt/models/hunyuan.py +63 -16
  127. sglang/srt/models/internvl.py +1 -1
  128. sglang/srt/models/kimi_vl.py +1 -1
  129. sglang/srt/models/llama.py +41 -0
  130. sglang/srt/models/llama4.py +11 -11
  131. sglang/srt/models/llava.py +2 -2
  132. sglang/srt/models/llavavid.py +1 -1
  133. sglang/srt/models/minicpm.py +0 -2
  134. sglang/srt/models/minicpmo.py +3 -7
  135. sglang/srt/models/minicpmv.py +1 -1
  136. sglang/srt/models/mistral.py +1 -1
  137. sglang/srt/models/mixtral.py +9 -2
  138. sglang/srt/models/mllama.py +3 -5
  139. sglang/srt/models/mllama4.py +13 -6
  140. sglang/srt/models/olmoe.py +8 -5
  141. sglang/srt/models/persimmon.py +330 -0
  142. sglang/srt/models/phi.py +321 -0
  143. sglang/srt/models/phi4mm.py +44 -4
  144. sglang/srt/models/phi4mm_audio.py +1260 -0
  145. sglang/srt/models/phi4mm_utils.py +1917 -0
  146. sglang/srt/models/phimoe.py +9 -3
  147. sglang/srt/models/qwen.py +37 -0
  148. sglang/srt/models/qwen2.py +41 -0
  149. sglang/srt/models/qwen2_5_vl.py +4 -4
  150. sglang/srt/models/qwen2_audio.py +1 -1
  151. sglang/srt/models/qwen2_moe.py +53 -9
  152. sglang/srt/models/qwen2_vl.py +4 -4
  153. sglang/srt/models/qwen3.py +65 -1
  154. sglang/srt/models/qwen3_moe.py +57 -24
  155. sglang/srt/models/vila.py +1 -1
  156. sglang/srt/multimodal/processors/base_processor.py +91 -97
  157. sglang/srt/multimodal/processors/clip.py +21 -19
  158. sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
  159. sglang/srt/multimodal/processors/gemma3.py +13 -17
  160. sglang/srt/multimodal/processors/gemma3n.py +19 -23
  161. sglang/srt/multimodal/processors/internvl.py +9 -10
  162. sglang/srt/multimodal/processors/janus_pro.py +12 -27
  163. sglang/srt/multimodal/processors/kimi_vl.py +12 -14
  164. sglang/srt/multimodal/processors/llava.py +4 -2
  165. sglang/srt/multimodal/processors/minicpm.py +35 -44
  166. sglang/srt/multimodal/processors/mlama.py +21 -18
  167. sglang/srt/multimodal/processors/mllama4.py +4 -5
  168. sglang/srt/multimodal/processors/phi4mm.py +63 -39
  169. sglang/srt/multimodal/processors/pixtral.py +14 -35
  170. sglang/srt/multimodal/processors/qwen_audio.py +65 -0
  171. sglang/srt/multimodal/processors/qwen_vl.py +16 -21
  172. sglang/srt/multimodal/processors/vila.py +14 -14
  173. sglang/srt/reasoning_parser.py +46 -4
  174. sglang/srt/sampling/sampling_batch_info.py +6 -5
  175. sglang/srt/sampling/sampling_params.py +8 -1
  176. sglang/srt/server_args.py +454 -270
  177. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
  178. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +46 -37
  179. sglang/srt/speculative/eagle_utils.py +51 -23
  180. sglang/srt/speculative/eagle_worker.py +59 -44
  181. sglang/srt/two_batch_overlap.py +10 -5
  182. sglang/srt/utils.py +44 -69
  183. sglang/test/runners.py +14 -3
  184. sglang/test/test_activation.py +50 -1
  185. sglang/test/test_block_fp8.py +8 -3
  186. sglang/test/test_block_fp8_ep.py +1 -1
  187. sglang/test/test_custom_ops.py +12 -7
  188. sglang/test/test_cutlass_w4a8_moe.py +1 -3
  189. sglang/test/test_fp4_moe.py +1 -3
  190. sglang/test/test_marlin_moe.py +286 -0
  191. sglang/test/test_marlin_utils.py +171 -0
  192. sglang/test/test_utils.py +35 -0
  193. sglang/version.py +1 -1
  194. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/METADATA +10 -10
  195. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/RECORD +198 -175
  196. sglang/srt/layers/quantization/quant_utils.py +0 -166
  197. sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
  198. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/WHEEL +0 -0
  199. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/licenses/LICENSE +0 -0
  200. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/top_level.txt +0 -0
@@ -16,7 +16,7 @@
16
16
  # and "Punica: Multi-Tenant LoRA Serving"
17
17
 
18
18
  import logging
19
- from typing import Dict, Set, Tuple
19
+ from typing import Dict, Iterable, List, Optional, Set, Tuple
20
20
 
21
21
  import torch
22
22
 
@@ -26,6 +26,7 @@ from sglang.srt.lora.backend.base_backend import BaseLoRABackend, get_backend_fr
26
26
  from sglang.srt.lora.layers import BaseLayerWithLoRA, get_lora_layer
27
27
  from sglang.srt.lora.lora import LoRAAdapter
28
28
  from sglang.srt.lora.lora_config import LoRAConfig
29
+ from sglang.srt.lora.lora_registry import LoRARef
29
30
  from sglang.srt.lora.mem_pool import LoRAMemoryPool
30
31
  from sglang.srt.lora.utils import (
31
32
  LoRABatchInfo,
@@ -53,6 +54,9 @@ class LoRAManager:
53
54
  lora_backend: str = "triton",
54
55
  tp_size: int = 1,
55
56
  tp_rank: int = 0,
57
+ max_lora_rank: Optional[int] = None,
58
+ target_modules: Optional[Iterable[str]] = None,
59
+ lora_paths: Optional[Dict[str, LoRARef]] = None,
56
60
  ):
57
61
  self.base_model: torch.nn.Module = base_model
58
62
  self.base_hf_config: AutoConfig = base_hf_config
@@ -69,7 +73,11 @@ class LoRAManager:
69
73
  self.lora_backend: BaseLoRABackend = backend_type(lora_backend)
70
74
 
71
75
  # Initialize mutable internal state of the LoRAManager.
72
- self.init_state()
76
+ self.init_state(
77
+ max_lora_rank=max_lora_rank,
78
+ target_modules=target_modules,
79
+ lora_paths=lora_paths,
80
+ )
73
81
 
74
82
  def init_cuda_graph_batch_info(self, max_bs_in_cuda_graph: int):
75
83
  self.max_bs_in_cuda_graph = max_bs_in_cuda_graph
@@ -106,91 +114,87 @@ class LoRAManager:
106
114
  success=success,
107
115
  error_message=error_message,
108
116
  loaded_adapters={
109
- name: config.path for name, config in self.configs.items()
117
+ lora_ref.lora_name: lora_ref.lora_path
118
+ for lora_ref in self.lora_refs.values()
110
119
  },
111
120
  )
112
121
 
113
- def load_lora_adapters(self, lora_paths: Dict[str, str]) -> LoRAUpdateResult:
114
- """
115
- Load LoRA adapters from the specified paths.
116
-
117
- Args:
118
- lora_paths (Dict[str, str]): A dictionary mapping LoRA adapter names to their file paths.
119
- If a LoRA adapter is already loaded, it will be skipped with a warning.
120
- """
121
-
122
- results = []
123
- for lora_name, lora_path in lora_paths.items():
124
- result = self.load_lora_adapter(lora_name, lora_path, update_state=False)
125
- results.append(result)
126
-
127
- self.update_state_from_configs()
128
-
129
- return self.create_lora_update_result(
130
- success=all(result.success for result in results),
131
- error_message="\n".join(
132
- result.error_message for result in results if not result.success
133
- ),
134
- )
135
-
136
- def load_lora_adapter(
137
- self, lora_name: str, lora_path: str, update_state: bool = True
138
- ) -> LoRAUpdateResult:
122
+ def load_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateResult:
139
123
  """
140
124
  Load a single LoRA adapter from the specified path.
141
125
 
142
126
  Args:
143
- lora_name (str): The name of the LoRA adapter.
144
- lora_path (str): The file path to the LoRA adapter.
145
- update_state (bool): Whether to refresh the internal state after loading the adapter. This is useful for batch loading.
127
+ lora_ref (LoRARef): The LoRARef object containing the LoRA name, path, and ID.
146
128
  """
129
+ assert (
130
+ lora_ref.lora_name is not None and lora_ref.lora_path is not None
131
+ ), "LoRARef must have both lora_name and lora_path set for loading."
132
+ assert (
133
+ lora_ref.lora_id not in self.loras
134
+ ), f"LoRA adapter with ID {lora_ref.lora_id} is already loaded. This should have been verified before request is sent to the backend."
147
135
 
148
- success = True
149
- error_message = ""
136
+ try:
137
+ # load configs
138
+ new_adapter = LoRAConfig(lora_ref.lora_path)
139
+ self.validate_new_adapter(new_adapter, lora_ref)
140
+ self.configs[lora_ref.lora_id] = new_adapter
150
141
 
151
- if lora_name in self.loras:
152
- success = False
153
- error_message = f"LoRA adapter {lora_name} is skipped as it is already loaded. If you want to reload it, please unload it first."
142
+ # load weights
143
+ self.load_lora_weights(lora_ref)
154
144
 
155
- try:
156
- self.configs[lora_name] = LoRAConfig(lora_path)
145
+ # keep metadata for displayed messages
146
+ self.lora_refs[lora_ref.lora_id] = lora_ref
157
147
  except Exception as e:
158
- success = False
159
- error_message = (
160
- f"Failed to load LoRA adapter {lora_name} from {lora_path}: {str(e)}"
148
+ return self.create_lora_update_result(
149
+ success=False,
150
+ error_message=str(e),
161
151
  )
162
152
 
163
- if update_state:
164
- self.update_state_from_configs()
153
+ return self.create_lora_update_result(success=True)
165
154
 
166
- return self.create_lora_update_result(
167
- success=success,
168
- error_message=error_message,
169
- )
155
+ def validate_new_adapter(self, lora_config: LoRAConfig, lora_ref: LoRARef):
156
+ """
157
+ Validate if an adapter can be loaded into the current LoRA memory pool and generate error if it is incompatible.
158
+ """
170
159
 
171
- def unload_lora_adapter(self, lora_name: str) -> LoRAUpdateResult:
160
+ memory_pool = getattr(self, "memory_pool", None)
161
+ incompatible = memory_pool and not memory_pool.can_support(lora_config)
162
+ if incompatible:
163
+ raise ValueError(
164
+ f"LoRA adapter {lora_ref.lora_name} with rank {lora_config.r} is incompatible with the current LoRA memory pool configuration. "
165
+ "Please ensure that the LoRA adapter's rank is within the configured `--max_lora_rank` and that the target modules are "
166
+ "included in `--enable_lora_modules`."
167
+ )
168
+
169
+ def unload_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateResult:
172
170
  """
173
171
  Unload LoRA adapters by their names. This will remove the adapters from the memory pool and
174
172
  delete the corresponding LoRA modules.
175
173
  """
176
174
 
177
- success = True
178
- error_message = ""
179
- if lora_name in self.loras:
180
- del self.configs[lora_name]
181
- else:
182
- error_message = f"LoRA adapter {lora_name} is not loaded."
183
- success = False
175
+ adapter = self.configs.get(lora_ref.lora_id, None)
176
+ assert (
177
+ adapter is not None
178
+ ), f"LoRA adapter with ID {lora_ref.lora_id} is not loaded. This should have been verified before request is sent to the backend."
184
179
 
185
- self.update_state_from_configs()
180
+ try:
181
+ del self.configs[lora_ref.lora_id]
182
+ del self.loras[lora_ref.lora_id]
183
+ del self.lora_refs[lora_ref.lora_id]
184
+ except Exception as e:
185
+ return self.create_lora_update_result(
186
+ success=False,
187
+ error_message=str(e),
188
+ )
186
189
 
187
- return self.create_lora_update_result(
188
- success=success,
189
- error_message=error_message,
190
- )
190
+ return self.create_lora_update_result(success=True)
191
191
 
192
192
  def prepare_lora_batch(self, forward_batch: ForwardBatch):
193
- # load active loras into lora memory pool
193
+ # Load active loras into lora memory pool
194
+ # TODO (lifuhuang): The naming of `forward_batch.lora_paths` is confusing. It actually contains a set of unique
195
+ # LoRA IDs, not LoRA paths. While unfortunately we cannot change the name in API for backward compatibility, we
196
+ # should consider (1) renaming the incorrect usage within the system, and (2) deprecating the parameter name in
197
+ # the current API schema and introducing a better request schema in the future (e.g., use `model_name`).
194
198
  cur_uids = set(forward_batch.lora_paths)
195
199
  assert len(cur_uids) <= self.max_loras_per_batch
196
200
  self.memory_pool.prepare_lora_batch(cur_uids, self.loras, self.lora_modules)
@@ -210,11 +214,11 @@ class LoRAManager:
210
214
  weight_indices = [0] * len(forward_batch.lora_paths)
211
215
  lora_ranks = [0] * self.max_loras_per_batch
212
216
  scalings = [0] * self.max_loras_per_batch
213
- for i, lora_path in enumerate(forward_batch.lora_paths):
214
- weight_indices[i] = self.memory_pool.get_buffer_id(lora_path)
215
- if lora_path is not None:
216
- lora = self.loras[lora_path]
217
- lora_ranks[weight_indices[i]] = lora.config.hf_config["r"]
217
+ for i, uid in enumerate(forward_batch.lora_paths):
218
+ weight_indices[i] = self.memory_pool.get_buffer_id(uid)
219
+ if uid is not None:
220
+ lora = self.loras[uid]
221
+ lora_ranks[weight_indices[i]] = lora.config.r
218
222
  scalings[weight_indices[i]] = lora.scaling
219
223
 
220
224
  # Use pinned memory to avoid synchronizations during host-to-device transfer
@@ -303,7 +307,7 @@ class LoRAManager:
303
307
  """
304
308
  Update all LoRA modules to associate them with the latest memory buffer.
305
309
  """
306
- for layer_id, layer_modules in self.lora_modules.items():
310
+ for layer_id, layer_modules in enumerate(self.lora_modules):
307
311
  for module_name, module in layer_modules.items():
308
312
  if "qkv_proj" in module_name:
309
313
  module.set_lora_info(
@@ -319,7 +323,7 @@ class LoRAManager:
319
323
  )
320
324
  else:
321
325
  weight_name = get_weight_name(
322
- module_name, self.lora_weight_names, LoRAType.LORA_A
326
+ module_name, self.memory_pool.lora_weight_names, LoRAType.LORA_A
323
327
  )
324
328
  module.set_lora_info(
325
329
  self.memory_pool.get_tensor(
@@ -330,125 +334,115 @@ class LoRAManager:
330
334
  ),
331
335
  )
332
336
 
333
- def init_state(self):
337
+ def init_state(
338
+ self,
339
+ max_lora_rank: Optional[int] = None,
340
+ target_modules: Optional[Iterable[str]] = None,
341
+ lora_paths: Optional[Dict[str, LoRARef]] = None,
342
+ ):
334
343
  """
335
344
  Initialize the internal (mutable) state of the LoRAManager.
336
345
 
337
- These states are mutable via the `update_state_from_configs` as LoRA adapters are loaded and unloaded dynamically.
346
+ When `lora_paths` is provided and not empty, it might be used for inferring LoRA shape info such as
347
+ the target modules and max_lora_rank.
338
348
  """
339
349
 
340
- # Configs of all active LoRA adapters.
350
+ assert lora_paths or (
351
+ max_lora_rank is not None and target_modules is not None
352
+ ), "When no initial --lora-paths is provided, you need to specify both --max-lora-rank and --lora-target-modules for LoRA initialization."
353
+
354
+ self.init_lora_adapters(lora_paths)
355
+ self.init_lora_shapes(
356
+ max_lora_rank=max_lora_rank,
357
+ target_modules=target_modules,
358
+ )
359
+ self.init_lora_weight_names()
360
+ self.init_lora_modules()
361
+ self.init_memory_pool()
362
+
363
+ def init_lora_adapters(self, lora_paths: Optional[Dict[str, LoRARef]] = None):
364
+ # Configs of all active LoRA adapters, indexed by LoRA ID.
341
365
  self.configs: Dict[str, LoRAConfig] = {}
342
366
 
343
- # LoRA adapter weights cached in CPU memory.
367
+ # LoRA adapter weights cached in CPU memory, indexed by LoRA ID.
344
368
  self.loras: Dict[str, LoRAAdapter] = {}
345
369
 
346
- # Supported weight names (e.g., qkv_proj) for LoRA A and B respectively.
347
- self.lora_weight_names: Tuple[Set[str]] = (set(), set())
370
+ # Mapping from LoRA ID to LoRARef object.
371
+ self.lora_refs: Dict[str, LoRARef] = {}
348
372
 
349
- # Look-up table that essentially maps (layer_index, module_name) to the corresponding LoRA module.
350
- self.lora_modules: Dict[int, Dict[str, BaseLayerWithLoRA]] = {
351
- i: {} for i in range(self.base_hf_config.num_hidden_layers)
352
- }
373
+ if lora_paths:
374
+ for lora_ref in lora_paths.values():
375
+ result = self.load_lora_adapter(lora_ref)
376
+ if not result.success:
377
+ raise RuntimeError(
378
+ f"Failed to load LoRA adapter {lora_ref.lora_name}: {result.error_message}"
379
+ )
353
380
 
354
- # Initialize memory pool
355
- self.memory_pool = LoRAMemoryPool(
356
- self.base_hf_config,
357
- self.max_loras_per_batch,
358
- self.dtype,
359
- self.tp_size,
360
- self.tp_rank,
361
- )
381
+ def init_lora_shapes(
382
+ self,
383
+ max_lora_rank: Optional[int] = None,
384
+ target_modules: Optional[Iterable[str]] = None,
385
+ ):
386
+ """Infer LoRA target modules and max_lora_rank from loaded adapters if not provided."""
362
387
 
363
- def update_state_from_configs(self):
364
- """
365
- Update the internal state of the LoRAManager based on the current `self.configs`. This method
366
- should be called whenever `self.configs` is modified (e.g., when new LoRA adapters are loaded).
367
-
368
- This includes:
369
- - Initializing LoRA adapters if they are not already loaded.
370
- - Collect all LoRA weight names based on the current loaded adapters.
371
- - Lazily monkey-patching the base model to use LoRA layers where applicable.
372
- - Preparing the GPU buffer pool for active LoRA weights.
373
- """
388
+ if target_modules is not None:
389
+ self.target_modules = set(target_modules)
390
+ else:
391
+ self.target_modules = set()
392
+ for config in self.configs.values():
393
+ self.target_modules.update(config.target_modules)
374
394
 
375
- # Target module names in huggingface lora configs.
376
- # e.g., {"k_proj", "q_proj", "v_proj", "o_proj"}
377
- hf_target_module_names: Set[str] = set()
378
- for config in self.configs.values():
379
- hf_target_module_names.update(config.target_modules)
380
- max_lora_dim: int = max([x.hf_config["r"] for x in self.configs.values()])
381
-
382
- # Loads / unloads LoRA adapters based on the latest configs.
383
- self.update_lora_adapters()
384
-
385
- # Lazily update states for new LoRA weight name (e.g., qkv_proj) as needed.
386
- #
387
- # Please note that the following update operations are "monotonic" by design, meaning that we update
388
- # multiple places to support the new weight names when the first adapter targeting such weight names
389
- # is loaded. However, we never "rollback" the support (e.g., convert LoRA layer back to base layer)
390
- # even if the associated adapters are unloaded later for both simplicity and practicality reasons: the
391
- # list of LoRA weight names is expected to be extremely finite and stable.
392
- self.update_lora_weight_names(hf_target_module_names)
393
- self.update_lora_modules(hf_target_module_names)
394
- self.update_memory_buffers(max_lora_dim)
395
-
396
- def update_lora_weight_names(self, hf_target_names: Set[str]):
395
+ if max_lora_rank is not None:
396
+ self.max_lora_rank = max_lora_rank
397
+ else:
398
+ self.max_lora_rank = max(
399
+ [x.hf_config["r"] for x in self.configs.values()],
400
+ default=0,
401
+ )
402
+
403
+ def init_lora_weight_names(self):
397
404
  """
398
405
  Add new LoRA weight names if needed based on the current `self.configs`.
399
406
  """
400
407
 
401
408
  # Target lora weight names for lora_a and lora_b modules respectively.
402
- for module in hf_target_names:
403
- lora_A, lora_B = get_normalized_lora_weight_names(module)
404
- self.lora_weight_names[0].update(lora_A)
405
- self.lora_weight_names[1].update(lora_B)
409
+ lora_A, lora_B = get_normalized_lora_weight_names(self.target_modules)
410
+ self.lora_weight_names: Tuple[Set[str]] = (set(lora_A), set(lora_B))
406
411
 
407
- def update_lora_adapters(self):
412
+ def load_lora_weights(self, lora_ref: LoRARef):
408
413
  """
409
- Update the LoRA adapters in CPU memory based on the current `self.configs`.
410
- It loads any new adapters that are not already loaded, and unloads any adapters
411
- that are no longer in `self.configs` (e.g., unloaded).
414
+ Load the weights of a LoRA adapter to CPU memory and conducts post-loading validation.
412
415
  """
413
-
414
- # Load new adapter weights to cpu
415
- for name, config in self.configs.items():
416
- if name not in self.loras:
417
- logger.info(f"Loading weight of LoRA adapter {name} from {config.path}")
418
- lora_adapter = LoRAAdapter(
419
- name,
420
- config,
421
- self.base_hf_config,
422
- self.load_config,
423
- self.lora_backend,
424
- )
425
- lora_adapter.initialize_weights()
426
- self.loras[name] = lora_adapter
427
-
428
- # Clean up unused LoRA adapters, copying the list to avoid modifying the dict during iteration.
429
- for name in list(self.loras):
430
- if name not in self.configs:
431
- logger.info(f"Unloading LoRA adapter {name}")
432
- del self.loras[name]
416
+ lora_adapter = LoRAAdapter(
417
+ lora_ref.lora_id,
418
+ self.configs[lora_ref.lora_id],
419
+ self.base_hf_config,
420
+ self.load_config,
421
+ self.lora_backend,
422
+ )
423
+ lora_adapter.initialize_weights()
424
+ self.loras[lora_ref.lora_id] = lora_adapter
433
425
 
434
426
  # Additional checks for flashinfer backend
435
427
  # FIXME remove the restrictions after supporting multi-rank for flashinfer backend
436
428
  if self.lora_backend == "flashinfer":
437
- lora_dims = set(x.hf_config["r"] for x in self.configs.values())
429
+ lora_dims = set(x.r for x in self.configs.values())
438
430
  scalings = set(x.scaling for x in self.loras.values())
439
431
  assert (
440
432
  len(lora_dims) == 1 and len(scalings) == 1
441
433
  ), "Flashinfer backend currently only supports single LoRA rank and scaling across all adapters. "
442
434
 
443
- def update_memory_buffers(self, max_lora_dim: int):
444
- """
445
- Update the LoRA memory pool buffers based on the current LoRA configurations and update
446
- LoRA modules to use the new buffers. This method should be called after the LoRA configurations
447
- are set or updated.
448
- """
449
-
450
- self.memory_pool.init_buffers(
451
- self.lora_weight_names, self.base_model, max_lora_dim
435
+ def init_memory_pool(self):
436
+ """(Re)initialize the LoRA memory pool based on the current configurations."""
437
+ self.memory_pool = LoRAMemoryPool(
438
+ base_hf_config=self.base_hf_config,
439
+ max_loras_per_batch=self.max_loras_per_batch,
440
+ dtype=self.dtype,
441
+ tp_size=self.tp_size,
442
+ tp_rank=self.tp_rank,
443
+ max_lora_rank=self.max_lora_rank,
444
+ lora_weight_names=self.lora_weight_names,
445
+ base_model=self.base_model,
452
446
  )
453
447
 
454
448
  def set_lora_module(self, module_name, module):
@@ -456,11 +450,16 @@ class LoRAManager:
456
450
  replace_submodule(self.base_model, module_name, lora_module)
457
451
  return lora_module
458
452
 
459
- def update_lora_modules(self, hf_target_names: Set[str]):
453
+ def init_lora_modules(self):
454
+ # Look-up table that essentially maps (layer_index, module_name) to the corresponding LoRA module.
455
+ self.lora_modules: List[Dict[str, BaseLayerWithLoRA]] = [
456
+ {} for _ in range(self.base_hf_config.num_hidden_layers)
457
+ ]
458
+
460
459
  # Target module names of customized layers defined in python/sglang/srt/layers
461
460
  # e.g., {"qkv_proj", "o_proj"}
462
461
  customized_target_names = get_customized_names_from_hf_names(
463
- hf_target_names, self.base_model
462
+ self.target_modules, self.base_model
464
463
  )
465
464
 
466
465
  for module_name, module in self.base_model.named_modules():
@@ -477,7 +476,6 @@ class LoRAManager:
477
476
  # The module should be converted if it is included in target_names
478
477
  if module_name.split(".")[-1] in customized_target_names:
479
478
  layer_id = get_layer_id(module_name)
480
- if module_name not in self.lora_modules[layer_id]:
481
- self.lora_modules[layer_id][module_name] = self.set_lora_module(
482
- module_name, module
483
- )
479
+ self.lora_modules[layer_id][module_name] = self.set_lora_module(
480
+ module_name, module
481
+ )
@@ -0,0 +1,124 @@
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+
15
+
16
+ import asyncio
17
+ from dataclasses import dataclass, field, fields
18
+ from typing import Dict, List, Optional, Union
19
+ from uuid import uuid4
20
+
21
+
22
+ @dataclass(frozen=True, slots=True)
23
+ class LoRARef:
24
+ """
25
+ Reference record for a LoRA model.
26
+
27
+ This object guarantees a unique ``lora_id`` and may include ``lora_name`` and ``lora_path``. The ID
28
+ eliminates conflicts from reused LoRA names or paths and can be used to generate deterministic cache
29
+ keys (e.g., radix cache).
30
+ """
31
+
32
+ lora_id: str = field(default_factory=lambda: uuid4().hex)
33
+ lora_name: Optional[str] = None
34
+ lora_path: Optional[str] = None
35
+
36
+ def __post_init__(self):
37
+ if self.lora_id is None:
38
+ raise ValueError("lora_id cannot be None")
39
+
40
+ def __str__(self) -> str:
41
+ parts = [
42
+ f"{f.name}={value}"
43
+ for f in fields(self)
44
+ if (value := getattr(self, f.name)) is not None
45
+ ]
46
+ return f"{self.__class__.__name__}({', '.join(parts)})"
47
+
48
+
49
+ class LoRARegistry:
50
+ """
51
+ The central registry to keep track of available LoRA adapters.
52
+
53
+ TODO (lifuhuang): This registry is intended as the foundation for overlapped lora update. We decided
54
+ to keep it in a separate PR to keep code review simple and to unblock the radix cache work.
55
+ """
56
+
57
+ def __init__(self, lora_paths: Optional[Dict[str, LoRARef]] = None):
58
+ assert lora_paths is None or all(
59
+ isinstance(lora, LoRARef) for lora in lora_paths.values()
60
+ ), (
61
+ "server_args.lora_paths should have been normalized to LoRARef objects during server initialization. "
62
+ "Please file an issue if you see this error."
63
+ )
64
+
65
+ # A dictionary to hold LoRARef objects, mapping from LoRA name to LoRARef.
66
+ self._registry: Dict[str, LoRARef] = dict(lora_paths or {})
67
+
68
+ async def register(self, lora_ref: LoRARef):
69
+ """
70
+ Register a new LoRARef object in the registry.
71
+
72
+ Args:
73
+ lora_ref (LoRARef): The LoRARef object to register.
74
+ """
75
+ if lora_ref.lora_name in self._registry:
76
+ raise ValueError(
77
+ f"LoRA with name {lora_ref.lora_name} already exists. Loaded LoRAs: {self._registry.keys()}"
78
+ )
79
+ self._registry[lora_ref.lora_name] = lora_ref
80
+
81
+ async def unregister(self, lora_name: str) -> str:
82
+ """
83
+ Unregister a LoRARef object from the registry and returns the removed LoRA ID.
84
+
85
+ Args:
86
+ lora_name (str): The name of the LoRA model to unregister.
87
+ """
88
+ lora_ref = self._registry.get(lora_name, None)
89
+ if lora_ref is None:
90
+ raise ValueError(
91
+ f"LoRA with name {lora_name} does not exist. Loaded LoRAs: {self._registry.keys()}"
92
+ )
93
+ del self._registry[lora_name]
94
+
95
+ return lora_ref.lora_id
96
+
97
+ async def acquire(self, lora_name: Union[str, List[str]]) -> Union[str, List[str]]:
98
+ """
99
+ Queries registry for LoRA IDs based on LoRA names and start tracking the usage of the corresponding LoRA adapters
100
+ by incrementing its counter.
101
+
102
+ TODO (lifuhuang): currently it only queries the registry and does not track the usage of LoRA adapters.
103
+ """
104
+
105
+ async def _acquire_single(name: str) -> str:
106
+ lora_ref = self._registry.get(name, None)
107
+ if lora_ref is None:
108
+ raise ValueError(
109
+ f"The following requested LoRA adapters are not loaded: {name}\n"
110
+ f"Loaded adapters: {self._registry.keys()}."
111
+ )
112
+ # await self._counters[lora_ref.lora_id].increment()
113
+ return lora_ref.lora_id
114
+
115
+ if isinstance(lora_name, str):
116
+ lora_id = await _acquire_single(lora_name)
117
+ return lora_id
118
+ elif isinstance(lora_name, list):
119
+ lora_ids = await asyncio.gather(
120
+ *[_acquire_single(name) for name in lora_name]
121
+ )
122
+ return lora_ids
123
+ else:
124
+ raise TypeError("lora_name must be either a string or a list of strings.")