sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +2 -1
- sglang/eval/loogle_eval.py +7 -0
- sglang/srt/_custom_ops.py +29 -1
- sglang/srt/configs/deepseekvl2.py +11 -2
- sglang/srt/configs/internvl.py +3 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/model_config.py +10 -8
- sglang/srt/configs/update_config.py +3 -1
- sglang/srt/conversation.py +2 -1
- sglang/srt/custom_op.py +5 -2
- sglang/srt/disaggregation/common/conn.py +34 -6
- sglang/srt/disaggregation/decode.py +9 -1
- sglang/srt/disaggregation/mini_lb.py +3 -2
- sglang/srt/disaggregation/mooncake/conn.py +93 -76
- sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
- sglang/srt/disaggregation/nixl/conn.py +17 -13
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
- sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
- sglang/srt/distributed/parallel_state.py +103 -15
- sglang/srt/entrypoints/engine.py +31 -33
- sglang/srt/entrypoints/http_server.py +20 -32
- sglang/srt/entrypoints/openai/protocol.py +3 -3
- sglang/srt/entrypoints/openai/serving_chat.py +48 -6
- sglang/srt/eplb/expert_location_dispatch.py +1 -1
- sglang/srt/function_call/base_format_detector.py +74 -12
- sglang/srt/function_call/deepseekv3_detector.py +26 -11
- sglang/srt/function_call/ebnf_composer.py +95 -63
- sglang/srt/function_call/function_call_parser.py +4 -2
- sglang/srt/function_call/kimik2_detector.py +41 -16
- sglang/srt/function_call/llama32_detector.py +6 -3
- sglang/srt/function_call/mistral_detector.py +11 -3
- sglang/srt/function_call/pythonic_detector.py +16 -14
- sglang/srt/function_call/qwen25_detector.py +12 -3
- sglang/srt/function_call/qwen3_coder_detector.py +151 -0
- sglang/srt/hf_transformers_utils.py +0 -1
- sglang/srt/layers/activation.py +24 -3
- sglang/srt/layers/attention/base_attn_backend.py +3 -1
- sglang/srt/layers/attention/flashattention_backend.py +3 -3
- sglang/srt/layers/attention/flashinfer_backend.py +40 -1
- sglang/srt/layers/communicator.py +12 -12
- sglang/srt/layers/dp_attention.py +72 -24
- sglang/srt/layers/linear.py +13 -102
- sglang/srt/layers/logits_processor.py +34 -24
- sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
- sglang/srt/layers/moe/ep_moe/layer.py +23 -402
- sglang/srt/layers/moe/fused_moe_native.py +7 -47
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +54 -263
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
- sglang/srt/layers/moe/topk.py +190 -23
- sglang/srt/layers/quantization/__init__.py +20 -134
- sglang/srt/layers/quantization/awq.py +578 -11
- sglang/srt/layers/quantization/awq_triton.py +339 -0
- sglang/srt/layers/quantization/base_config.py +85 -10
- sglang/srt/layers/quantization/blockwise_int8.py +17 -55
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +23 -79
- sglang/srt/layers/quantization/fp8.py +273 -62
- sglang/srt/layers/quantization/fp8_kernel.py +210 -46
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gptq.py +501 -143
- sglang/srt/layers/quantization/marlin_utils.py +790 -0
- sglang/srt/layers/quantization/modelopt_quant.py +34 -112
- sglang/srt/layers/quantization/moe_wna16.py +45 -49
- sglang/srt/layers/quantization/petit.py +252 -0
- sglang/srt/layers/quantization/petit_utils.py +104 -0
- sglang/srt/layers/quantization/qoq.py +7 -6
- sglang/srt/layers/quantization/scalar_type.py +352 -0
- sglang/srt/layers/quantization/unquant.py +422 -0
- sglang/srt/layers/quantization/utils.py +340 -9
- sglang/srt/layers/quantization/w4afp8.py +8 -4
- sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
- sglang/srt/layers/quantization/w8a8_int8.py +51 -115
- sglang/srt/layers/radix_attention.py +5 -3
- sglang/srt/layers/vocab_parallel_embedding.py +1 -41
- sglang/srt/lora/lora.py +0 -4
- sglang/srt/lora/lora_manager.py +162 -164
- sglang/srt/lora/lora_registry.py +124 -0
- sglang/srt/lora/mem_pool.py +83 -35
- sglang/srt/lora/utils.py +12 -5
- sglang/srt/managers/cache_controller.py +288 -0
- sglang/srt/managers/io_struct.py +60 -30
- sglang/srt/managers/mm_utils.py +7 -8
- sglang/srt/managers/schedule_batch.py +163 -113
- sglang/srt/managers/schedule_policy.py +68 -27
- sglang/srt/managers/scheduler.py +256 -86
- sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
- sglang/srt/managers/tokenizer_manager.py +38 -27
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/allocator.py +74 -23
- sglang/srt/mem_cache/base_prefix_cache.py +14 -2
- sglang/srt/mem_cache/chunk_cache.py +5 -2
- sglang/srt/mem_cache/hicache_storage.py +168 -0
- sglang/srt/mem_cache/hiradix_cache.py +194 -5
- sglang/srt/mem_cache/memory_pool.py +16 -1
- sglang/srt/mem_cache/memory_pool_host.py +44 -2
- sglang/srt/mem_cache/radix_cache.py +26 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +66 -31
- sglang/srt/model_executor/forward_batch_info.py +210 -25
- sglang/srt/model_executor/model_runner.py +147 -42
- sglang/srt/model_loader/loader.py +7 -1
- sglang/srt/model_loader/utils.py +4 -4
- sglang/srt/models/clip.py +1 -1
- sglang/srt/models/deepseek.py +9 -6
- sglang/srt/models/deepseek_janus_pro.py +1 -1
- sglang/srt/models/deepseek_v2.py +192 -173
- sglang/srt/models/deepseek_vl2.py +5 -5
- sglang/srt/models/gemma.py +48 -0
- sglang/srt/models/gemma2.py +52 -0
- sglang/srt/models/gemma3_causal.py +63 -0
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -4
- sglang/srt/models/granitemoe.py +385 -0
- sglang/srt/models/grok.py +9 -3
- sglang/srt/models/hunyuan.py +63 -16
- sglang/srt/models/internvl.py +1 -1
- sglang/srt/models/kimi_vl.py +1 -1
- sglang/srt/models/llama.py +41 -0
- sglang/srt/models/llama4.py +11 -11
- sglang/srt/models/llava.py +2 -2
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/minicpm.py +0 -2
- sglang/srt/models/minicpmo.py +3 -7
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mixtral.py +9 -2
- sglang/srt/models/mllama.py +3 -5
- sglang/srt/models/mllama4.py +13 -6
- sglang/srt/models/olmoe.py +8 -5
- sglang/srt/models/persimmon.py +330 -0
- sglang/srt/models/phi.py +321 -0
- sglang/srt/models/phi4mm.py +44 -4
- sglang/srt/models/phi4mm_audio.py +1260 -0
- sglang/srt/models/phi4mm_utils.py +1917 -0
- sglang/srt/models/phimoe.py +9 -3
- sglang/srt/models/qwen.py +37 -0
- sglang/srt/models/qwen2.py +41 -0
- sglang/srt/models/qwen2_5_vl.py +4 -4
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +53 -9
- sglang/srt/models/qwen2_vl.py +4 -4
- sglang/srt/models/qwen3.py +65 -1
- sglang/srt/models/qwen3_moe.py +57 -24
- sglang/srt/models/vila.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +91 -97
- sglang/srt/multimodal/processors/clip.py +21 -19
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
- sglang/srt/multimodal/processors/gemma3.py +13 -17
- sglang/srt/multimodal/processors/gemma3n.py +19 -23
- sglang/srt/multimodal/processors/internvl.py +9 -10
- sglang/srt/multimodal/processors/janus_pro.py +12 -27
- sglang/srt/multimodal/processors/kimi_vl.py +12 -14
- sglang/srt/multimodal/processors/llava.py +4 -2
- sglang/srt/multimodal/processors/minicpm.py +35 -44
- sglang/srt/multimodal/processors/mlama.py +21 -18
- sglang/srt/multimodal/processors/mllama4.py +4 -5
- sglang/srt/multimodal/processors/phi4mm.py +63 -39
- sglang/srt/multimodal/processors/pixtral.py +14 -35
- sglang/srt/multimodal/processors/qwen_audio.py +65 -0
- sglang/srt/multimodal/processors/qwen_vl.py +16 -21
- sglang/srt/multimodal/processors/vila.py +14 -14
- sglang/srt/reasoning_parser.py +46 -4
- sglang/srt/sampling/sampling_batch_info.py +6 -5
- sglang/srt/sampling/sampling_params.py +8 -1
- sglang/srt/server_args.py +454 -270
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +46 -37
- sglang/srt/speculative/eagle_utils.py +51 -23
- sglang/srt/speculative/eagle_worker.py +59 -44
- sglang/srt/two_batch_overlap.py +10 -5
- sglang/srt/utils.py +44 -69
- sglang/test/runners.py +14 -3
- sglang/test/test_activation.py +50 -1
- sglang/test/test_block_fp8.py +8 -3
- sglang/test/test_block_fp8_ep.py +1 -1
- sglang/test/test_custom_ops.py +12 -7
- sglang/test/test_cutlass_w4a8_moe.py +1 -3
- sglang/test/test_fp4_moe.py +1 -3
- sglang/test/test_marlin_moe.py +286 -0
- sglang/test/test_marlin_utils.py +171 -0
- sglang/test/test_utils.py +35 -0
- sglang/version.py +1 -1
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/METADATA +10 -10
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/RECORD +198 -175
- sglang/srt/layers/quantization/quant_utils.py +0 -166
- sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/top_level.txt +0 -0
sglang/srt/lora/lora_manager.py
CHANGED
@@ -16,7 +16,7 @@
|
|
16
16
|
# and "Punica: Multi-Tenant LoRA Serving"
|
17
17
|
|
18
18
|
import logging
|
19
|
-
from typing import Dict, Set, Tuple
|
19
|
+
from typing import Dict, Iterable, List, Optional, Set, Tuple
|
20
20
|
|
21
21
|
import torch
|
22
22
|
|
@@ -26,6 +26,7 @@ from sglang.srt.lora.backend.base_backend import BaseLoRABackend, get_backend_fr
|
|
26
26
|
from sglang.srt.lora.layers import BaseLayerWithLoRA, get_lora_layer
|
27
27
|
from sglang.srt.lora.lora import LoRAAdapter
|
28
28
|
from sglang.srt.lora.lora_config import LoRAConfig
|
29
|
+
from sglang.srt.lora.lora_registry import LoRARef
|
29
30
|
from sglang.srt.lora.mem_pool import LoRAMemoryPool
|
30
31
|
from sglang.srt.lora.utils import (
|
31
32
|
LoRABatchInfo,
|
@@ -53,6 +54,9 @@ class LoRAManager:
|
|
53
54
|
lora_backend: str = "triton",
|
54
55
|
tp_size: int = 1,
|
55
56
|
tp_rank: int = 0,
|
57
|
+
max_lora_rank: Optional[int] = None,
|
58
|
+
target_modules: Optional[Iterable[str]] = None,
|
59
|
+
lora_paths: Optional[Dict[str, LoRARef]] = None,
|
56
60
|
):
|
57
61
|
self.base_model: torch.nn.Module = base_model
|
58
62
|
self.base_hf_config: AutoConfig = base_hf_config
|
@@ -69,7 +73,11 @@ class LoRAManager:
|
|
69
73
|
self.lora_backend: BaseLoRABackend = backend_type(lora_backend)
|
70
74
|
|
71
75
|
# Initialize mutable internal state of the LoRAManager.
|
72
|
-
self.init_state(
|
76
|
+
self.init_state(
|
77
|
+
max_lora_rank=max_lora_rank,
|
78
|
+
target_modules=target_modules,
|
79
|
+
lora_paths=lora_paths,
|
80
|
+
)
|
73
81
|
|
74
82
|
def init_cuda_graph_batch_info(self, max_bs_in_cuda_graph: int):
|
75
83
|
self.max_bs_in_cuda_graph = max_bs_in_cuda_graph
|
@@ -106,91 +114,87 @@ class LoRAManager:
|
|
106
114
|
success=success,
|
107
115
|
error_message=error_message,
|
108
116
|
loaded_adapters={
|
109
|
-
|
117
|
+
lora_ref.lora_name: lora_ref.lora_path
|
118
|
+
for lora_ref in self.lora_refs.values()
|
110
119
|
},
|
111
120
|
)
|
112
121
|
|
113
|
-
def
|
114
|
-
"""
|
115
|
-
Load LoRA adapters from the specified paths.
|
116
|
-
|
117
|
-
Args:
|
118
|
-
lora_paths (Dict[str, str]): A dictionary mapping LoRA adapter names to their file paths.
|
119
|
-
If a LoRA adapter is already loaded, it will be skipped with a warning.
|
120
|
-
"""
|
121
|
-
|
122
|
-
results = []
|
123
|
-
for lora_name, lora_path in lora_paths.items():
|
124
|
-
result = self.load_lora_adapter(lora_name, lora_path, update_state=False)
|
125
|
-
results.append(result)
|
126
|
-
|
127
|
-
self.update_state_from_configs()
|
128
|
-
|
129
|
-
return self.create_lora_update_result(
|
130
|
-
success=all(result.success for result in results),
|
131
|
-
error_message="\n".join(
|
132
|
-
result.error_message for result in results if not result.success
|
133
|
-
),
|
134
|
-
)
|
135
|
-
|
136
|
-
def load_lora_adapter(
|
137
|
-
self, lora_name: str, lora_path: str, update_state: bool = True
|
138
|
-
) -> LoRAUpdateResult:
|
122
|
+
def load_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateResult:
|
139
123
|
"""
|
140
124
|
Load a single LoRA adapter from the specified path.
|
141
125
|
|
142
126
|
Args:
|
143
|
-
|
144
|
-
lora_path (str): The file path to the LoRA adapter.
|
145
|
-
update_state (bool): Whether to refresh the internal state after loading the adapter. This is useful for batch loading.
|
127
|
+
lora_ref (LoRARef): The LoRARef object containing the LoRA name, path, and ID.
|
146
128
|
"""
|
129
|
+
assert (
|
130
|
+
lora_ref.lora_name is not None and lora_ref.lora_path is not None
|
131
|
+
), "LoRARef must have both lora_name and lora_path set for loading."
|
132
|
+
assert (
|
133
|
+
lora_ref.lora_id not in self.loras
|
134
|
+
), f"LoRA adapter with ID {lora_ref.lora_id} is already loaded. This should have been verified before request is sent to the backend."
|
147
135
|
|
148
|
-
|
149
|
-
|
136
|
+
try:
|
137
|
+
# load configs
|
138
|
+
new_adapter = LoRAConfig(lora_ref.lora_path)
|
139
|
+
self.validate_new_adapter(new_adapter, lora_ref)
|
140
|
+
self.configs[lora_ref.lora_id] = new_adapter
|
150
141
|
|
151
|
-
|
152
|
-
|
153
|
-
error_message = f"LoRA adapter {lora_name} is skipped as it is already loaded. If you want to reload it, please unload it first."
|
142
|
+
# load weights
|
143
|
+
self.load_lora_weights(lora_ref)
|
154
144
|
|
155
|
-
|
156
|
-
self.
|
145
|
+
# keep metadata for displayed messages
|
146
|
+
self.lora_refs[lora_ref.lora_id] = lora_ref
|
157
147
|
except Exception as e:
|
158
|
-
|
159
|
-
|
160
|
-
|
148
|
+
return self.create_lora_update_result(
|
149
|
+
success=False,
|
150
|
+
error_message=str(e),
|
161
151
|
)
|
162
152
|
|
163
|
-
|
164
|
-
self.update_state_from_configs()
|
153
|
+
return self.create_lora_update_result(success=True)
|
165
154
|
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
155
|
+
def validate_new_adapter(self, lora_config: LoRAConfig, lora_ref: LoRARef):
|
156
|
+
"""
|
157
|
+
Validate if an adapter can be loaded into the current LoRA memory pool and generate error if it is incompatible.
|
158
|
+
"""
|
170
159
|
|
171
|
-
|
160
|
+
memory_pool = getattr(self, "memory_pool", None)
|
161
|
+
incompatible = memory_pool and not memory_pool.can_support(lora_config)
|
162
|
+
if incompatible:
|
163
|
+
raise ValueError(
|
164
|
+
f"LoRA adapter {lora_ref.lora_name} with rank {lora_config.r} is incompatible with the current LoRA memory pool configuration. "
|
165
|
+
"Please ensure that the LoRA adapter's rank is within the configured `--max_lora_rank` and that the target modules are "
|
166
|
+
"included in `--enable_lora_modules`."
|
167
|
+
)
|
168
|
+
|
169
|
+
def unload_lora_adapter(self, lora_ref: LoRARef) -> LoRAUpdateResult:
|
172
170
|
"""
|
173
171
|
Unload LoRA adapters by their names. This will remove the adapters from the memory pool and
|
174
172
|
delete the corresponding LoRA modules.
|
175
173
|
"""
|
176
174
|
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
else:
|
182
|
-
error_message = f"LoRA adapter {lora_name} is not loaded."
|
183
|
-
success = False
|
175
|
+
adapter = self.configs.get(lora_ref.lora_id, None)
|
176
|
+
assert (
|
177
|
+
adapter is not None
|
178
|
+
), f"LoRA adapter with ID {lora_ref.lora_id} is not loaded. This should have been verified before request is sent to the backend."
|
184
179
|
|
185
|
-
|
180
|
+
try:
|
181
|
+
del self.configs[lora_ref.lora_id]
|
182
|
+
del self.loras[lora_ref.lora_id]
|
183
|
+
del self.lora_refs[lora_ref.lora_id]
|
184
|
+
except Exception as e:
|
185
|
+
return self.create_lora_update_result(
|
186
|
+
success=False,
|
187
|
+
error_message=str(e),
|
188
|
+
)
|
186
189
|
|
187
|
-
return self.create_lora_update_result(
|
188
|
-
success=success,
|
189
|
-
error_message=error_message,
|
190
|
-
)
|
190
|
+
return self.create_lora_update_result(success=True)
|
191
191
|
|
192
192
|
def prepare_lora_batch(self, forward_batch: ForwardBatch):
|
193
|
-
#
|
193
|
+
# Load active loras into lora memory pool
|
194
|
+
# TODO (lifuhuang): The naming of `forward_batch.lora_paths` is confusing. It actually contains a set of unique
|
195
|
+
# LoRA IDs, not LoRA paths. While unfortunately we cannot change the name in API for backward compatibility, we
|
196
|
+
# should consider (1) renaming the incorrect usage within the system, and (2) deprecating the parameter name in
|
197
|
+
# the current API schema and introducing a better request schema in the future (e.g., use `model_name`).
|
194
198
|
cur_uids = set(forward_batch.lora_paths)
|
195
199
|
assert len(cur_uids) <= self.max_loras_per_batch
|
196
200
|
self.memory_pool.prepare_lora_batch(cur_uids, self.loras, self.lora_modules)
|
@@ -210,11 +214,11 @@ class LoRAManager:
|
|
210
214
|
weight_indices = [0] * len(forward_batch.lora_paths)
|
211
215
|
lora_ranks = [0] * self.max_loras_per_batch
|
212
216
|
scalings = [0] * self.max_loras_per_batch
|
213
|
-
for i,
|
214
|
-
weight_indices[i] = self.memory_pool.get_buffer_id(
|
215
|
-
if
|
216
|
-
lora = self.loras[
|
217
|
-
lora_ranks[weight_indices[i]] = lora.config.
|
217
|
+
for i, uid in enumerate(forward_batch.lora_paths):
|
218
|
+
weight_indices[i] = self.memory_pool.get_buffer_id(uid)
|
219
|
+
if uid is not None:
|
220
|
+
lora = self.loras[uid]
|
221
|
+
lora_ranks[weight_indices[i]] = lora.config.r
|
218
222
|
scalings[weight_indices[i]] = lora.scaling
|
219
223
|
|
220
224
|
# Use pinned memory to avoid synchronizations during host-to-device transfer
|
@@ -303,7 +307,7 @@ class LoRAManager:
|
|
303
307
|
"""
|
304
308
|
Update all LoRA modules to associate them with the latest memory buffer.
|
305
309
|
"""
|
306
|
-
for layer_id, layer_modules in self.lora_modules
|
310
|
+
for layer_id, layer_modules in enumerate(self.lora_modules):
|
307
311
|
for module_name, module in layer_modules.items():
|
308
312
|
if "qkv_proj" in module_name:
|
309
313
|
module.set_lora_info(
|
@@ -319,7 +323,7 @@ class LoRAManager:
|
|
319
323
|
)
|
320
324
|
else:
|
321
325
|
weight_name = get_weight_name(
|
322
|
-
module_name, self.lora_weight_names, LoRAType.LORA_A
|
326
|
+
module_name, self.memory_pool.lora_weight_names, LoRAType.LORA_A
|
323
327
|
)
|
324
328
|
module.set_lora_info(
|
325
329
|
self.memory_pool.get_tensor(
|
@@ -330,125 +334,115 @@ class LoRAManager:
|
|
330
334
|
),
|
331
335
|
)
|
332
336
|
|
333
|
-
def init_state(
|
337
|
+
def init_state(
|
338
|
+
self,
|
339
|
+
max_lora_rank: Optional[int] = None,
|
340
|
+
target_modules: Optional[Iterable[str]] = None,
|
341
|
+
lora_paths: Optional[Dict[str, LoRARef]] = None,
|
342
|
+
):
|
334
343
|
"""
|
335
344
|
Initialize the internal (mutable) state of the LoRAManager.
|
336
345
|
|
337
|
-
|
346
|
+
When `lora_paths` is provided and not empty, it might be used for inferring LoRA shape info such as
|
347
|
+
the target modules and max_lora_rank.
|
338
348
|
"""
|
339
349
|
|
340
|
-
|
350
|
+
assert lora_paths or (
|
351
|
+
max_lora_rank is not None and target_modules is not None
|
352
|
+
), "When no initial --lora-paths is provided, you need to specify both --max-lora-rank and --lora-target-modules for LoRA initialization."
|
353
|
+
|
354
|
+
self.init_lora_adapters(lora_paths)
|
355
|
+
self.init_lora_shapes(
|
356
|
+
max_lora_rank=max_lora_rank,
|
357
|
+
target_modules=target_modules,
|
358
|
+
)
|
359
|
+
self.init_lora_weight_names()
|
360
|
+
self.init_lora_modules()
|
361
|
+
self.init_memory_pool()
|
362
|
+
|
363
|
+
def init_lora_adapters(self, lora_paths: Optional[Dict[str, LoRARef]] = None):
|
364
|
+
# Configs of all active LoRA adapters, indexed by LoRA ID.
|
341
365
|
self.configs: Dict[str, LoRAConfig] = {}
|
342
366
|
|
343
|
-
# LoRA adapter weights cached in CPU memory.
|
367
|
+
# LoRA adapter weights cached in CPU memory, indexed by LoRA ID.
|
344
368
|
self.loras: Dict[str, LoRAAdapter] = {}
|
345
369
|
|
346
|
-
#
|
347
|
-
self.
|
370
|
+
# Mapping from LoRA ID to LoRARef object.
|
371
|
+
self.lora_refs: Dict[str, LoRARef] = {}
|
348
372
|
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
373
|
+
if lora_paths:
|
374
|
+
for lora_ref in lora_paths.values():
|
375
|
+
result = self.load_lora_adapter(lora_ref)
|
376
|
+
if not result.success:
|
377
|
+
raise RuntimeError(
|
378
|
+
f"Failed to load LoRA adapter {lora_ref.lora_name}: {result.error_message}"
|
379
|
+
)
|
353
380
|
|
354
|
-
|
355
|
-
self
|
356
|
-
|
357
|
-
|
358
|
-
|
359
|
-
|
360
|
-
self.tp_rank,
|
361
|
-
)
|
381
|
+
def init_lora_shapes(
|
382
|
+
self,
|
383
|
+
max_lora_rank: Optional[int] = None,
|
384
|
+
target_modules: Optional[Iterable[str]] = None,
|
385
|
+
):
|
386
|
+
"""Infer LoRA target modules and max_lora_rank from loaded adapters if not provided."""
|
362
387
|
|
363
|
-
|
364
|
-
|
365
|
-
|
366
|
-
|
367
|
-
|
368
|
-
|
369
|
-
- Initializing LoRA adapters if they are not already loaded.
|
370
|
-
- Collect all LoRA weight names based on the current loaded adapters.
|
371
|
-
- Lazily monkey-patching the base model to use LoRA layers where applicable.
|
372
|
-
- Preparing the GPU buffer pool for active LoRA weights.
|
373
|
-
"""
|
388
|
+
if target_modules is not None:
|
389
|
+
self.target_modules = set(target_modules)
|
390
|
+
else:
|
391
|
+
self.target_modules = set()
|
392
|
+
for config in self.configs.values():
|
393
|
+
self.target_modules.update(config.target_modules)
|
374
394
|
|
375
|
-
|
376
|
-
|
377
|
-
|
378
|
-
|
379
|
-
|
380
|
-
|
381
|
-
|
382
|
-
|
383
|
-
|
384
|
-
|
385
|
-
# Lazily update states for new LoRA weight name (e.g., qkv_proj) as needed.
|
386
|
-
#
|
387
|
-
# Please note that the following update operations are "monotonic" by design, meaning that we update
|
388
|
-
# multiple places to support the new weight names when the first adapter targeting such weight names
|
389
|
-
# is loaded. However, we never "rollback" the support (e.g., convert LoRA layer back to base layer)
|
390
|
-
# even if the associated adapters are unloaded later for both simplicity and practicality reasons: the
|
391
|
-
# list of LoRA weight names is expected to be extremely finite and stable.
|
392
|
-
self.update_lora_weight_names(hf_target_module_names)
|
393
|
-
self.update_lora_modules(hf_target_module_names)
|
394
|
-
self.update_memory_buffers(max_lora_dim)
|
395
|
-
|
396
|
-
def update_lora_weight_names(self, hf_target_names: Set[str]):
|
395
|
+
if max_lora_rank is not None:
|
396
|
+
self.max_lora_rank = max_lora_rank
|
397
|
+
else:
|
398
|
+
self.max_lora_rank = max(
|
399
|
+
[x.hf_config["r"] for x in self.configs.values()],
|
400
|
+
default=0,
|
401
|
+
)
|
402
|
+
|
403
|
+
def init_lora_weight_names(self):
|
397
404
|
"""
|
398
405
|
Add new LoRA weight names if needed based on the current `self.configs`.
|
399
406
|
"""
|
400
407
|
|
401
408
|
# Target lora weight names for lora_a and lora_b modules respectively.
|
402
|
-
|
403
|
-
|
404
|
-
self.lora_weight_names[0].update(lora_A)
|
405
|
-
self.lora_weight_names[1].update(lora_B)
|
409
|
+
lora_A, lora_B = get_normalized_lora_weight_names(self.target_modules)
|
410
|
+
self.lora_weight_names: Tuple[Set[str]] = (set(lora_A), set(lora_B))
|
406
411
|
|
407
|
-
def
|
412
|
+
def load_lora_weights(self, lora_ref: LoRARef):
|
408
413
|
"""
|
409
|
-
|
410
|
-
It loads any new adapters that are not already loaded, and unloads any adapters
|
411
|
-
that are no longer in `self.configs` (e.g., unloaded).
|
414
|
+
Load the weights of a LoRA adapter to CPU memory and conducts post-loading validation.
|
412
415
|
"""
|
413
|
-
|
414
|
-
|
415
|
-
|
416
|
-
|
417
|
-
|
418
|
-
|
419
|
-
|
420
|
-
|
421
|
-
|
422
|
-
self.load_config,
|
423
|
-
self.lora_backend,
|
424
|
-
)
|
425
|
-
lora_adapter.initialize_weights()
|
426
|
-
self.loras[name] = lora_adapter
|
427
|
-
|
428
|
-
# Clean up unused LoRA adapters, copying the list to avoid modifying the dict during iteration.
|
429
|
-
for name in list(self.loras):
|
430
|
-
if name not in self.configs:
|
431
|
-
logger.info(f"Unloading LoRA adapter {name}")
|
432
|
-
del self.loras[name]
|
416
|
+
lora_adapter = LoRAAdapter(
|
417
|
+
lora_ref.lora_id,
|
418
|
+
self.configs[lora_ref.lora_id],
|
419
|
+
self.base_hf_config,
|
420
|
+
self.load_config,
|
421
|
+
self.lora_backend,
|
422
|
+
)
|
423
|
+
lora_adapter.initialize_weights()
|
424
|
+
self.loras[lora_ref.lora_id] = lora_adapter
|
433
425
|
|
434
426
|
# Additional checks for flashinfer backend
|
435
427
|
# FIXME remove the restrictions after supporting multi-rank for flashinfer backend
|
436
428
|
if self.lora_backend == "flashinfer":
|
437
|
-
lora_dims = set(x.
|
429
|
+
lora_dims = set(x.r for x in self.configs.values())
|
438
430
|
scalings = set(x.scaling for x in self.loras.values())
|
439
431
|
assert (
|
440
432
|
len(lora_dims) == 1 and len(scalings) == 1
|
441
433
|
), "Flashinfer backend currently only supports single LoRA rank and scaling across all adapters. "
|
442
434
|
|
443
|
-
def
|
444
|
-
"""
|
445
|
-
|
446
|
-
|
447
|
-
|
448
|
-
|
449
|
-
|
450
|
-
|
451
|
-
self.
|
435
|
+
def init_memory_pool(self):
|
436
|
+
"""(Re)initialize the LoRA memory pool based on the current configurations."""
|
437
|
+
self.memory_pool = LoRAMemoryPool(
|
438
|
+
base_hf_config=self.base_hf_config,
|
439
|
+
max_loras_per_batch=self.max_loras_per_batch,
|
440
|
+
dtype=self.dtype,
|
441
|
+
tp_size=self.tp_size,
|
442
|
+
tp_rank=self.tp_rank,
|
443
|
+
max_lora_rank=self.max_lora_rank,
|
444
|
+
lora_weight_names=self.lora_weight_names,
|
445
|
+
base_model=self.base_model,
|
452
446
|
)
|
453
447
|
|
454
448
|
def set_lora_module(self, module_name, module):
|
@@ -456,11 +450,16 @@ class LoRAManager:
|
|
456
450
|
replace_submodule(self.base_model, module_name, lora_module)
|
457
451
|
return lora_module
|
458
452
|
|
459
|
-
def
|
453
|
+
def init_lora_modules(self):
|
454
|
+
# Look-up table that essentially maps (layer_index, module_name) to the corresponding LoRA module.
|
455
|
+
self.lora_modules: List[Dict[str, BaseLayerWithLoRA]] = [
|
456
|
+
{} for _ in range(self.base_hf_config.num_hidden_layers)
|
457
|
+
]
|
458
|
+
|
460
459
|
# Target module names of customized layers defined in python/sglang/srt/layers
|
461
460
|
# e.g., {"qkv_proj", "o_proj"}
|
462
461
|
customized_target_names = get_customized_names_from_hf_names(
|
463
|
-
|
462
|
+
self.target_modules, self.base_model
|
464
463
|
)
|
465
464
|
|
466
465
|
for module_name, module in self.base_model.named_modules():
|
@@ -477,7 +476,6 @@ class LoRAManager:
|
|
477
476
|
# The module should be converted if it is included in target_names
|
478
477
|
if module_name.split(".")[-1] in customized_target_names:
|
479
478
|
layer_id = get_layer_id(module_name)
|
480
|
-
|
481
|
-
|
482
|
-
|
483
|
-
)
|
479
|
+
self.lora_modules[layer_id][module_name] = self.set_lora_module(
|
480
|
+
module_name, module
|
481
|
+
)
|
@@ -0,0 +1,124 @@
|
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
14
|
+
|
15
|
+
|
16
|
+
import asyncio
|
17
|
+
from dataclasses import dataclass, field, fields
|
18
|
+
from typing import Dict, List, Optional, Union
|
19
|
+
from uuid import uuid4
|
20
|
+
|
21
|
+
|
22
|
+
@dataclass(frozen=True, slots=True)
|
23
|
+
class LoRARef:
|
24
|
+
"""
|
25
|
+
Reference record for a LoRA model.
|
26
|
+
|
27
|
+
This object guarantees a unique ``lora_id`` and may include ``lora_name`` and ``lora_path``. The ID
|
28
|
+
eliminates conflicts from reused LoRA names or paths and can be used to generate deterministic cache
|
29
|
+
keys (e.g., radix cache).
|
30
|
+
"""
|
31
|
+
|
32
|
+
lora_id: str = field(default_factory=lambda: uuid4().hex)
|
33
|
+
lora_name: Optional[str] = None
|
34
|
+
lora_path: Optional[str] = None
|
35
|
+
|
36
|
+
def __post_init__(self):
|
37
|
+
if self.lora_id is None:
|
38
|
+
raise ValueError("lora_id cannot be None")
|
39
|
+
|
40
|
+
def __str__(self) -> str:
|
41
|
+
parts = [
|
42
|
+
f"{f.name}={value}"
|
43
|
+
for f in fields(self)
|
44
|
+
if (value := getattr(self, f.name)) is not None
|
45
|
+
]
|
46
|
+
return f"{self.__class__.__name__}({', '.join(parts)})"
|
47
|
+
|
48
|
+
|
49
|
+
class LoRARegistry:
|
50
|
+
"""
|
51
|
+
The central registry to keep track of available LoRA adapters.
|
52
|
+
|
53
|
+
TODO (lifuhuang): This registry is intended as the foundation for overlapped lora update. We decided
|
54
|
+
to keep it in a separate PR to keep code review simple and to unblock the radix cache work.
|
55
|
+
"""
|
56
|
+
|
57
|
+
def __init__(self, lora_paths: Optional[Dict[str, LoRARef]] = None):
|
58
|
+
assert lora_paths is None or all(
|
59
|
+
isinstance(lora, LoRARef) for lora in lora_paths.values()
|
60
|
+
), (
|
61
|
+
"server_args.lora_paths should have been normalized to LoRARef objects during server initialization. "
|
62
|
+
"Please file an issue if you see this error."
|
63
|
+
)
|
64
|
+
|
65
|
+
# A dictionary to hold LoRARef objects, mapping from LoRA name to LoRARef.
|
66
|
+
self._registry: Dict[str, LoRARef] = dict(lora_paths or {})
|
67
|
+
|
68
|
+
async def register(self, lora_ref: LoRARef):
|
69
|
+
"""
|
70
|
+
Register a new LoRARef object in the registry.
|
71
|
+
|
72
|
+
Args:
|
73
|
+
lora_ref (LoRARef): The LoRARef object to register.
|
74
|
+
"""
|
75
|
+
if lora_ref.lora_name in self._registry:
|
76
|
+
raise ValueError(
|
77
|
+
f"LoRA with name {lora_ref.lora_name} already exists. Loaded LoRAs: {self._registry.keys()}"
|
78
|
+
)
|
79
|
+
self._registry[lora_ref.lora_name] = lora_ref
|
80
|
+
|
81
|
+
async def unregister(self, lora_name: str) -> str:
|
82
|
+
"""
|
83
|
+
Unregister a LoRARef object from the registry and returns the removed LoRA ID.
|
84
|
+
|
85
|
+
Args:
|
86
|
+
lora_name (str): The name of the LoRA model to unregister.
|
87
|
+
"""
|
88
|
+
lora_ref = self._registry.get(lora_name, None)
|
89
|
+
if lora_ref is None:
|
90
|
+
raise ValueError(
|
91
|
+
f"LoRA with name {lora_name} does not exist. Loaded LoRAs: {self._registry.keys()}"
|
92
|
+
)
|
93
|
+
del self._registry[lora_name]
|
94
|
+
|
95
|
+
return lora_ref.lora_id
|
96
|
+
|
97
|
+
async def acquire(self, lora_name: Union[str, List[str]]) -> Union[str, List[str]]:
|
98
|
+
"""
|
99
|
+
Queries registry for LoRA IDs based on LoRA names and start tracking the usage of the corresponding LoRA adapters
|
100
|
+
by incrementing its counter.
|
101
|
+
|
102
|
+
TODO (lifuhuang): currently it only queries the registry and does not track the usage of LoRA adapters.
|
103
|
+
"""
|
104
|
+
|
105
|
+
async def _acquire_single(name: str) -> str:
|
106
|
+
lora_ref = self._registry.get(name, None)
|
107
|
+
if lora_ref is None:
|
108
|
+
raise ValueError(
|
109
|
+
f"The following requested LoRA adapters are not loaded: {name}\n"
|
110
|
+
f"Loaded adapters: {self._registry.keys()}."
|
111
|
+
)
|
112
|
+
# await self._counters[lora_ref.lora_id].increment()
|
113
|
+
return lora_ref.lora_id
|
114
|
+
|
115
|
+
if isinstance(lora_name, str):
|
116
|
+
lora_id = await _acquire_single(lora_name)
|
117
|
+
return lora_id
|
118
|
+
elif isinstance(lora_name, list):
|
119
|
+
lora_ids = await asyncio.gather(
|
120
|
+
*[_acquire_single(name) for name in lora_name]
|
121
|
+
)
|
122
|
+
return lora_ids
|
123
|
+
else:
|
124
|
+
raise TypeError("lora_name must be either a string or a list of strings.")
|