sglang 0.4.9.post3__py3-none-any.whl → 0.4.9.post5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/lang/chat_template.py +21 -0
- sglang/srt/_custom_ops.py +29 -1
- sglang/srt/configs/internvl.py +3 -0
- sglang/srt/configs/model_config.py +5 -1
- sglang/srt/constrained/base_grammar_backend.py +10 -2
- sglang/srt/constrained/xgrammar_backend.py +7 -5
- sglang/srt/conversation.py +17 -2
- sglang/srt/debug_utils/__init__.py +0 -0
- sglang/srt/debug_utils/dump_comparator.py +131 -0
- sglang/srt/debug_utils/dumper.py +108 -0
- sglang/srt/debug_utils/text_comparator.py +172 -0
- sglang/srt/disaggregation/common/conn.py +34 -6
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +13 -1
- sglang/srt/disaggregation/mini_lb.py +3 -2
- sglang/srt/disaggregation/mooncake/conn.py +65 -20
- sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
- sglang/srt/disaggregation/nixl/conn.py +17 -13
- sglang/srt/disaggregation/prefill.py +13 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
- sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
- sglang/srt/distributed/parallel_state.py +70 -15
- sglang/srt/entrypoints/engine.py +5 -9
- sglang/srt/entrypoints/http_server.py +20 -32
- sglang/srt/entrypoints/openai/protocol.py +3 -3
- sglang/srt/entrypoints/openai/serving_chat.py +148 -72
- sglang/srt/function_call/base_format_detector.py +74 -12
- sglang/srt/function_call/deepseekv3_detector.py +26 -11
- sglang/srt/function_call/ebnf_composer.py +105 -66
- sglang/srt/function_call/function_call_parser.py +6 -4
- sglang/srt/function_call/glm4_moe_detector.py +164 -0
- sglang/srt/function_call/kimik2_detector.py +41 -16
- sglang/srt/function_call/llama32_detector.py +6 -3
- sglang/srt/function_call/mistral_detector.py +11 -3
- sglang/srt/function_call/pythonic_detector.py +16 -14
- sglang/srt/function_call/qwen25_detector.py +12 -3
- sglang/srt/function_call/{qwen3_detector.py → qwen3_coder_detector.py} +11 -9
- sglang/srt/layers/activation.py +11 -3
- sglang/srt/layers/attention/base_attn_backend.py +3 -1
- sglang/srt/layers/attention/hybrid_attn_backend.py +100 -0
- sglang/srt/layers/attention/vision.py +56 -8
- sglang/srt/layers/communicator.py +12 -12
- sglang/srt/layers/dp_attention.py +72 -24
- sglang/srt/layers/layernorm.py +26 -1
- sglang/srt/layers/logits_processor.py +46 -25
- sglang/srt/layers/moe/ep_moe/layer.py +172 -206
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=160,N=320,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +25 -224
- sglang/srt/layers/moe/fused_moe_triton/layer.py +38 -48
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +11 -8
- sglang/srt/layers/moe/topk.py +88 -34
- sglang/srt/layers/multimodal.py +11 -8
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -9
- sglang/srt/layers/quantization/fp8.py +25 -247
- sglang/srt/layers/quantization/fp8_kernel.py +78 -48
- sglang/srt/layers/quantization/modelopt_quant.py +33 -14
- sglang/srt/layers/quantization/unquant.py +24 -76
- sglang/srt/layers/quantization/utils.py +0 -9
- sglang/srt/layers/quantization/w4afp8.py +68 -17
- sglang/srt/layers/radix_attention.py +5 -3
- sglang/srt/lora/lora_manager.py +133 -169
- sglang/srt/lora/lora_registry.py +188 -0
- sglang/srt/lora/mem_pool.py +2 -2
- sglang/srt/managers/cache_controller.py +62 -13
- sglang/srt/managers/io_struct.py +19 -1
- sglang/srt/managers/mm_utils.py +154 -35
- sglang/srt/managers/multimodal_processor.py +3 -14
- sglang/srt/managers/schedule_batch.py +27 -11
- sglang/srt/managers/scheduler.py +48 -26
- sglang/srt/managers/tokenizer_manager.py +62 -28
- sglang/srt/managers/tp_worker.py +5 -4
- sglang/srt/mem_cache/allocator.py +67 -7
- sglang/srt/mem_cache/hicache_storage.py +17 -1
- sglang/srt/mem_cache/hiradix_cache.py +35 -18
- sglang/srt/mem_cache/memory_pool_host.py +3 -0
- sglang/srt/model_executor/cuda_graph_runner.py +61 -25
- sglang/srt/model_executor/forward_batch_info.py +201 -29
- sglang/srt/model_executor/model_runner.py +109 -37
- sglang/srt/models/deepseek_v2.py +63 -30
- sglang/srt/models/glm4_moe.py +1035 -0
- sglang/srt/models/glm4_moe_nextn.py +167 -0
- sglang/srt/models/interns1.py +328 -0
- sglang/srt/models/internvl.py +143 -47
- sglang/srt/models/llava.py +9 -5
- sglang/srt/models/minicpmo.py +4 -1
- sglang/srt/models/mllama4.py +10 -3
- sglang/srt/models/qwen2_moe.py +2 -6
- sglang/srt/models/qwen3_moe.py +6 -8
- sglang/srt/multimodal/processors/base_processor.py +20 -6
- sglang/srt/multimodal/processors/clip.py +2 -2
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +2 -2
- sglang/srt/multimodal/processors/gemma3.py +2 -2
- sglang/srt/multimodal/processors/gemma3n.py +2 -2
- sglang/srt/multimodal/processors/internvl.py +21 -8
- sglang/srt/multimodal/processors/janus_pro.py +2 -2
- sglang/srt/multimodal/processors/kimi_vl.py +2 -2
- sglang/srt/multimodal/processors/llava.py +4 -4
- sglang/srt/multimodal/processors/minicpm.py +2 -3
- sglang/srt/multimodal/processors/mlama.py +2 -2
- sglang/srt/multimodal/processors/mllama4.py +18 -111
- sglang/srt/multimodal/processors/phi4mm.py +2 -2
- sglang/srt/multimodal/processors/pixtral.py +2 -2
- sglang/srt/multimodal/processors/qwen_audio.py +2 -2
- sglang/srt/multimodal/processors/qwen_vl.py +2 -2
- sglang/srt/multimodal/processors/vila.py +3 -1
- sglang/srt/reasoning_parser.py +48 -5
- sglang/srt/sampling/sampling_batch_info.py +6 -5
- sglang/srt/server_args.py +132 -60
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +37 -36
- sglang/srt/speculative/eagle_utils.py +51 -23
- sglang/srt/speculative/eagle_worker.py +59 -44
- sglang/srt/two_batch_overlap.py +9 -5
- sglang/srt/utils.py +113 -69
- sglang/srt/weight_sync/utils.py +119 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_activation.py +50 -1
- sglang/test/test_utils.py +65 -5
- sglang/utils.py +19 -0
- sglang/version.py +1 -1
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/METADATA +6 -6
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/RECORD +127 -114
- sglang/srt/debug_utils.py +0 -74
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post3.dist-info → sglang-0.4.9.post5.dist-info}/top_level.txt +0 -0
sglang/srt/lora/lora_manager.py
CHANGED
@@ -16,7 +16,7 @@
|
|
16
16
|
# and "Punica: Multi-Tenant LoRA Serving"
|
17
17
|
|
18
18
|
import logging
|
19
|
-
from typing import Dict, Iterable, Optional, Set, Tuple
|
19
|
+
from typing import Dict, Iterable, List, Optional, Set, Tuple
|
20
20
|
|
21
21
|
import torch
|
22
22
|
|
@@ -26,6 +26,7 @@ from sglang.srt.lora.backend.base_backend import BaseLoRABackend, get_backend_fr
|
|
26
26
|
from sglang.srt.lora.layers import BaseLayerWithLoRA, get_lora_layer
|
27
27
|
from sglang.srt.lora.lora import LoRAAdapter
|
28
28
|
from sglang.srt.lora.lora_config import LoRAConfig
|
29
|
+
from sglang.srt.lora.lora_registry import LoRARef
|
29
30
|
from sglang.srt.lora.mem_pool import LoRAMemoryPool
|
30
31
|
from sglang.srt.lora.utils import (
|
31
32
|
LoRABatchInfo,
|
@@ -55,6 +56,7 @@ class LoRAManager:
|
|
55
56
|
tp_rank: int = 0,
|
56
57
|
max_lora_rank: Optional[int] = None,
|
57
58
|
target_modules: Optional[Iterable[str]] = None,
|
59
|
+
lora_paths: Optional[Dict[str, LoRARef]] = None,
|
58
60
|
):
|
59
61
|
self.base_model: torch.nn.Module = base_model
|
60
62
|
self.base_hf_config: AutoConfig = base_hf_config
|
@@ -64,10 +66,6 @@ class LoRAManager:
|
|
64
66
|
self.device: torch.device = next(self.base_model.parameters()).device
|
65
67
|
self.tp_size: int = tp_size
|
66
68
|
self.tp_rank: int = tp_rank
|
67
|
-
self.max_lora_rank: Optional[int] = max_lora_rank
|
68
|
-
self.target_modules: Optional[Set[str]] = (
|
69
|
-
set(target_modules) if target_modules else None
|
70
|
-
)
|
71
69
|
|
72
70
|
# LoRA backend for running sgemm kernels
|
73
71
|
logger.info(f"Using {lora_backend} as backend of LoRA kernels.")
|
@@ -75,7 +73,11 @@ class LoRAManager:
|
|
75
73
|
self.lora_backend: BaseLoRABackend = backend_type(lora_backend)
|
76
74
|
|
77
75
|
# Initialize mutable internal state of the LoRAManager.
|
78
|
-
self.init_state(
|
76
|
+
self.init_state(
|
77
|
+
max_lora_rank=max_lora_rank,
|
78
|
+
target_modules=target_modules,
|
79
|
+
lora_paths=lora_paths,
|
80
|
+
)
|
79
81
|
|
80
82
|
def init_cuda_graph_batch_info(self, max_bs_in_cuda_graph: int):
|
81
83
|
self.max_bs_in_cuda_graph = max_bs_in_cuda_graph
|
@@ -112,108 +114,87 @@ class LoRAManager:
|
|
112
114
|
success=success,
|
113
115
|
error_message=error_message,
|
114
116
|
loaded_adapters={
|
115
|
-
|
117
|
+
lora_ref.lora_name: lora_ref.lora_path
|
118
|
+
for lora_ref in self.lora_refs.values()
|
116
119
|
},
|
117
120
|
)
|
118
121
|
|
119
|
-
def
|
120
|
-
"""
|
121
|
-
Load LoRA adapters from the specified paths.
|
122
|
-
|
123
|
-
Args:
|
124
|
-
lora_paths (Dict[str, str]): A dictionary mapping LoRA adapter names to their file paths.
|
125
|
-
If a LoRA adapter is already loaded, it will be skipped with a warning.
|
126
|
-
"""
|
127
|
-
|
128
|
-
results = []
|
129
|
-
for lora_name, lora_path in lora_paths.items():
|
130
|
-
result = self.load_lora_adapter(lora_name, lora_path, update_state=False)
|
131
|
-
results.append(result)
|
132
|
-
|
133
|
-
self.update_state_from_configs()
|
134
|
-
|
135
|
-
return self.create_lora_update_result(
|
136
|
-
success=all(result.success for result in results),
|
137
|
-
error_message="\n".join(
|
138
|
-
result.error_message for result in results if not result.success
|
139
|
-
),
|
140
|
-
)
|
141
|
-
|
142
|
-
def load_lora_adapter(
|
143
|
-
self, lora_name: str, lora_path: str, update_state: bool = True
|
144
|
-
) -> LoRAUpdateResult:
|
122
|
+
def load_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateResult:
|
145
123
|
"""
|
146
124
|
Load a single LoRA adapter from the specified path.
|
147
125
|
|
148
126
|
Args:
|
149
|
-
|
150
|
-
lora_path (str): The file path to the LoRA adapter.
|
151
|
-
update_state (bool): Whether to refresh the internal state after loading the adapter. This is useful for batch loading.
|
127
|
+
lora_ref (LoRARef): The LoRARef object containing the LoRA name, path, and ID.
|
152
128
|
"""
|
129
|
+
assert (
|
130
|
+
lora_ref.lora_name is not None and lora_ref.lora_path is not None
|
131
|
+
), "LoRARef must have both lora_name and lora_path set for loading."
|
132
|
+
assert (
|
133
|
+
lora_ref.lora_id not in self.loras
|
134
|
+
), f"LoRA adapter with ID {lora_ref.lora_id} is already loaded. This should have been verified before request is sent to the backend."
|
153
135
|
|
154
|
-
|
155
|
-
|
136
|
+
try:
|
137
|
+
# load configs
|
138
|
+
new_adapter = LoRAConfig(lora_ref.lora_path)
|
139
|
+
self.validate_new_adapter(new_adapter, lora_ref)
|
140
|
+
self.configs[lora_ref.lora_id] = new_adapter
|
156
141
|
|
157
|
-
|
158
|
-
|
159
|
-
error_message = f"LoRA adapter {lora_name} is skipped as it is already loaded. If you want to reload it, please unload it first."
|
142
|
+
# load weights
|
143
|
+
self.load_lora_weights(lora_ref)
|
160
144
|
|
161
|
-
|
162
|
-
|
163
|
-
self.validate_new_adapter(lora_name, new_adapter)
|
164
|
-
self.configs[lora_name] = new_adapter
|
145
|
+
# keep metadata for displayed messages
|
146
|
+
self.lora_refs[lora_ref.lora_id] = lora_ref
|
165
147
|
except Exception as e:
|
166
|
-
|
167
|
-
|
168
|
-
|
148
|
+
return self.create_lora_update_result(
|
149
|
+
success=False,
|
150
|
+
error_message=str(e),
|
169
151
|
)
|
170
152
|
|
171
|
-
|
172
|
-
self.update_state_from_configs()
|
153
|
+
return self.create_lora_update_result(success=True)
|
173
154
|
|
174
|
-
|
175
|
-
success=success,
|
176
|
-
error_message=error_message,
|
177
|
-
)
|
178
|
-
|
179
|
-
def validate_new_adapter(self, lora_name: str, lora_config: LoRAConfig):
|
155
|
+
def validate_new_adapter(self, lora_config: LoRAConfig, lora_ref: LoRARef):
|
180
156
|
"""
|
181
157
|
Validate if an adapter can be loaded into the current LoRA memory pool and generate error if it is incompatible.
|
182
158
|
"""
|
183
159
|
|
184
|
-
|
185
|
-
|
186
|
-
)
|
160
|
+
memory_pool = getattr(self, "memory_pool", None)
|
161
|
+
incompatible = memory_pool and not memory_pool.can_support(lora_config)
|
187
162
|
if incompatible:
|
188
163
|
raise ValueError(
|
189
|
-
f"LoRA adapter {lora_name} with rank {lora_config.r} is incompatible with the current LoRA memory pool configuration. "
|
164
|
+
f"LoRA adapter {lora_ref.lora_name} with rank {lora_config.r} is incompatible with the current LoRA memory pool configuration. "
|
190
165
|
"Please ensure that the LoRA adapter's rank is within the configured `--max_lora_rank` and that the target modules are "
|
191
166
|
"included in `--enable_lora_modules`."
|
192
167
|
)
|
193
168
|
|
194
|
-
def unload_lora_adapter(self,
|
169
|
+
def unload_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateResult:
|
195
170
|
"""
|
196
171
|
Unload LoRA adapters by their names. This will remove the adapters from the memory pool and
|
197
172
|
delete the corresponding LoRA modules.
|
198
173
|
"""
|
199
174
|
|
200
|
-
|
201
|
-
|
202
|
-
|
203
|
-
|
204
|
-
else:
|
205
|
-
error_message = f"LoRA adapter {lora_name} is not loaded."
|
206
|
-
success = False
|
175
|
+
adapter = self.configs.get(lora_ref.lora_id, None)
|
176
|
+
assert (
|
177
|
+
adapter is not None
|
178
|
+
), f"LoRA adapter with ID {lora_ref.lora_id} is not loaded. This should have been verified before request is sent to the backend."
|
207
179
|
|
208
|
-
|
180
|
+
try:
|
181
|
+
del self.configs[lora_ref.lora_id]
|
182
|
+
del self.loras[lora_ref.lora_id]
|
183
|
+
del self.lora_refs[lora_ref.lora_id]
|
184
|
+
except Exception as e:
|
185
|
+
return self.create_lora_update_result(
|
186
|
+
success=False,
|
187
|
+
error_message=str(e),
|
188
|
+
)
|
209
189
|
|
210
|
-
return self.create_lora_update_result(
|
211
|
-
success=success,
|
212
|
-
error_message=error_message,
|
213
|
-
)
|
190
|
+
return self.create_lora_update_result(success=True)
|
214
191
|
|
215
192
|
def prepare_lora_batch(self, forward_batch: ForwardBatch):
|
216
|
-
#
|
193
|
+
# Load active loras into lora memory pool
|
194
|
+
# TODO (lifuhuang): The naming of `forward_batch.lora_paths` is confusing. It actually contains a set of unique
|
195
|
+
# LoRA IDs, not LoRA paths. While unfortunately we cannot change the name in API for backward compatibility, we
|
196
|
+
# should consider (1) renaming the incorrect usage within the system, and (2) deprecating the parameter name in
|
197
|
+
# the current API schema and introducing a better request schema in the future (e.g., use `model_name`).
|
217
198
|
cur_uids = set(forward_batch.lora_paths)
|
218
199
|
assert len(cur_uids) <= self.max_loras_per_batch
|
219
200
|
self.memory_pool.prepare_lora_batch(cur_uids, self.loras, self.lora_modules)
|
@@ -233,10 +214,10 @@ class LoRAManager:
|
|
233
214
|
weight_indices = [0] * len(forward_batch.lora_paths)
|
234
215
|
lora_ranks = [0] * self.max_loras_per_batch
|
235
216
|
scalings = [0] * self.max_loras_per_batch
|
236
|
-
for i,
|
237
|
-
weight_indices[i] = self.memory_pool.get_buffer_id(
|
238
|
-
if
|
239
|
-
lora = self.loras[
|
217
|
+
for i, uid in enumerate(forward_batch.lora_paths):
|
218
|
+
weight_indices[i] = self.memory_pool.get_buffer_id(uid)
|
219
|
+
if uid is not None:
|
220
|
+
lora = self.loras[uid]
|
240
221
|
lora_ranks[weight_indices[i]] = lora.config.r
|
241
222
|
scalings[weight_indices[i]] = lora.scaling
|
242
223
|
|
@@ -326,7 +307,7 @@ class LoRAManager:
|
|
326
307
|
"""
|
327
308
|
Update all LoRA modules to associate them with the latest memory buffer.
|
328
309
|
"""
|
329
|
-
for layer_id, layer_modules in self.lora_modules
|
310
|
+
for layer_id, layer_modules in enumerate(self.lora_modules):
|
330
311
|
for module_name, module in layer_modules.items():
|
331
312
|
if "qkv_proj" in module_name:
|
332
313
|
module.set_lora_info(
|
@@ -353,115 +334,94 @@ class LoRAManager:
|
|
353
334
|
),
|
354
335
|
)
|
355
336
|
|
356
|
-
def init_state(
|
337
|
+
def init_state(
|
338
|
+
self,
|
339
|
+
max_lora_rank: Optional[int] = None,
|
340
|
+
target_modules: Optional[Iterable[str]] = None,
|
341
|
+
lora_paths: Optional[Dict[str, LoRARef]] = None,
|
342
|
+
):
|
357
343
|
"""
|
358
344
|
Initialize the internal (mutable) state of the LoRAManager.
|
359
345
|
|
360
|
-
|
346
|
+
When `lora_paths` is provided and not empty, it might be used for inferring LoRA shape info such as
|
347
|
+
the target modules and max_lora_rank.
|
361
348
|
"""
|
362
349
|
|
363
|
-
|
364
|
-
|
365
|
-
|
366
|
-
# LoRA adapter weights cached in CPU memory.
|
367
|
-
self.loras: Dict[str, LoRAAdapter] = {}
|
350
|
+
assert lora_paths or (
|
351
|
+
max_lora_rank is not None and target_modules is not None
|
352
|
+
), "When no initial --lora-paths is provided, you need to specify both --max-lora-rank and --lora-target-modules for LoRA initialization."
|
368
353
|
|
369
|
-
|
370
|
-
self.
|
371
|
-
|
372
|
-
|
373
|
-
|
374
|
-
|
375
|
-
|
354
|
+
self.init_lora_adapters(lora_paths)
|
355
|
+
self.init_lora_shapes(
|
356
|
+
max_lora_rank=max_lora_rank,
|
357
|
+
target_modules=target_modules,
|
358
|
+
)
|
359
|
+
self.init_lora_weight_names()
|
360
|
+
self.init_lora_modules()
|
361
|
+
self.init_memory_pool()
|
376
362
|
|
377
|
-
|
378
|
-
#
|
379
|
-
self.
|
363
|
+
def init_lora_adapters(self, lora_paths: Optional[Dict[str, LoRARef]] = None):
|
364
|
+
# Configs of all active LoRA adapters, indexed by LoRA ID.
|
365
|
+
self.configs: Dict[str, LoRAConfig] = {}
|
380
366
|
|
381
|
-
|
382
|
-
|
383
|
-
Update the internal state of the LoRAManager based on the current `self.configs`. This method
|
384
|
-
should be called whenever `self.configs` is modified (e.g., when new LoRA adapters are loaded).
|
385
|
-
"""
|
367
|
+
# LoRA adapter weights cached in CPU memory, indexed by LoRA ID.
|
368
|
+
self.loras: Dict[str, LoRAAdapter] = {}
|
386
369
|
|
387
|
-
#
|
388
|
-
self.
|
389
|
-
# Apply the latest LoRA configurations to the internal state for inferencing.
|
390
|
-
self.apply_lora_configs()
|
370
|
+
# Mapping from LoRA ID to LoRARef object.
|
371
|
+
self.lora_refs: Dict[str, LoRARef] = {}
|
391
372
|
|
392
|
-
|
393
|
-
|
394
|
-
|
373
|
+
if lora_paths:
|
374
|
+
for lora_ref in lora_paths.values():
|
375
|
+
result = self.load_lora_adapter(lora_ref)
|
376
|
+
if not result.success:
|
377
|
+
raise RuntimeError(
|
378
|
+
f"Failed to load LoRA adapter {lora_ref.lora_name}: {result.error_message}"
|
379
|
+
)
|
395
380
|
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
400
|
-
|
401
|
-
"""
|
381
|
+
def init_lora_shapes(
|
382
|
+
self,
|
383
|
+
max_lora_rank: Optional[int] = None,
|
384
|
+
target_modules: Optional[Iterable[str]] = None,
|
385
|
+
):
|
386
|
+
"""Infer LoRA target modules and max_lora_rank from loaded adapters if not provided."""
|
402
387
|
|
403
|
-
if
|
404
|
-
|
405
|
-
|
406
|
-
|
407
|
-
|
408
|
-
|
409
|
-
|
410
|
-
if self.max_lora_rank is None:
|
411
|
-
self.max_lora_rank = max(
|
412
|
-
[x.hf_config["r"] for x in self.configs.values()],
|
413
|
-
default=0,
|
414
|
-
)
|
388
|
+
if target_modules is not None:
|
389
|
+
self.target_modules = set(target_modules)
|
390
|
+
else:
|
391
|
+
self.target_modules = set()
|
392
|
+
for config in self.configs.values():
|
393
|
+
self.target_modules.update(config.target_modules)
|
415
394
|
|
416
|
-
|
417
|
-
self.
|
418
|
-
self.update_memory_buffers()
|
395
|
+
if max_lora_rank is not None:
|
396
|
+
self.max_lora_rank = max_lora_rank
|
419
397
|
else:
|
420
|
-
|
421
|
-
|
422
|
-
|
423
|
-
assert self.memory_pool.can_support(self.configs.values()), (
|
424
|
-
"LoRA memory pool cannot support the current LoRA configuration. "
|
425
|
-
"This should never happen as we should have validated adapter compatibility. "
|
426
|
-
"Please create a Github issue to report.",
|
398
|
+
self.max_lora_rank = max(
|
399
|
+
[x.hf_config["r"] for x in self.configs.values()],
|
400
|
+
default=0,
|
427
401
|
)
|
428
402
|
|
429
|
-
def
|
403
|
+
def init_lora_weight_names(self):
|
430
404
|
"""
|
431
405
|
Add new LoRA weight names if needed based on the current `self.configs`.
|
432
406
|
"""
|
433
407
|
|
434
408
|
# Target lora weight names for lora_a and lora_b modules respectively.
|
435
409
|
lora_A, lora_B = get_normalized_lora_weight_names(self.target_modules)
|
436
|
-
self.lora_weight_names[
|
437
|
-
self.lora_weight_names[1].update(lora_B)
|
410
|
+
self.lora_weight_names: Tuple[Set[str]] = (set(lora_A), set(lora_B))
|
438
411
|
|
439
|
-
def
|
412
|
+
def load_lora_weights(self, lora_ref: LoRARef):
|
440
413
|
"""
|
441
|
-
|
442
|
-
It loads any new adapters that are not already loaded, and unloads any adapters
|
443
|
-
that are no longer in `self.configs` (e.g., unloaded).
|
414
|
+
Load the weights of a LoRA adapter to CPU memory and conducts post-loading validation.
|
444
415
|
"""
|
445
|
-
|
446
|
-
|
447
|
-
|
448
|
-
|
449
|
-
|
450
|
-
|
451
|
-
|
452
|
-
|
453
|
-
|
454
|
-
self.load_config,
|
455
|
-
self.lora_backend,
|
456
|
-
)
|
457
|
-
lora_adapter.initialize_weights()
|
458
|
-
self.loras[name] = lora_adapter
|
459
|
-
|
460
|
-
# Clean up unused LoRA adapters, copying the list to avoid modifying the dict during iteration.
|
461
|
-
for name in list(self.loras):
|
462
|
-
if name not in self.configs:
|
463
|
-
logger.info(f"Unloading LoRA adapter {name}")
|
464
|
-
del self.loras[name]
|
416
|
+
lora_adapter = LoRAAdapter(
|
417
|
+
lora_ref.lora_id,
|
418
|
+
self.configs[lora_ref.lora_id],
|
419
|
+
self.base_hf_config,
|
420
|
+
self.load_config,
|
421
|
+
self.lora_backend,
|
422
|
+
)
|
423
|
+
lora_adapter.initialize_weights()
|
424
|
+
self.loras[lora_ref.lora_id] = lora_adapter
|
465
425
|
|
466
426
|
# Additional checks for flashinfer backend
|
467
427
|
# FIXME remove the restrictions after supporting multi-rank for flashinfer backend
|
@@ -472,7 +432,7 @@ class LoRAManager:
|
|
472
432
|
len(lora_dims) == 1 and len(scalings) == 1
|
473
433
|
), "Flashinfer backend currently only supports single LoRA rank and scaling across all adapters. "
|
474
434
|
|
475
|
-
def
|
435
|
+
def init_memory_pool(self):
|
476
436
|
"""(Re)initialize the LoRA memory pool based on the current configurations."""
|
477
437
|
self.memory_pool = LoRAMemoryPool(
|
478
438
|
base_hf_config=self.base_hf_config,
|
@@ -490,7 +450,12 @@ class LoRAManager:
|
|
490
450
|
replace_submodule(self.base_model, module_name, lora_module)
|
491
451
|
return lora_module
|
492
452
|
|
493
|
-
def
|
453
|
+
def init_lora_modules(self):
|
454
|
+
# Look-up table that essentially maps (layer_index, module_name) to the corresponding LoRA module.
|
455
|
+
self.lora_modules: List[Dict[str, BaseLayerWithLoRA]] = [
|
456
|
+
{} for _ in range(self.base_hf_config.num_hidden_layers)
|
457
|
+
]
|
458
|
+
|
494
459
|
# Target module names of customized layers defined in python/sglang/srt/layers
|
495
460
|
# e.g., {"qkv_proj", "o_proj"}
|
496
461
|
customized_target_names = get_customized_names_from_hf_names(
|
@@ -511,7 +476,6 @@ class LoRAManager:
|
|
511
476
|
# The module should be converted if it is included in target_names
|
512
477
|
if module_name.split(".")[-1] in customized_target_names:
|
513
478
|
layer_id = get_layer_id(module_name)
|
514
|
-
|
515
|
-
|
516
|
-
|
517
|
-
)
|
479
|
+
self.lora_modules[layer_id][module_name] = self.set_lora_module(
|
480
|
+
module_name, module
|
481
|
+
)
|
@@ -0,0 +1,188 @@
|
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
14
|
+
|
15
|
+
|
16
|
+
import asyncio
|
17
|
+
from collections import defaultdict
|
18
|
+
from dataclasses import dataclass, field, fields
|
19
|
+
from typing import Dict, List, Optional, Union
|
20
|
+
from uuid import uuid4
|
21
|
+
|
22
|
+
from sglang.srt.aio_rwlock import RWLock
|
23
|
+
from sglang.srt.utils import ConcurrentCounter
|
24
|
+
|
25
|
+
|
26
|
+
@dataclass(frozen=True)
|
27
|
+
class LoRARef:
|
28
|
+
"""
|
29
|
+
Reference record for a LoRA model.
|
30
|
+
|
31
|
+
This object guarantees a unique ``lora_id`` and may include ``lora_name`` and ``lora_path``. The ID
|
32
|
+
eliminates conflicts from reused LoRA names or paths and can be used to generate deterministic cache
|
33
|
+
keys (e.g., radix cache).
|
34
|
+
"""
|
35
|
+
|
36
|
+
lora_id: str = field(default_factory=lambda: uuid4().hex)
|
37
|
+
lora_name: Optional[str] = None
|
38
|
+
lora_path: Optional[str] = None
|
39
|
+
|
40
|
+
def __post_init__(self):
|
41
|
+
if self.lora_id is None:
|
42
|
+
raise ValueError("lora_id cannot be None")
|
43
|
+
|
44
|
+
def __str__(self) -> str:
|
45
|
+
parts = [
|
46
|
+
f"{f.name}={value}"
|
47
|
+
for f in fields(self)
|
48
|
+
if (value := getattr(self, f.name)) is not None
|
49
|
+
]
|
50
|
+
return f"{self.__class__.__name__}({', '.join(parts)})"
|
51
|
+
|
52
|
+
|
53
|
+
class LoRARegistry:
|
54
|
+
"""
|
55
|
+
The central registry to keep track of available LoRA adapters and ongoing LoRA requests.
|
56
|
+
|
57
|
+
The `LoRARegistry` resides in the tokenizer manager process and acts as the single source of truth for all
|
58
|
+
available LoRA adapters. It supports concurrent inference and dynamic adapter updates through a two-phase
|
59
|
+
update / eventual consistency model between the tokenizer manager process and the scheduler processes.
|
60
|
+
"""
|
61
|
+
|
62
|
+
def __init__(self, lora_paths: Optional[Dict[str, LoRARef]] = None):
|
63
|
+
assert lora_paths is None or all(
|
64
|
+
isinstance(lora, LoRARef) for lora in lora_paths.values()
|
65
|
+
), (
|
66
|
+
"server_args.lora_paths should have been normalized to LoRARef objects during server initialization. "
|
67
|
+
"Please file an issue if you see this error."
|
68
|
+
)
|
69
|
+
|
70
|
+
# A read-write lock to ensure adapters loading / unloading operations are exclusive.
|
71
|
+
# Please note that the counter increment/decrement operations are not synchronized through this
|
72
|
+
# lock, as they are designed to be non-blocking and can be performed concurrently.
|
73
|
+
self._registry_lock = RWLock()
|
74
|
+
# A dictionary to hold LoRARef objects, mapping from LoRA name to LoRARef.
|
75
|
+
self._registry: Dict[str, LoRARef] = {}
|
76
|
+
# Counters for ongoing requests, mapping from LoRA ID to ConcurrentCounter.
|
77
|
+
self._counters: Dict[str, ConcurrentCounter] = {}
|
78
|
+
|
79
|
+
# Initialize the registry with provided LoRA paths, if present.
|
80
|
+
if lora_paths:
|
81
|
+
for lora_ref in lora_paths.values():
|
82
|
+
self._register_adapter(lora_ref)
|
83
|
+
|
84
|
+
async def register(self, lora_ref: LoRARef):
|
85
|
+
"""
|
86
|
+
Register a new LoRARef object in the registry.
|
87
|
+
|
88
|
+
Args:
|
89
|
+
lora_ref (LoRARef): The LoRARef object to register.
|
90
|
+
"""
|
91
|
+
async with self._registry_lock.writer_lock:
|
92
|
+
self._register_adapter(lora_ref)
|
93
|
+
|
94
|
+
async def unregister(self, lora_name: str) -> str:
|
95
|
+
"""
|
96
|
+
Unregister a LoRARef object from the registry and returns the removed LoRA ID.
|
97
|
+
|
98
|
+
Args:
|
99
|
+
lora_name (str): The name of the LoRA model to unregister.
|
100
|
+
"""
|
101
|
+
async with self._registry_lock.writer_lock:
|
102
|
+
lora_ref = self._registry.get(lora_name, None)
|
103
|
+
if lora_ref is None:
|
104
|
+
raise ValueError(
|
105
|
+
f"LoRA with name {lora_name} does not exist. Loaded LoRAs: {self._registry.keys()}"
|
106
|
+
)
|
107
|
+
del self._registry[lora_name]
|
108
|
+
del self._counters[lora_ref.lora_id]
|
109
|
+
|
110
|
+
return lora_ref.lora_id
|
111
|
+
|
112
|
+
async def acquire(self, lora_name: Union[str, List[str]]) -> Union[str, List[str]]:
|
113
|
+
"""
|
114
|
+
Queries registry for LoRA IDs based on LoRA names and start tracking the usage of the corresponding LoRA adapters
|
115
|
+
by incrementing its counter.
|
116
|
+
"""
|
117
|
+
|
118
|
+
def _lookup(name: str) -> str:
|
119
|
+
lora_ref = self._registry.get(name, None)
|
120
|
+
if lora_ref is None:
|
121
|
+
raise ValueError(
|
122
|
+
f"The following requested LoRA adapters are not loaded: {name}\n"
|
123
|
+
f"Loaded adapters: {self._registry.keys()}."
|
124
|
+
)
|
125
|
+
return lora_ref.lora_id
|
126
|
+
|
127
|
+
async with self._registry_lock.reader_lock:
|
128
|
+
if isinstance(lora_name, str):
|
129
|
+
lora_id = _lookup(lora_name)
|
130
|
+
await self._counters[lora_id].increment(notify_all=False)
|
131
|
+
return lora_id
|
132
|
+
elif isinstance(lora_name, list):
|
133
|
+
lora_ids = [_lookup(name) for name in lora_name]
|
134
|
+
|
135
|
+
# Increment the counters only after all IDs are looked up.
|
136
|
+
await asyncio.gather(
|
137
|
+
*[self._counters[id].increment(notify_all=False) for id in lora_ids]
|
138
|
+
)
|
139
|
+
return lora_ids
|
140
|
+
else:
|
141
|
+
raise TypeError(
|
142
|
+
"lora_name must be either a string or a list of strings."
|
143
|
+
)
|
144
|
+
|
145
|
+
async def release(self, lora_id: Union[str, List[str]]):
|
146
|
+
"""
|
147
|
+
Decrements the usage counter for a LoRA adapter, indicating that it is no longer in use.
|
148
|
+
"""
|
149
|
+
|
150
|
+
async with self._registry_lock.reader_lock:
|
151
|
+
if isinstance(lora_id, str):
|
152
|
+
await self._counters[lora_id].decrement()
|
153
|
+
elif isinstance(lora_id, list):
|
154
|
+
await asyncio.gather(
|
155
|
+
*[self._counters[id].decrement() for id in lora_id]
|
156
|
+
)
|
157
|
+
else:
|
158
|
+
raise TypeError("lora_id must be either a string or a list of strings.")
|
159
|
+
|
160
|
+
async def wait_for_unload(self, lora_id: str):
|
161
|
+
"""
|
162
|
+
Waits until the usage counter for a LoRA adapter reaches zero, indicating that it is no longer in use.
|
163
|
+
This is useful for ensuring that a LoRA adapter can be safely unloaded.
|
164
|
+
|
165
|
+
This method itself is not synchronized, which is safe because it should only be called during LoRA unloading,
|
166
|
+
which itself is guaranteed to be sequential.
|
167
|
+
"""
|
168
|
+
assert (
|
169
|
+
lora_id not in self._registry
|
170
|
+
), "wait_for_unload should only be called after the LoRA adapter has been unregistered. "
|
171
|
+
counter = self._counters.get(lora_id)
|
172
|
+
if counter:
|
173
|
+
# Wait until no requests are using this LoRA adapter.
|
174
|
+
await counter.wait_for_zero()
|
175
|
+
del self._counters[lora_id]
|
176
|
+
|
177
|
+
def _register_adapter(self, lora_ref: LoRARef):
|
178
|
+
"""
|
179
|
+
Internal helper method to register a LoRA adapter.
|
180
|
+
"""
|
181
|
+
|
182
|
+
if lora_ref.lora_name in self._registry:
|
183
|
+
raise ValueError(
|
184
|
+
f"LoRA with name {lora_ref.lora_name} already exists. Loaded LoRAs: {self._registry.keys()}"
|
185
|
+
)
|
186
|
+
self._registry[lora_ref.lora_name] = lora_ref
|
187
|
+
self._counters[lora_ref.lora_id] = ConcurrentCounter()
|
188
|
+
return lora_ref
|
sglang/srt/lora/mem_pool.py
CHANGED
@@ -153,7 +153,7 @@ class LoRAMemoryPool:
|
|
153
153
|
self,
|
154
154
|
cur_uids: Set[Optional[str]],
|
155
155
|
lora_adapters: Dict[str, LoRAAdapter],
|
156
|
-
lora_modules:
|
156
|
+
lora_modules: List[Dict[str, BaseLayerWithLoRA]],
|
157
157
|
):
|
158
158
|
def get_available_buffer_slot():
|
159
159
|
for buffer_id in range(self.max_loras_per_batch):
|
@@ -186,7 +186,7 @@ class LoRAMemoryPool:
|
|
186
186
|
uid: str,
|
187
187
|
buffer_id: int,
|
188
188
|
lora_adapter: LoRAAdapter,
|
189
|
-
lora_modules:
|
189
|
+
lora_modules: List[Dict[str, BaseLayerWithLoRA]],
|
190
190
|
):
|
191
191
|
def load_lora_weight_tensor(
|
192
192
|
buffer_view: torch.Tensor, weight: Optional[torch.Tensor]
|