sglang 0.5.0rc0__py3-none-any.whl → 0.5.0rc2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -3
- sglang/bench_one_batch.py +6 -1
- sglang/lang/chat_template.py +18 -0
- sglang/srt/bench_utils.py +137 -0
- sglang/srt/configs/model_config.py +8 -7
- sglang/srt/disaggregation/decode.py +8 -4
- sglang/srt/disaggregation/mooncake/conn.py +43 -25
- sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
- sglang/srt/distributed/parallel_state.py +4 -2
- sglang/srt/entrypoints/context.py +3 -20
- sglang/srt/entrypoints/engine.py +13 -8
- sglang/srt/entrypoints/harmony_utils.py +2 -0
- sglang/srt/entrypoints/http_server.py +68 -5
- sglang/srt/entrypoints/openai/protocol.py +2 -9
- sglang/srt/entrypoints/openai/serving_chat.py +60 -265
- sglang/srt/entrypoints/openai/serving_completions.py +1 -0
- sglang/srt/entrypoints/openai/tool_server.py +4 -3
- sglang/srt/function_call/ebnf_composer.py +1 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +331 -0
- sglang/srt/function_call/kimik2_detector.py +3 -3
- sglang/srt/function_call/qwen3_coder_detector.py +219 -9
- sglang/srt/jinja_template_utils.py +6 -0
- sglang/srt/layers/attention/aiter_backend.py +370 -107
- sglang/srt/layers/attention/ascend_backend.py +3 -0
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +18 -0
- sglang/srt/layers/attention/flashinfer_backend.py +55 -13
- sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -0
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +24 -27
- sglang/srt/layers/attention/trtllm_mha_backend.py +8 -6
- sglang/srt/layers/attention/trtllm_mla_backend.py +129 -25
- sglang/srt/layers/attention/vision.py +9 -1
- sglang/srt/layers/attention/wave_backend.py +627 -0
- sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
- sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
- sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
- sglang/srt/layers/communicator.py +11 -13
- sglang/srt/layers/dp_attention.py +118 -27
- sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
- sglang/srt/layers/linear.py +1 -0
- sglang/srt/layers/logits_processor.py +12 -18
- sglang/srt/layers/moe/cutlass_moe.py +11 -16
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
- sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
- sglang/srt/layers/moe/ep_moe/layer.py +60 -2
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -9
- sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
- sglang/srt/layers/moe/topk.py +4 -1
- sglang/srt/layers/multimodal.py +156 -40
- sglang/srt/layers/quantization/__init__.py +10 -35
- sglang/srt/layers/quantization/awq.py +15 -16
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +0 -1
- sglang/srt/layers/quantization/fp8_kernel.py +277 -0
- sglang/srt/layers/quantization/fp8_utils.py +22 -10
- sglang/srt/layers/quantization/gptq.py +12 -17
- sglang/srt/layers/quantization/marlin_utils.py +15 -5
- sglang/srt/layers/quantization/modelopt_quant.py +58 -41
- sglang/srt/layers/quantization/mxfp4.py +20 -3
- sglang/srt/layers/quantization/utils.py +52 -2
- sglang/srt/layers/quantization/w4afp8.py +20 -11
- sglang/srt/layers/quantization/w8a8_int8.py +48 -34
- sglang/srt/layers/rotary_embedding.py +281 -2
- sglang/srt/layers/sampler.py +5 -2
- sglang/srt/lora/backend/base_backend.py +3 -23
- sglang/srt/lora/layers.py +66 -116
- sglang/srt/lora/lora.py +17 -62
- sglang/srt/lora/lora_manager.py +12 -48
- sglang/srt/lora/lora_registry.py +20 -9
- sglang/srt/lora/mem_pool.py +20 -63
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/utils.py +25 -58
- sglang/srt/managers/cache_controller.py +24 -29
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +20 -6
- sglang/srt/managers/mm_utils.py +1 -2
- sglang/srt/managers/multimodal_processor.py +1 -1
- sglang/srt/managers/schedule_batch.py +43 -49
- sglang/srt/managers/schedule_policy.py +6 -6
- sglang/srt/managers/scheduler.py +18 -11
- sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
- sglang/srt/managers/tokenizer_manager.py +53 -44
- sglang/srt/mem_cache/allocator.py +39 -214
- sglang/srt/mem_cache/allocator_ascend.py +158 -0
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +1 -1
- sglang/srt/mem_cache/hiradix_cache.py +34 -24
- sglang/srt/mem_cache/lora_radix_cache.py +421 -0
- sglang/srt/mem_cache/memory_pool_host.py +33 -35
- sglang/srt/mem_cache/radix_cache.py +2 -5
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
- sglang/srt/model_executor/cuda_graph_runner.py +29 -23
- sglang/srt/model_executor/forward_batch_info.py +33 -14
- sglang/srt/model_executor/model_runner.py +179 -81
- sglang/srt/model_loader/loader.py +18 -6
- sglang/srt/models/deepseek_nextn.py +2 -1
- sglang/srt/models/deepseek_v2.py +79 -38
- sglang/srt/models/gemma2.py +0 -34
- sglang/srt/models/gemma3n_mm.py +8 -9
- sglang/srt/models/glm4.py +6 -0
- sglang/srt/models/glm4_moe.py +11 -11
- sglang/srt/models/glm4_moe_nextn.py +2 -1
- sglang/srt/models/glm4v.py +589 -0
- sglang/srt/models/glm4v_moe.py +400 -0
- sglang/srt/models/gpt_oss.py +142 -20
- sglang/srt/models/granite.py +0 -25
- sglang/srt/models/llama.py +10 -27
- sglang/srt/models/llama4.py +19 -6
- sglang/srt/models/qwen2.py +2 -2
- sglang/srt/models/qwen2_5_vl.py +7 -3
- sglang/srt/models/qwen2_audio.py +10 -9
- sglang/srt/models/qwen2_moe.py +20 -5
- sglang/srt/models/qwen3.py +0 -24
- sglang/srt/models/qwen3_classification.py +78 -0
- sglang/srt/models/qwen3_moe.py +18 -5
- sglang/srt/models/registry.py +1 -1
- sglang/srt/models/step3_vl.py +6 -2
- sglang/srt/models/torch_native_llama.py +0 -24
- sglang/srt/multimodal/processors/base_processor.py +23 -13
- sglang/srt/multimodal/processors/glm4v.py +132 -0
- sglang/srt/multimodal/processors/qwen_audio.py +4 -2
- sglang/srt/operations.py +17 -2
- sglang/srt/reasoning_parser.py +316 -0
- sglang/srt/sampling/sampling_batch_info.py +7 -4
- sglang/srt/server_args.py +142 -140
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +7 -21
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
- sglang/srt/speculative/eagle_worker.py +16 -0
- sglang/srt/two_batch_overlap.py +16 -12
- sglang/srt/utils.py +3 -3
- sglang/srt/weight_sync/tensor_bucket.py +106 -0
- sglang/test/attention/test_trtllm_mla_backend.py +186 -36
- sglang/test/doc_patch.py +59 -0
- sglang/test/few_shot_gsm8k.py +1 -1
- sglang/test/few_shot_gsm8k_engine.py +1 -1
- sglang/test/run_eval.py +4 -1
- sglang/test/simple_eval_common.py +6 -0
- sglang/test/simple_eval_gpqa.py +2 -0
- sglang/test/test_fp4_moe.py +118 -36
- sglang/test/test_marlin_moe.py +1 -1
- sglang/test/test_marlin_utils.py +1 -1
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/METADATA +27 -31
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/RECORD +166 -142
- sglang/lang/backend/__init__.py +0 -0
- sglang/srt/function_call/harmony_tool_parser.py +0 -130
- sglang/srt/layers/quantization/scalar_type.py +0 -352
- sglang/srt/lora/backend/flashinfer_backend.py +0 -131
- /sglang/{api.py → lang/api.py} +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/WHEEL +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.0rc0.dist-info → sglang-0.5.0rc2.dist-info}/top_level.txt +0 -0
sglang/srt/lora/lora.py
CHANGED
@@ -117,7 +117,6 @@ class LoRAAdapter(nn.Module):
|
|
117
117
|
q_name = weight_name
|
118
118
|
k_name = weight_name.replace("q_proj", "k_proj")
|
119
119
|
v_name = weight_name.replace("q_proj", "v_proj")
|
120
|
-
kv_name = weight_name.replace("q_proj", "kv_proj")
|
121
120
|
qkv_name = weight_name.replace("q_proj", "qkv_proj")
|
122
121
|
|
123
122
|
# If k_proj doesn't have lora, initialize it to zero
|
@@ -126,57 +125,27 @@ class LoRAAdapter(nn.Module):
|
|
126
125
|
if "k_proj" in target_module
|
127
126
|
else torch.zeros_like(weights[v_name])
|
128
127
|
)
|
129
|
-
|
130
|
-
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
136
|
-
|
137
|
-
|
138
|
-
|
139
|
-
|
140
|
-
|
141
|
-
weights.pop(v_name)
|
142
|
-
else:
|
143
|
-
weights[kv_name] = torch.stack(
|
144
|
-
[
|
145
|
-
k_proj_weight,
|
146
|
-
weights[v_name],
|
147
|
-
],
|
148
|
-
dim=0,
|
149
|
-
)
|
150
|
-
if "k_proj" in target_module:
|
151
|
-
weights.pop(k_name)
|
152
|
-
weights.pop(v_name)
|
128
|
+
weights[qkv_name] = torch.cat(
|
129
|
+
(
|
130
|
+
weights[q_name],
|
131
|
+
k_proj_weight,
|
132
|
+
weights[v_name],
|
133
|
+
),
|
134
|
+
0,
|
135
|
+
)
|
136
|
+
weights.pop(q_name)
|
137
|
+
if "k_proj" in target_module:
|
138
|
+
weights.pop(k_name)
|
139
|
+
weights.pop(v_name)
|
153
140
|
elif "qkv_proj" in weight_name:
|
154
141
|
# If qkv_proj is already stacked, we normalize it following the SGL convention.
|
155
142
|
qkv_name = weight_name
|
156
143
|
q_name = weight_name.replace("qkv_proj", "q_proj")
|
157
144
|
k_name = weight_name.replace("qkv_proj", "k_proj")
|
158
145
|
v_name = weight_name.replace("qkv_proj", "v_proj")
|
159
|
-
kv_name = weight_name.replace("qkv_proj", "kv_proj")
|
160
146
|
if "lora_A" in weight_name:
|
161
147
|
weights[qkv_name] = weights[qkv_name].repeat(3, 1)
|
162
|
-
else:
|
163
|
-
head_size = (
|
164
|
-
self.base_hf_config.hidden_size
|
165
|
-
// self.base_hf_config.num_attention_heads
|
166
|
-
)
|
167
|
-
weights[q_name], k_proj_weight, v_proj_weight = torch.split(
|
168
|
-
weights[qkv_name],
|
169
|
-
[
|
170
|
-
head_size * self.base_hf_config.num_attention_heads,
|
171
|
-
head_size * self.base_hf_config.num_key_value_heads,
|
172
|
-
head_size * self.base_hf_config.num_key_value_heads,
|
173
|
-
],
|
174
|
-
dim=0,
|
175
|
-
)
|
176
|
-
weights[kv_name] = torch.stack(
|
177
|
-
[k_proj_weight, v_proj_weight],
|
178
|
-
dim=0,
|
179
|
-
)
|
148
|
+
# else: no-op as LoRA B weight is already stacked.
|
180
149
|
|
181
150
|
def normalize_gate_up_proj(
|
182
151
|
self, weight_names: List[str], weights: Dict[str, torch.Tensor]
|
@@ -187,20 +156,14 @@ class LoRAAdapter(nn.Module):
|
|
187
156
|
gate_up_name = weight_name.replace("gate_proj", "gate_up_proj")
|
188
157
|
if up_name not in weights:
|
189
158
|
weights[up_name] = torch.zeros_like(weights[weight_name])
|
190
|
-
# FIXME: Add gate-only support for flashinfer in future implementations
|
191
159
|
assert self.lora_backend.name == "triton", (
|
192
160
|
f"LoRA weight initialization currently only supported for 'triton' backend. "
|
193
161
|
f"Received backend: {self.lora_backend.name}. Please verify your backend configuration "
|
194
162
|
f"or consider implementing custom initialization logic for other backends."
|
195
163
|
)
|
196
|
-
|
197
|
-
weights[
|
198
|
-
|
199
|
-
)
|
200
|
-
else:
|
201
|
-
weights[gate_up_name] = torch.stack(
|
202
|
-
[weights[weight_name], weights[up_name]], dim=0
|
203
|
-
)
|
164
|
+
weights[gate_up_name] = torch.cat(
|
165
|
+
(weights[weight_name], weights[up_name]), 0
|
166
|
+
)
|
204
167
|
weights.pop(weight_name)
|
205
168
|
if up_name in weights:
|
206
169
|
weights.pop(up_name)
|
@@ -209,12 +172,4 @@ class LoRAAdapter(nn.Module):
|
|
209
172
|
gate_up_name = weight_name
|
210
173
|
if "lora_A" in weight_name:
|
211
174
|
weights[gate_up_name] = weights[gate_up_name].repeat(2, 1)
|
212
|
-
else:
|
213
|
-
output_dim = weights[gate_up_name].shape[0] // 2
|
214
|
-
weights[gate_up_name] = torch.stack(
|
215
|
-
[
|
216
|
-
weights[gate_up_name][:output_dim, :],
|
217
|
-
weights[gate_up_name][output_dim:, :],
|
218
|
-
],
|
219
|
-
dim=0,
|
220
|
-
)
|
175
|
+
# else: no-op as LoRA B weight is already stacked.
|
sglang/srt/lora/lora_manager.py
CHANGED
@@ -31,7 +31,6 @@ from sglang.srt.lora.mem_pool import LoRAMemoryPool
|
|
31
31
|
from sglang.srt.lora.utils import (
|
32
32
|
LoRABatchInfo,
|
33
33
|
LoRAType,
|
34
|
-
get_customized_names_from_hf_names,
|
35
34
|
get_layer_id,
|
36
35
|
get_normalized_lora_weight_names,
|
37
36
|
get_weight_name,
|
@@ -345,40 +344,19 @@ class LoRAManager:
|
|
345
344
|
)
|
346
345
|
self.lora_backend.set_batch_info(batch_info)
|
347
346
|
|
348
|
-
# TODO (lifuhuang): one potential perf optimization that is worth considering is to see if we can call
|
349
|
-
# this method only when loading/unloading LoRA adapters, instead of calling it for every micro-batch.
|
350
|
-
self.update_lora_info()
|
351
|
-
|
352
347
|
def update_lora_info(self):
|
353
348
|
"""
|
354
349
|
Update all LoRA modules to associate them with the latest memory buffer.
|
355
350
|
"""
|
356
351
|
for layer_id, layer_modules in enumerate(self.lora_modules):
|
357
352
|
for module_name, module in layer_modules.items():
|
358
|
-
|
359
|
-
|
360
|
-
|
361
|
-
|
362
|
-
|
363
|
-
|
364
|
-
|
365
|
-
),
|
366
|
-
self.memory_pool.get_tensor(
|
367
|
-
"kv_proj", layer_id, LoRAType.LORA_B
|
368
|
-
),
|
369
|
-
)
|
370
|
-
else:
|
371
|
-
weight_name = get_weight_name(
|
372
|
-
module_name, self.memory_pool.lora_weight_names, LoRAType.LORA_A
|
373
|
-
)
|
374
|
-
module.set_lora_info(
|
375
|
-
self.memory_pool.get_tensor(
|
376
|
-
weight_name, layer_id, LoRAType.LORA_A
|
377
|
-
),
|
378
|
-
self.memory_pool.get_tensor(
|
379
|
-
weight_name, layer_id, LoRAType.LORA_B
|
380
|
-
),
|
381
|
-
)
|
353
|
+
weight_name = get_weight_name(
|
354
|
+
module_name, self.memory_pool.lora_weight_names
|
355
|
+
)
|
356
|
+
module.set_lora_info(
|
357
|
+
self.memory_pool.get_tensor(weight_name, layer_id, LoRAType.LORA_A),
|
358
|
+
self.memory_pool.get_tensor(weight_name, layer_id, LoRAType.LORA_B),
|
359
|
+
)
|
382
360
|
|
383
361
|
def init_state(
|
384
362
|
self,
|
@@ -405,6 +383,7 @@ class LoRAManager:
|
|
405
383
|
self.init_lora_weight_names()
|
406
384
|
self.init_lora_modules()
|
407
385
|
self.init_memory_pool()
|
386
|
+
self.update_lora_info()
|
408
387
|
|
409
388
|
def init_lora_adapters(self, lora_paths: Optional[Dict[str, LoRARef]] = None):
|
410
389
|
# Configs of all active LoRA adapters, indexed by LoRA ID.
|
@@ -461,9 +440,9 @@ class LoRAManager:
|
|
461
440
|
Add new LoRA weight names if needed based on the current `self.configs`.
|
462
441
|
"""
|
463
442
|
|
464
|
-
|
465
|
-
|
466
|
-
|
443
|
+
self.lora_weight_names: Set[str] = get_normalized_lora_weight_names(
|
444
|
+
self.target_modules
|
445
|
+
)
|
467
446
|
|
468
447
|
def load_lora_weights(self, lora_ref: LoRARef):
|
469
448
|
"""
|
@@ -479,15 +458,6 @@ class LoRAManager:
|
|
479
458
|
lora_adapter.initialize_weights()
|
480
459
|
self.loras[lora_ref.lora_id] = lora_adapter
|
481
460
|
|
482
|
-
# Additional checks for flashinfer backend
|
483
|
-
# FIXME remove the restrictions after supporting multi-rank for flashinfer backend
|
484
|
-
if self.lora_backend == "flashinfer":
|
485
|
-
lora_dims = set(x.r for x in self.configs.values())
|
486
|
-
scalings = set(x.scaling for x in self.loras.values())
|
487
|
-
assert (
|
488
|
-
len(lora_dims) == 1 and len(scalings) == 1
|
489
|
-
), "Flashinfer backend currently only supports single LoRA rank and scaling across all adapters. "
|
490
|
-
|
491
461
|
def init_memory_pool(self):
|
492
462
|
"""(Re)initialize the LoRA memory pool based on the current configurations."""
|
493
463
|
self.memory_pool = LoRAMemoryPool(
|
@@ -512,12 +482,6 @@ class LoRAManager:
|
|
512
482
|
{} for _ in range(self.base_hf_config.num_hidden_layers)
|
513
483
|
]
|
514
484
|
|
515
|
-
# Target module names of customized layers defined in python/sglang/srt/layers
|
516
|
-
# e.g., {"qkv_proj", "o_proj"}
|
517
|
-
customized_target_names = get_customized_names_from_hf_names(
|
518
|
-
self.target_modules, self.base_model
|
519
|
-
)
|
520
|
-
|
521
485
|
for module_name, module in self.base_model.named_modules():
|
522
486
|
# TODO (lifuhuang): in the future, we should consider generalizing the
|
523
487
|
# should_apply_lora function to support mapping by full module name instead
|
@@ -530,7 +494,7 @@ class LoRAManager:
|
|
530
494
|
continue
|
531
495
|
|
532
496
|
# The module should be converted if it is included in target_names
|
533
|
-
if module_name.split(".")[-1] in
|
497
|
+
if module_name.split(".")[-1] in self.lora_weight_names:
|
534
498
|
layer_id = get_layer_id(module_name)
|
535
499
|
self.lora_modules[layer_id][module_name] = self.set_lora_module(
|
536
500
|
module_name, module
|
sglang/srt/lora/lora_registry.py
CHANGED
@@ -14,7 +14,6 @@
|
|
14
14
|
|
15
15
|
|
16
16
|
import asyncio
|
17
|
-
from collections import defaultdict
|
18
17
|
from dataclasses import dataclass, field, fields
|
19
18
|
from typing import Dict, List, Optional, Union
|
20
19
|
from uuid import uuid4
|
@@ -106,7 +105,6 @@ class LoRARegistry:
|
|
106
105
|
f"LoRA with name {lora_name} does not exist. Loaded LoRAs: {self._registry.keys()}"
|
107
106
|
)
|
108
107
|
del self._registry[lora_name]
|
109
|
-
del self._counters[lora_ref.lora_id]
|
110
108
|
|
111
109
|
return lora_ref.lora_id
|
112
110
|
|
@@ -117,6 +115,9 @@ class LoRARegistry:
|
|
117
115
|
"""
|
118
116
|
|
119
117
|
def _lookup(name: str) -> str:
|
118
|
+
if name is None:
|
119
|
+
return None
|
120
|
+
|
120
121
|
lora_ref = self._registry.get(name, None)
|
121
122
|
if lora_ref is None:
|
122
123
|
raise ValueError(
|
@@ -135,7 +136,11 @@ class LoRARegistry:
|
|
135
136
|
|
136
137
|
# Increment the counters only after all IDs are looked up.
|
137
138
|
await asyncio.gather(
|
138
|
-
*[
|
139
|
+
*[
|
140
|
+
self._counters[id].increment(notify_all=False)
|
141
|
+
for id in lora_ids
|
142
|
+
if id is not None
|
143
|
+
]
|
139
144
|
)
|
140
145
|
return lora_ids
|
141
146
|
else:
|
@@ -153,7 +158,11 @@ class LoRARegistry:
|
|
153
158
|
await self._counters[lora_id].decrement()
|
154
159
|
elif isinstance(lora_id, list):
|
155
160
|
await asyncio.gather(
|
156
|
-
*[
|
161
|
+
*[
|
162
|
+
self._counters[id].decrement()
|
163
|
+
for id in lora_id
|
164
|
+
if id is not None
|
165
|
+
]
|
157
166
|
)
|
158
167
|
else:
|
159
168
|
raise TypeError("lora_id must be either a string or a list of strings.")
|
@@ -169,11 +178,13 @@ class LoRARegistry:
|
|
169
178
|
assert (
|
170
179
|
lora_id not in self._registry
|
171
180
|
), "wait_for_unload should only be called after the LoRA adapter has been unregistered. "
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
181
|
+
assert (
|
182
|
+
lora_id in self._counters
|
183
|
+
), "The LoRA ID should still have a counter if it has been registered before."
|
184
|
+
|
185
|
+
# Wait until no requests are using this LoRA adapter.
|
186
|
+
await self._counters[lora_id].wait_for_zero()
|
187
|
+
del self._counters[lora_id]
|
177
188
|
|
178
189
|
def _register_adapter(self, lora_ref: LoRARef):
|
179
190
|
"""
|
sglang/srt/lora/mem_pool.py
CHANGED
@@ -52,7 +52,7 @@ class LoRAMemoryPool:
|
|
52
52
|
tp_size: int,
|
53
53
|
tp_rank: int,
|
54
54
|
max_lora_rank: int,
|
55
|
-
lora_weight_names:
|
55
|
+
lora_weight_names: Set[str],
|
56
56
|
base_model: torch.nn.Module,
|
57
57
|
):
|
58
58
|
self.base_hf_config: AutoConfig = base_hf_config
|
@@ -62,9 +62,7 @@ class LoRAMemoryPool:
|
|
62
62
|
self.tp_size: int = tp_size
|
63
63
|
self.tp_rank: int = tp_rank
|
64
64
|
self.max_lora_rank: int = max_lora_rank
|
65
|
-
|
66
|
-
# lora weight names for LoRA A and B respectively.
|
67
|
-
self.lora_weight_names: Tuple[Set[str], Set[str]] = lora_weight_names
|
65
|
+
self.lora_weight_names: Set[str] = lora_weight_names
|
68
66
|
|
69
67
|
# Both A_buffer and B_buffer maps lora weight names to its buffer space.
|
70
68
|
# A_buffer contains num_layer number of row-major tensors with shape
|
@@ -97,12 +95,8 @@ class LoRAMemoryPool:
|
|
97
95
|
"""
|
98
96
|
if config.r > self.max_lora_rank:
|
99
97
|
return False
|
100
|
-
|
101
|
-
|
102
|
-
)
|
103
|
-
return weights_a.issubset(self.lora_weight_names[0]) and weights_b.issubset(
|
104
|
-
self.lora_weight_names[1]
|
105
|
-
)
|
98
|
+
weights = get_normalized_lora_weight_names(config.target_modules)
|
99
|
+
return weights.issubset(self.lora_weight_names)
|
106
100
|
|
107
101
|
if isinstance(config, LoRAConfig):
|
108
102
|
return _can_support(config)
|
@@ -132,11 +126,9 @@ class LoRAMemoryPool:
|
|
132
126
|
Given a module_name (might be a stacked name), return the hidden dims of modules' input and output.
|
133
127
|
"""
|
134
128
|
_, output_dim = get_hidden_dim(module_name, self.base_hf_config, base_model)
|
135
|
-
c = get_stacked_multiply(module_name)
|
136
129
|
if self.tp_size > 1 and module_name not in ROW_PARALLELISM_LINEAR_LORA_NAMES:
|
137
130
|
output_dim = divide(output_dim, self.tp_size)
|
138
131
|
return (
|
139
|
-
c,
|
140
132
|
self.max_loras_per_batch,
|
141
133
|
output_dim,
|
142
134
|
max_lora_dim,
|
@@ -165,13 +157,13 @@ class LoRAMemoryPool:
|
|
165
157
|
|
166
158
|
init_buffer(
|
167
159
|
self.A_buffer,
|
168
|
-
self.lora_weight_names
|
160
|
+
self.lora_weight_names,
|
169
161
|
self.get_lora_A_shape,
|
170
162
|
)
|
171
163
|
|
172
164
|
init_buffer(
|
173
165
|
self.B_buffer,
|
174
|
-
self.lora_weight_names
|
166
|
+
self.lora_weight_names,
|
175
167
|
self.get_lora_B_shape,
|
176
168
|
)
|
177
169
|
|
@@ -246,7 +238,7 @@ class LoRAMemoryPool:
|
|
246
238
|
return
|
247
239
|
|
248
240
|
assert lora_adapter is not None
|
249
|
-
lora_rank = lora_adapter.config.
|
241
|
+
lora_rank = lora_adapter.config.r
|
250
242
|
for layer_id in range(self.num_layer):
|
251
243
|
layer_weights = lora_adapter.layers[layer_id].weights
|
252
244
|
temp_A_buffer: Dict[str, Optional[torch.Tensor]] = {
|
@@ -256,73 +248,38 @@ class LoRAMemoryPool:
|
|
256
248
|
weight_name: None for weight_name in self.B_buffer
|
257
249
|
}
|
258
250
|
for name, weights in layer_weights.items():
|
251
|
+
lora_weight_name = get_weight_name(name, self.lora_weight_names)
|
259
252
|
if "lora_A" in name:
|
260
|
-
lora_weight_name = get_weight_name(
|
261
|
-
name, self.lora_weight_names, LoRAType.LORA_A
|
262
|
-
)
|
263
253
|
temp_A_buffer[lora_weight_name] = weights
|
264
254
|
else:
|
265
|
-
lora_weight_name = get_weight_name(
|
266
|
-
name, self.lora_weight_names, LoRAType.LORA_B
|
267
|
-
)
|
268
255
|
temp_B_buffer[lora_weight_name] = weights
|
269
256
|
|
270
257
|
if self.tp_size > 1:
|
271
258
|
cur_layer_modules = lora_modules[layer_id]
|
272
259
|
for module_name, module in cur_layer_modules.items():
|
273
|
-
weight_name = get_weight_name(
|
274
|
-
module_name, self.lora_weight_names, LoRAType.LORA_A
|
275
|
-
)
|
260
|
+
weight_name = get_weight_name(module_name, self.lora_weight_names)
|
276
261
|
|
277
262
|
if temp_A_buffer[weight_name] is None:
|
278
263
|
# Skip weight slicing if the weight is not present in the adapter
|
279
264
|
continue
|
280
265
|
|
281
|
-
|
282
|
-
temp_A_buffer[
|
283
|
-
|
284
|
-
|
285
|
-
temp_B_buffer[
|
286
|
-
|
287
|
-
[temp_B_buffer["q_proj"], temp_B_buffer["kv_proj"]],
|
288
|
-
self.tp_rank,
|
289
|
-
)
|
290
|
-
)
|
291
|
-
else:
|
292
|
-
# TODO (lifuhuang): Ideally, we should call `get_weight_name` separately for both A and B.
|
293
|
-
# Currently, we're reusing A's weight name as a workaround, relying on the fact that A and
|
294
|
-
# B share the same name except for `qkv_proj`. We should clean this up once we deprecate the
|
295
|
-
# FlashInfer LoRA backend.
|
296
|
-
temp_A_buffer[weight_name] = module.slice_lora_a_weights(
|
297
|
-
temp_A_buffer[weight_name], self.tp_rank
|
298
|
-
)
|
299
|
-
temp_B_buffer[weight_name] = module.slice_lora_b_weights(
|
300
|
-
temp_B_buffer[weight_name], self.tp_rank
|
301
|
-
)
|
266
|
+
temp_A_buffer[weight_name] = module.slice_lora_a_weights(
|
267
|
+
temp_A_buffer[weight_name], self.tp_rank
|
268
|
+
)
|
269
|
+
temp_B_buffer[weight_name] = module.slice_lora_b_weights(
|
270
|
+
temp_B_buffer[weight_name], self.tp_rank
|
271
|
+
)
|
302
272
|
|
303
273
|
for name, weights in temp_A_buffer.items():
|
304
274
|
c = get_stacked_multiply(name)
|
305
|
-
|
306
|
-
|
307
|
-
]
|
275
|
+
target_buffer = self.A_buffer[name][layer_id]
|
276
|
+
buffer_view = target_buffer[buffer_id, : lora_rank * c, :]
|
308
277
|
load_lora_weight_tensor(buffer_view, weights)
|
309
278
|
|
310
279
|
for name, weights in temp_B_buffer.items():
|
311
|
-
|
312
|
-
|
313
|
-
|
314
|
-
buffer_view = self.B_buffer[name][layer_id][stacked_id][
|
315
|
-
buffer_id
|
316
|
-
][:, :lora_rank]
|
317
|
-
weight_slice = (
|
318
|
-
weights[stacked_id] if weights is not None else None
|
319
|
-
)
|
320
|
-
load_lora_weight_tensor(buffer_view, weight_slice)
|
321
|
-
else:
|
322
|
-
buffer_view = self.B_buffer[name][layer_id][0][buffer_id][
|
323
|
-
:, :lora_rank
|
324
|
-
]
|
325
|
-
load_lora_weight_tensor(buffer_view, weights)
|
280
|
+
target_buffer = self.B_buffer[name][layer_id]
|
281
|
+
buffer_view = target_buffer[buffer_id, :, :lora_rank]
|
282
|
+
load_lora_weight_tensor(buffer_view, weights)
|
326
283
|
|
327
284
|
def get_tensor(
|
328
285
|
self, weight_name: str, layer_id: int, lora_type: LoRAType
|
@@ -119,7 +119,7 @@ def _qkv_lora_b_kernel(
|
|
119
119
|
output_ptr = (output + seg_start * output_stride_0 + n_start * output_stride_1) + (
|
120
120
|
s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1
|
121
121
|
)
|
122
|
-
output_mask = (s_offset[:, None] < seg_len)
|
122
|
+
output_mask = (s_offset[:, None] < seg_len) & (n_offset[None, :] < n_size)
|
123
123
|
partial_sum += tl.load(output_ptr, mask=output_mask)
|
124
124
|
tl.store(output_ptr, partial_sum, mask=output_mask)
|
125
125
|
|
sglang/srt/lora/utils.py
CHANGED
@@ -47,34 +47,6 @@ def get_layer_id(name: str) -> int:
|
|
47
47
|
return int(match.group(1))
|
48
48
|
|
49
49
|
|
50
|
-
def get_customized_names_from_hf_names(
|
51
|
-
hf_module_names: Set[str], base_model: torch.nn.Module
|
52
|
-
) -> Set[str]:
|
53
|
-
"""
|
54
|
-
This function takes in a set of huggingface style module names:
|
55
|
-
e.g., {"k_proj", "q_proj", "v_proj", "o_proj"}
|
56
|
-
and outputs a set of module names of customized sglang layers:
|
57
|
-
e.g., {"qkv_proj", "o_proj"}
|
58
|
-
"""
|
59
|
-
if hasattr(base_model, "get_module_name"):
|
60
|
-
return {base_model.get_module_name(name) for name in hf_module_names}
|
61
|
-
else:
|
62
|
-
"""
|
63
|
-
Fallback solution of mapping from config module name to module name in model class.
|
64
|
-
Please check if it aligns with your base model.
|
65
|
-
Please implement the function in the model class if it is not.
|
66
|
-
You can reference this function in llama.py.
|
67
|
-
"""
|
68
|
-
params_mapping = {
|
69
|
-
"q_proj": "qkv_proj",
|
70
|
-
"k_proj": "qkv_proj",
|
71
|
-
"v_proj": "qkv_proj",
|
72
|
-
"gate_proj": "gate_up_proj",
|
73
|
-
"up_proj": "gate_up_proj",
|
74
|
-
}
|
75
|
-
return {params_mapping.get(name, name) for name in hf_module_names}
|
76
|
-
|
77
|
-
|
78
50
|
def get_hidden_dim(
|
79
51
|
module_name: str, config: AutoConfig, base_model: torch.nn.Module
|
80
52
|
) -> Tuple[int]:
|
@@ -92,14 +64,20 @@ def get_hidden_dim(
|
|
92
64
|
Please implement the function in the model class if it is not.
|
93
65
|
You can reference this function in llama.py.
|
94
66
|
"""
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
67
|
+
head_dim = getattr(
|
68
|
+
config, "head_dim", config.hidden_size // config.num_attention_heads
|
69
|
+
)
|
70
|
+
if module_name == "qkv_proj":
|
71
|
+
return config.hidden_size, head_dim * (
|
72
|
+
config.num_attention_heads + config.num_key_value_heads * 2
|
73
|
+
)
|
74
|
+
elif module_name == "o_proj":
|
75
|
+
return (
|
76
|
+
head_dim * config.num_attention_heads,
|
77
|
+
config.hidden_size,
|
100
78
|
)
|
101
79
|
elif module_name == "gate_up_proj":
|
102
|
-
return config.hidden_size, config.intermediate_size
|
80
|
+
return config.hidden_size, config.intermediate_size * 2
|
103
81
|
elif module_name == "down_proj":
|
104
82
|
return config.intermediate_size, config.hidden_size
|
105
83
|
else:
|
@@ -108,26 +86,22 @@ def get_hidden_dim(
|
|
108
86
|
|
109
87
|
def get_normalized_lora_weight_names(
|
110
88
|
target_modules: Iterable[str],
|
111
|
-
) ->
|
89
|
+
) -> set[str]:
|
112
90
|
"""
|
113
91
|
Mapping a list of target module name to names of the normalized LoRA weights.
|
114
|
-
Returned tuple contains (name for Lora A, name for Lora B)
|
115
92
|
"""
|
116
93
|
params_mapping = {
|
117
|
-
"q_proj":
|
118
|
-
"k_proj":
|
119
|
-
"v_proj":
|
120
|
-
"gate_proj":
|
121
|
-
"up_proj":
|
122
|
-
"qkv_proj": (["qkv_proj"], ["q_proj", "kv_proj"]),
|
123
|
-
"gate_up_proj": (["gate_up_proj"], ["gate_up_proj"]),
|
94
|
+
"q_proj": "qkv_proj",
|
95
|
+
"k_proj": "qkv_proj",
|
96
|
+
"v_proj": "qkv_proj",
|
97
|
+
"gate_proj": "gate_up_proj",
|
98
|
+
"up_proj": "gate_up_proj",
|
124
99
|
}
|
125
100
|
|
126
|
-
result =
|
101
|
+
result = set()
|
127
102
|
for name in target_modules:
|
128
|
-
|
129
|
-
result
|
130
|
-
result[1].update(lora_b)
|
103
|
+
weight_name = params_mapping.get(name, name)
|
104
|
+
result.add(weight_name)
|
131
105
|
return result
|
132
106
|
|
133
107
|
|
@@ -137,23 +111,21 @@ def get_stacked_multiply(module_name: str) -> int:
|
|
137
111
|
"""
|
138
112
|
stacked_rank = {
|
139
113
|
"qkv_proj": 3,
|
140
|
-
"kv_proj": 2,
|
141
114
|
"gate_up_proj": 2,
|
142
115
|
}
|
143
116
|
return stacked_rank[module_name] if module_name in stacked_rank else 1
|
144
117
|
|
145
118
|
|
146
119
|
def get_weight_name(
|
147
|
-
target_name: str, lora_weight_names: Tuple[Set[str]]
|
120
|
+
target_name: str, lora_weight_names: Tuple[Set[str]]
|
148
121
|
) -> Optional[str]:
|
149
122
|
"""
|
150
|
-
|
151
|
-
|
123
|
+
Get the weight name in lora_weight_names that can match target_name.
|
124
|
+
|
152
125
|
If there is a weight name in lora_weight_names that can match target_name, return this name
|
153
126
|
Else raise ValueError.
|
154
127
|
"""
|
155
|
-
|
156
|
-
for weight_name in lora_weight_names[idx]:
|
128
|
+
for weight_name in lora_weight_names:
|
157
129
|
if weight_name in target_name:
|
158
130
|
return weight_name
|
159
131
|
raise ValueError(
|
@@ -161,9 +133,4 @@ def get_weight_name(
|
|
161
133
|
)
|
162
134
|
|
163
135
|
|
164
|
-
# TODO: [PR #4274] For future use to simplify the mapping between HF module names and customized module names.
|
165
|
-
VOCAB_PARALLELISM_EMBEDDING_NAMES = ["embeddings"]
|
166
|
-
COLUMN_PARALLELISM_LINEAR_LORA_NAMES = ["gate_proj", "up_proj"]
|
167
|
-
MERGED_COLUMN_PARALLELISM_LINEAR_LORA_NAMES = ["gate_up_proj"]
|
168
|
-
QKV_PARALLELISM_LINEAR_LORA_NAMES = ["qkv_proj"]
|
169
136
|
ROW_PARALLELISM_LINEAR_LORA_NAMES = ["o_proj", "down_proj"]
|