sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -3
- sglang/bench_one_batch.py +119 -17
- sglang/lang/chat_template.py +18 -0
- sglang/srt/bench_utils.py +137 -0
- sglang/srt/configs/model_config.py +42 -7
- sglang/srt/conversation.py +9 -5
- sglang/srt/disaggregation/base/conn.py +5 -2
- sglang/srt/disaggregation/decode.py +14 -4
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
- sglang/srt/disaggregation/mooncake/conn.py +286 -160
- sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
- sglang/srt/disaggregation/prefill.py +2 -0
- sglang/srt/distributed/parallel_state.py +15 -11
- sglang/srt/entrypoints/context.py +227 -0
- sglang/srt/entrypoints/engine.py +15 -9
- sglang/srt/entrypoints/harmony_utils.py +372 -0
- sglang/srt/entrypoints/http_server.py +74 -4
- sglang/srt/entrypoints/openai/protocol.py +218 -1
- sglang/srt/entrypoints/openai/serving_chat.py +41 -11
- sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
- sglang/srt/entrypoints/openai/tool_server.py +175 -0
- sglang/srt/entrypoints/tool.py +87 -0
- sglang/srt/eplb/expert_location.py +5 -1
- sglang/srt/function_call/ebnf_composer.py +1 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +331 -0
- sglang/srt/function_call/kimik2_detector.py +3 -3
- sglang/srt/function_call/qwen3_coder_detector.py +219 -9
- sglang/srt/hf_transformers_utils.py +30 -3
- sglang/srt/jinja_template_utils.py +14 -1
- sglang/srt/layers/attention/aiter_backend.py +375 -115
- sglang/srt/layers/attention/ascend_backend.py +3 -0
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
- sglang/srt/layers/attention/flashattention_backend.py +18 -0
- sglang/srt/layers/attention/flashinfer_backend.py +52 -13
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +85 -14
- sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
- sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
- sglang/srt/layers/attention/trtllm_mla_backend.py +119 -22
- sglang/srt/layers/attention/vision.py +22 -6
- sglang/srt/layers/attention/wave_backend.py +627 -0
- sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
- sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
- sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
- sglang/srt/layers/communicator.py +29 -14
- sglang/srt/layers/dp_attention.py +12 -0
- sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
- sglang/srt/layers/linear.py +3 -7
- sglang/srt/layers/moe/cutlass_moe.py +12 -3
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
- sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
- sglang/srt/layers/moe/ep_moe/layer.py +135 -73
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
- sglang/srt/layers/moe/fused_moe_triton/layer.py +412 -33
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
- sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
- sglang/srt/layers/moe/topk.py +16 -4
- sglang/srt/layers/moe/utils.py +16 -0
- sglang/srt/layers/quantization/__init__.py +27 -3
- sglang/srt/layers/quantization/fp4.py +557 -0
- sglang/srt/layers/quantization/fp8.py +3 -6
- sglang/srt/layers/quantization/fp8_kernel.py +277 -0
- sglang/srt/layers/quantization/fp8_utils.py +51 -10
- sglang/srt/layers/quantization/modelopt_quant.py +258 -68
- sglang/srt/layers/quantization/mxfp4.py +654 -0
- sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
- sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
- sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
- sglang/srt/layers/quantization/quark/utils.py +107 -0
- sglang/srt/layers/quantization/unquant.py +60 -6
- sglang/srt/layers/quantization/w4afp8.py +21 -12
- sglang/srt/layers/quantization/w8a8_int8.py +48 -34
- sglang/srt/layers/rotary_embedding.py +506 -3
- sglang/srt/layers/utils.py +9 -0
- sglang/srt/layers/vocab_parallel_embedding.py +8 -3
- sglang/srt/lora/backend/base_backend.py +3 -23
- sglang/srt/lora/layers.py +60 -114
- sglang/srt/lora/lora.py +17 -62
- sglang/srt/lora/lora_manager.py +82 -62
- sglang/srt/lora/lora_registry.py +23 -11
- sglang/srt/lora/mem_pool.py +63 -68
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/utils.py +25 -58
- sglang/srt/managers/cache_controller.py +75 -58
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +20 -8
- sglang/srt/managers/mm_utils.py +6 -13
- sglang/srt/managers/multimodal_processor.py +1 -1
- sglang/srt/managers/schedule_batch.py +61 -25
- sglang/srt/managers/schedule_policy.py +6 -6
- sglang/srt/managers/scheduler.py +41 -19
- sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
- sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
- sglang/srt/managers/scheduler_recv_skipper.py +37 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
- sglang/srt/managers/template_manager.py +35 -1
- sglang/srt/managers/tokenizer_manager.py +47 -30
- sglang/srt/managers/tp_worker.py +3 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
- sglang/srt/mem_cache/allocator.py +61 -87
- sglang/srt/mem_cache/hicache_storage.py +1 -1
- sglang/srt/mem_cache/hiradix_cache.py +80 -22
- sglang/srt/mem_cache/lora_radix_cache.py +421 -0
- sglang/srt/mem_cache/memory_pool_host.py +34 -36
- sglang/srt/mem_cache/multimodal_cache.py +33 -13
- sglang/srt/mem_cache/radix_cache.py +2 -5
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
- sglang/srt/model_executor/cuda_graph_runner.py +29 -9
- sglang/srt/model_executor/forward_batch_info.py +61 -19
- sglang/srt/model_executor/model_runner.py +148 -37
- sglang/srt/model_loader/loader.py +18 -6
- sglang/srt/model_loader/weight_utils.py +10 -0
- sglang/srt/models/bailing_moe.py +425 -0
- sglang/srt/models/deepseek_v2.py +137 -59
- sglang/srt/models/ernie4.py +426 -0
- sglang/srt/models/ernie4_eagle.py +203 -0
- sglang/srt/models/gemma2.py +0 -34
- sglang/srt/models/gemma3n_mm.py +38 -0
- sglang/srt/models/glm4.py +6 -0
- sglang/srt/models/glm4_moe.py +28 -16
- sglang/srt/models/glm4v.py +589 -0
- sglang/srt/models/glm4v_moe.py +400 -0
- sglang/srt/models/gpt_oss.py +1251 -0
- sglang/srt/models/granite.py +0 -25
- sglang/srt/models/llama.py +0 -25
- sglang/srt/models/llama4.py +1 -1
- sglang/srt/models/qwen2.py +6 -0
- sglang/srt/models/qwen2_5_vl.py +7 -3
- sglang/srt/models/qwen2_audio.py +10 -9
- sglang/srt/models/qwen2_moe.py +6 -0
- sglang/srt/models/qwen3.py +0 -24
- sglang/srt/models/qwen3_moe.py +32 -6
- sglang/srt/models/registry.py +1 -1
- sglang/srt/models/step3_vl.py +9 -0
- sglang/srt/models/torch_native_llama.py +0 -24
- sglang/srt/models/transformers.py +2 -5
- sglang/srt/multimodal/processors/base_processor.py +23 -13
- sglang/srt/multimodal/processors/glm4v.py +132 -0
- sglang/srt/multimodal/processors/qwen_audio.py +4 -2
- sglang/srt/multimodal/processors/step3_vl.py +3 -1
- sglang/srt/reasoning_parser.py +332 -37
- sglang/srt/server_args.py +186 -75
- sglang/srt/speculative/eagle_worker.py +16 -0
- sglang/srt/two_batch_overlap.py +169 -9
- sglang/srt/utils.py +41 -5
- sglang/srt/weight_sync/tensor_bucket.py +106 -0
- sglang/test/attention/test_trtllm_mla_backend.py +186 -36
- sglang/test/doc_patch.py +59 -0
- sglang/test/few_shot_gsm8k.py +1 -1
- sglang/test/few_shot_gsm8k_engine.py +1 -1
- sglang/test/run_eval.py +4 -1
- sglang/test/runners.py +2 -2
- sglang/test/simple_eval_common.py +6 -0
- sglang/test/simple_eval_gpqa.py +2 -0
- sglang/test/test_fp4_moe.py +118 -36
- sglang/test/test_utils.py +1 -1
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/METADATA +36 -38
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/RECORD +174 -141
- sglang/srt/lora/backend/flashinfer_backend.py +0 -131
- /sglang/{api.py → lang/api.py} +0 -0
- /sglang/{lang/backend → srt/layers/quantization/quark}/__init__.py +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/WHEEL +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/top_level.txt +0 -0
sglang/srt/lora/lora_registry.py
CHANGED
@@ -14,7 +14,6 @@
|
|
14
14
|
|
15
15
|
|
16
16
|
import asyncio
|
17
|
-
from collections import defaultdict
|
18
17
|
from dataclasses import dataclass, field, fields
|
19
18
|
from typing import Dict, List, Optional, Union
|
20
19
|
from uuid import uuid4
|
@@ -28,14 +27,15 @@ class LoRARef:
|
|
28
27
|
"""
|
29
28
|
Reference record for a LoRA model.
|
30
29
|
|
31
|
-
This object guarantees a unique ``lora_id`` and may include ``lora_name`` and ``
|
32
|
-
eliminates conflicts from reused LoRA names or paths and can be used to generate deterministic cache
|
30
|
+
This object guarantees a unique ``lora_id`` and may include ``lora_name``, ``lora_path``, and ``pinned``.
|
31
|
+
The ID eliminates conflicts from reused LoRA names or paths and can be used to generate deterministic cache
|
33
32
|
keys (e.g., radix cache).
|
34
33
|
"""
|
35
34
|
|
36
35
|
lora_id: str = field(default_factory=lambda: uuid4().hex)
|
37
36
|
lora_name: Optional[str] = None
|
38
37
|
lora_path: Optional[str] = None
|
38
|
+
pinned: Optional[bool] = None
|
39
39
|
|
40
40
|
def __post_init__(self):
|
41
41
|
if self.lora_id is None:
|
@@ -105,7 +105,6 @@ class LoRARegistry:
|
|
105
105
|
f"LoRA with name {lora_name} does not exist. Loaded LoRAs: {self._registry.keys()}"
|
106
106
|
)
|
107
107
|
del self._registry[lora_name]
|
108
|
-
del self._counters[lora_ref.lora_id]
|
109
108
|
|
110
109
|
return lora_ref.lora_id
|
111
110
|
|
@@ -116,6 +115,9 @@ class LoRARegistry:
|
|
116
115
|
"""
|
117
116
|
|
118
117
|
def _lookup(name: str) -> str:
|
118
|
+
if name is None:
|
119
|
+
return None
|
120
|
+
|
119
121
|
lora_ref = self._registry.get(name, None)
|
120
122
|
if lora_ref is None:
|
121
123
|
raise ValueError(
|
@@ -134,7 +136,11 @@ class LoRARegistry:
|
|
134
136
|
|
135
137
|
# Increment the counters only after all IDs are looked up.
|
136
138
|
await asyncio.gather(
|
137
|
-
*[
|
139
|
+
*[
|
140
|
+
self._counters[id].increment(notify_all=False)
|
141
|
+
for id in lora_ids
|
142
|
+
if id is not None
|
143
|
+
]
|
138
144
|
)
|
139
145
|
return lora_ids
|
140
146
|
else:
|
@@ -152,7 +158,11 @@ class LoRARegistry:
|
|
152
158
|
await self._counters[lora_id].decrement()
|
153
159
|
elif isinstance(lora_id, list):
|
154
160
|
await asyncio.gather(
|
155
|
-
*[
|
161
|
+
*[
|
162
|
+
self._counters[id].decrement()
|
163
|
+
for id in lora_id
|
164
|
+
if id is not None
|
165
|
+
]
|
156
166
|
)
|
157
167
|
else:
|
158
168
|
raise TypeError("lora_id must be either a string or a list of strings.")
|
@@ -168,11 +178,13 @@ class LoRARegistry:
|
|
168
178
|
assert (
|
169
179
|
lora_id not in self._registry
|
170
180
|
), "wait_for_unload should only be called after the LoRA adapter has been unregistered. "
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
181
|
+
assert (
|
182
|
+
lora_id in self._counters
|
183
|
+
), "The LoRA ID should still have a counter if it has been registered before."
|
184
|
+
|
185
|
+
# Wait until no requests are using this LoRA adapter.
|
186
|
+
await self._counters[lora_id].wait_for_zero()
|
187
|
+
del self._counters[lora_id]
|
176
188
|
|
177
189
|
def _register_adapter(self, lora_ref: LoRARef):
|
178
190
|
"""
|
sglang/srt/lora/mem_pool.py
CHANGED
@@ -1,3 +1,4 @@
|
|
1
|
+
import logging
|
1
2
|
from typing import Callable, Dict, Iterable, List, Optional, Set, Tuple, Union
|
2
3
|
|
3
4
|
import torch
|
@@ -7,6 +8,7 @@ from sglang.srt.hf_transformers_utils import AutoConfig
|
|
7
8
|
from sglang.srt.lora.layers import BaseLayerWithLoRA
|
8
9
|
from sglang.srt.lora.lora import LoRAAdapter
|
9
10
|
from sglang.srt.lora.lora_config import LoRAConfig
|
11
|
+
from sglang.srt.lora.lora_registry import LoRARef
|
10
12
|
from sglang.srt.lora.utils import (
|
11
13
|
ROW_PARALLELISM_LINEAR_LORA_NAMES,
|
12
14
|
LoRAType,
|
@@ -16,6 +18,28 @@ from sglang.srt.lora.utils import (
|
|
16
18
|
get_weight_name,
|
17
19
|
)
|
18
20
|
|
21
|
+
logger = logging.getLogger(__name__)
|
22
|
+
|
23
|
+
|
24
|
+
class EmptySlot:
|
25
|
+
"""
|
26
|
+
Singleton class to represent an empty slot in the memory pool.
|
27
|
+
This is used to improve readability by not using special str as a placeholder.
|
28
|
+
"""
|
29
|
+
|
30
|
+
__slots__ = ()
|
31
|
+
|
32
|
+
def __repr__(self):
|
33
|
+
return "|EMPTY|"
|
34
|
+
|
35
|
+
def __new__(cls):
|
36
|
+
if not hasattr(cls, "_instance"):
|
37
|
+
cls._instance = super().__new__(cls)
|
38
|
+
return cls._instance
|
39
|
+
|
40
|
+
|
41
|
+
EMPTY_SLOT = EmptySlot()
|
42
|
+
|
19
43
|
|
20
44
|
class LoRAMemoryPool:
|
21
45
|
"""Class for memory pool management of lora modules"""
|
@@ -28,7 +52,7 @@ class LoRAMemoryPool:
|
|
28
52
|
tp_size: int,
|
29
53
|
tp_rank: int,
|
30
54
|
max_lora_rank: int,
|
31
|
-
lora_weight_names:
|
55
|
+
lora_weight_names: Set[str],
|
32
56
|
base_model: torch.nn.Module,
|
33
57
|
):
|
34
58
|
self.base_hf_config: AutoConfig = base_hf_config
|
@@ -38,9 +62,7 @@ class LoRAMemoryPool:
|
|
38
62
|
self.tp_size: int = tp_size
|
39
63
|
self.tp_rank: int = tp_rank
|
40
64
|
self.max_lora_rank: int = max_lora_rank
|
41
|
-
|
42
|
-
# lora weight names for LoRA A and B respectively.
|
43
|
-
self.lora_weight_names: Tuple[Set[str], Set[str]] = lora_weight_names
|
65
|
+
self.lora_weight_names: Set[str] = lora_weight_names
|
44
66
|
|
45
67
|
# Both A_buffer and B_buffer maps lora weight names to its buffer space.
|
46
68
|
# A_buffer contains num_layer number of row-major tensors with shape
|
@@ -54,9 +76,11 @@ class LoRAMemoryPool:
|
|
54
76
|
self.uid_to_buffer_id: Dict[Optional[str], int] = {}
|
55
77
|
|
56
78
|
# Buffer idx -> lora uid in memory pool
|
57
|
-
# All uids are initialized as
|
79
|
+
# All uids are initialized as `EmptySlot` for empty buffer slots
|
58
80
|
# Here we don't initialize to None since None is a valid uid
|
59
|
-
self.buffer_id_to_uid: List[
|
81
|
+
self.buffer_id_to_uid: List[Union[str, None, EmptySlot]] = [
|
82
|
+
EMPTY_SLOT
|
83
|
+
] * self.max_loras_per_batch
|
60
84
|
|
61
85
|
self.init_buffers(base_model)
|
62
86
|
|
@@ -71,12 +95,8 @@ class LoRAMemoryPool:
|
|
71
95
|
"""
|
72
96
|
if config.r > self.max_lora_rank:
|
73
97
|
return False
|
74
|
-
|
75
|
-
|
76
|
-
)
|
77
|
-
return weights_a.issubset(self.lora_weight_names[0]) and weights_b.issubset(
|
78
|
-
self.lora_weight_names[1]
|
79
|
-
)
|
98
|
+
weights = get_normalized_lora_weight_names(config.target_modules)
|
99
|
+
return weights.issubset(self.lora_weight_names)
|
80
100
|
|
81
101
|
if isinstance(config, LoRAConfig):
|
82
102
|
return _can_support(config)
|
@@ -106,11 +126,9 @@ class LoRAMemoryPool:
|
|
106
126
|
Given a module_name (might be a stacked name), return the hidden dims of modules' input and output.
|
107
127
|
"""
|
108
128
|
_, output_dim = get_hidden_dim(module_name, self.base_hf_config, base_model)
|
109
|
-
c = get_stacked_multiply(module_name)
|
110
129
|
if self.tp_size > 1 and module_name not in ROW_PARALLELISM_LINEAR_LORA_NAMES:
|
111
130
|
output_dim = divide(output_dim, self.tp_size)
|
112
131
|
return (
|
113
|
-
c,
|
114
132
|
self.max_loras_per_batch,
|
115
133
|
output_dim,
|
116
134
|
max_lora_dim,
|
@@ -139,13 +157,13 @@ class LoRAMemoryPool:
|
|
139
157
|
|
140
158
|
init_buffer(
|
141
159
|
self.A_buffer,
|
142
|
-
self.lora_weight_names
|
160
|
+
self.lora_weight_names,
|
143
161
|
self.get_lora_A_shape,
|
144
162
|
)
|
145
163
|
|
146
164
|
init_buffer(
|
147
165
|
self.B_buffer,
|
148
|
-
self.lora_weight_names
|
166
|
+
self.lora_weight_names,
|
149
167
|
self.get_lora_B_shape,
|
150
168
|
)
|
151
169
|
|
@@ -154,17 +172,29 @@ class LoRAMemoryPool:
|
|
154
172
|
cur_uids: Set[Optional[str]],
|
155
173
|
lora_adapters: Dict[str, LoRAAdapter],
|
156
174
|
lora_modules: List[Dict[str, BaseLayerWithLoRA]],
|
175
|
+
lora_refs: Dict[str, LoRARef],
|
157
176
|
):
|
158
177
|
def get_available_buffer_slot():
|
159
178
|
for buffer_id in range(self.max_loras_per_batch):
|
160
179
|
# Prioritize empty slots
|
161
|
-
if self.buffer_id_to_uid[buffer_id] ==
|
180
|
+
if self.buffer_id_to_uid[buffer_id] == EMPTY_SLOT:
|
162
181
|
return buffer_id
|
163
182
|
|
164
183
|
for buffer_id in range(self.max_loras_per_batch):
|
184
|
+
uid = self.buffer_id_to_uid[buffer_id]
|
185
|
+
|
165
186
|
# Evict unneeded lora
|
166
|
-
if
|
167
|
-
|
187
|
+
if uid not in cur_uids:
|
188
|
+
# Skip pinned LoRAs
|
189
|
+
# TODO (lifuhuang): we might consider supporting pinning base model (uid == None) in the future.
|
190
|
+
if uid is not None:
|
191
|
+
lora_ref = lora_refs.get(uid)
|
192
|
+
if lora_ref is not None and lora_ref.pinned:
|
193
|
+
continue
|
194
|
+
|
195
|
+
self.uid_to_buffer_id.pop(uid)
|
196
|
+
logger.debug(f"Evicting LoRA {uid} from buffer slot {buffer_id}.")
|
197
|
+
self.buffer_id_to_uid[buffer_id] = EMPTY_SLOT
|
168
198
|
return buffer_id
|
169
199
|
|
170
200
|
raise ValueError(
|
@@ -208,7 +238,7 @@ class LoRAMemoryPool:
|
|
208
238
|
return
|
209
239
|
|
210
240
|
assert lora_adapter is not None
|
211
|
-
lora_rank = lora_adapter.config.
|
241
|
+
lora_rank = lora_adapter.config.r
|
212
242
|
for layer_id in range(self.num_layer):
|
213
243
|
layer_weights = lora_adapter.layers[layer_id].weights
|
214
244
|
temp_A_buffer: Dict[str, Optional[torch.Tensor]] = {
|
@@ -218,73 +248,38 @@ class LoRAMemoryPool:
|
|
218
248
|
weight_name: None for weight_name in self.B_buffer
|
219
249
|
}
|
220
250
|
for name, weights in layer_weights.items():
|
251
|
+
lora_weight_name = get_weight_name(name, self.lora_weight_names)
|
221
252
|
if "lora_A" in name:
|
222
|
-
lora_weight_name = get_weight_name(
|
223
|
-
name, self.lora_weight_names, LoRAType.LORA_A
|
224
|
-
)
|
225
253
|
temp_A_buffer[lora_weight_name] = weights
|
226
254
|
else:
|
227
|
-
lora_weight_name = get_weight_name(
|
228
|
-
name, self.lora_weight_names, LoRAType.LORA_B
|
229
|
-
)
|
230
255
|
temp_B_buffer[lora_weight_name] = weights
|
231
256
|
|
232
257
|
if self.tp_size > 1:
|
233
258
|
cur_layer_modules = lora_modules[layer_id]
|
234
259
|
for module_name, module in cur_layer_modules.items():
|
235
|
-
weight_name = get_weight_name(
|
236
|
-
module_name, self.lora_weight_names, LoRAType.LORA_A
|
237
|
-
)
|
260
|
+
weight_name = get_weight_name(module_name, self.lora_weight_names)
|
238
261
|
|
239
262
|
if temp_A_buffer[weight_name] is None:
|
240
263
|
# Skip weight slicing if the weight is not present in the adapter
|
241
264
|
continue
|
242
265
|
|
243
|
-
|
244
|
-
temp_A_buffer[
|
245
|
-
|
246
|
-
|
247
|
-
temp_B_buffer[
|
248
|
-
|
249
|
-
[temp_B_buffer["q_proj"], temp_B_buffer["kv_proj"]],
|
250
|
-
self.tp_rank,
|
251
|
-
)
|
252
|
-
)
|
253
|
-
else:
|
254
|
-
# TODO (lifuhuang): Ideally, we should call `get_weight_name` separately for both A and B.
|
255
|
-
# Currently, we're reusing A's weight name as a workaround, relying on the fact that A and
|
256
|
-
# B share the same name except for `qkv_proj`. We should clean this up once we deprecate the
|
257
|
-
# FlashInfer LoRA backend.
|
258
|
-
temp_A_buffer[weight_name] = module.slice_lora_a_weights(
|
259
|
-
temp_A_buffer[weight_name], self.tp_rank
|
260
|
-
)
|
261
|
-
temp_B_buffer[weight_name] = module.slice_lora_b_weights(
|
262
|
-
temp_B_buffer[weight_name], self.tp_rank
|
263
|
-
)
|
266
|
+
temp_A_buffer[weight_name] = module.slice_lora_a_weights(
|
267
|
+
temp_A_buffer[weight_name], self.tp_rank
|
268
|
+
)
|
269
|
+
temp_B_buffer[weight_name] = module.slice_lora_b_weights(
|
270
|
+
temp_B_buffer[weight_name], self.tp_rank
|
271
|
+
)
|
264
272
|
|
265
273
|
for name, weights in temp_A_buffer.items():
|
266
274
|
c = get_stacked_multiply(name)
|
267
|
-
|
268
|
-
|
269
|
-
]
|
275
|
+
target_buffer = self.A_buffer[name][layer_id]
|
276
|
+
buffer_view = target_buffer[buffer_id, : lora_rank * c, :]
|
270
277
|
load_lora_weight_tensor(buffer_view, weights)
|
271
278
|
|
272
279
|
for name, weights in temp_B_buffer.items():
|
273
|
-
|
274
|
-
|
275
|
-
|
276
|
-
buffer_view = self.B_buffer[name][layer_id][stacked_id][
|
277
|
-
buffer_id
|
278
|
-
][:, :lora_rank]
|
279
|
-
weight_slice = (
|
280
|
-
weights[stacked_id] if weights is not None else None
|
281
|
-
)
|
282
|
-
load_lora_weight_tensor(buffer_view, weight_slice)
|
283
|
-
else:
|
284
|
-
buffer_view = self.B_buffer[name][layer_id][0][buffer_id][
|
285
|
-
:, :lora_rank
|
286
|
-
]
|
287
|
-
load_lora_weight_tensor(buffer_view, weights)
|
280
|
+
target_buffer = self.B_buffer[name][layer_id]
|
281
|
+
buffer_view = target_buffer[buffer_id, :, :lora_rank]
|
282
|
+
load_lora_weight_tensor(buffer_view, weights)
|
288
283
|
|
289
284
|
def get_tensor(
|
290
285
|
self, weight_name: str, layer_id: int, lora_type: LoRAType
|
@@ -119,7 +119,7 @@ def _qkv_lora_b_kernel(
|
|
119
119
|
output_ptr = (output + seg_start * output_stride_0 + n_start * output_stride_1) + (
|
120
120
|
s_offset[:, None] * output_stride_0 + n_offset[None, :] * output_stride_1
|
121
121
|
)
|
122
|
-
output_mask = (s_offset[:, None] < seg_len)
|
122
|
+
output_mask = (s_offset[:, None] < seg_len) & (n_offset[None, :] < n_size)
|
123
123
|
partial_sum += tl.load(output_ptr, mask=output_mask)
|
124
124
|
tl.store(output_ptr, partial_sum, mask=output_mask)
|
125
125
|
|
sglang/srt/lora/utils.py
CHANGED
@@ -47,34 +47,6 @@ def get_layer_id(name: str) -> int:
|
|
47
47
|
return int(match.group(1))
|
48
48
|
|
49
49
|
|
50
|
-
def get_customized_names_from_hf_names(
|
51
|
-
hf_module_names: Set[str], base_model: torch.nn.Module
|
52
|
-
) -> Set[str]:
|
53
|
-
"""
|
54
|
-
This function takes in a set of huggingface style module names:
|
55
|
-
e.g., {"k_proj", "q_proj", "v_proj", "o_proj"}
|
56
|
-
and outputs a set of module names of customized sglang layers:
|
57
|
-
e.g., {"qkv_proj", "o_proj"}
|
58
|
-
"""
|
59
|
-
if hasattr(base_model, "get_module_name"):
|
60
|
-
return {base_model.get_module_name(name) for name in hf_module_names}
|
61
|
-
else:
|
62
|
-
"""
|
63
|
-
Fallback solution of mapping from config module name to module name in model class.
|
64
|
-
Please check if it aligns with your base model.
|
65
|
-
Please implement the function in the model class if it is not.
|
66
|
-
You can reference this function in llama.py.
|
67
|
-
"""
|
68
|
-
params_mapping = {
|
69
|
-
"q_proj": "qkv_proj",
|
70
|
-
"k_proj": "qkv_proj",
|
71
|
-
"v_proj": "qkv_proj",
|
72
|
-
"gate_proj": "gate_up_proj",
|
73
|
-
"up_proj": "gate_up_proj",
|
74
|
-
}
|
75
|
-
return {params_mapping.get(name, name) for name in hf_module_names}
|
76
|
-
|
77
|
-
|
78
50
|
def get_hidden_dim(
|
79
51
|
module_name: str, config: AutoConfig, base_model: torch.nn.Module
|
80
52
|
) -> Tuple[int]:
|
@@ -92,14 +64,20 @@ def get_hidden_dim(
|
|
92
64
|
Please implement the function in the model class if it is not.
|
93
65
|
You can reference this function in llama.py.
|
94
66
|
"""
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
67
|
+
head_dim = getattr(
|
68
|
+
config, "head_dim", config.hidden_size // config.num_attention_heads
|
69
|
+
)
|
70
|
+
if module_name == "qkv_proj":
|
71
|
+
return config.hidden_size, head_dim * (
|
72
|
+
config.num_attention_heads + config.num_key_value_heads * 2
|
73
|
+
)
|
74
|
+
elif module_name == "o_proj":
|
75
|
+
return (
|
76
|
+
head_dim * config.num_attention_heads,
|
77
|
+
config.hidden_size,
|
100
78
|
)
|
101
79
|
elif module_name == "gate_up_proj":
|
102
|
-
return config.hidden_size, config.intermediate_size
|
80
|
+
return config.hidden_size, config.intermediate_size * 2
|
103
81
|
elif module_name == "down_proj":
|
104
82
|
return config.intermediate_size, config.hidden_size
|
105
83
|
else:
|
@@ -108,26 +86,22 @@ def get_hidden_dim(
|
|
108
86
|
|
109
87
|
def get_normalized_lora_weight_names(
|
110
88
|
target_modules: Iterable[str],
|
111
|
-
) ->
|
89
|
+
) -> set[str]:
|
112
90
|
"""
|
113
91
|
Mapping a list of target module name to names of the normalized LoRA weights.
|
114
|
-
Returned tuple contains (name for Lora A, name for Lora B)
|
115
92
|
"""
|
116
93
|
params_mapping = {
|
117
|
-
"q_proj":
|
118
|
-
"k_proj":
|
119
|
-
"v_proj":
|
120
|
-
"gate_proj":
|
121
|
-
"up_proj":
|
122
|
-
"qkv_proj": (["qkv_proj"], ["q_proj", "kv_proj"]),
|
123
|
-
"gate_up_proj": (["gate_up_proj"], ["gate_up_proj"]),
|
94
|
+
"q_proj": "qkv_proj",
|
95
|
+
"k_proj": "qkv_proj",
|
96
|
+
"v_proj": "qkv_proj",
|
97
|
+
"gate_proj": "gate_up_proj",
|
98
|
+
"up_proj": "gate_up_proj",
|
124
99
|
}
|
125
100
|
|
126
|
-
result =
|
101
|
+
result = set()
|
127
102
|
for name in target_modules:
|
128
|
-
|
129
|
-
result
|
130
|
-
result[1].update(lora_b)
|
103
|
+
weight_name = params_mapping.get(name, name)
|
104
|
+
result.add(weight_name)
|
131
105
|
return result
|
132
106
|
|
133
107
|
|
@@ -137,23 +111,21 @@ def get_stacked_multiply(module_name: str) -> int:
|
|
137
111
|
"""
|
138
112
|
stacked_rank = {
|
139
113
|
"qkv_proj": 3,
|
140
|
-
"kv_proj": 2,
|
141
114
|
"gate_up_proj": 2,
|
142
115
|
}
|
143
116
|
return stacked_rank[module_name] if module_name in stacked_rank else 1
|
144
117
|
|
145
118
|
|
146
119
|
def get_weight_name(
|
147
|
-
target_name: str, lora_weight_names: Tuple[Set[str]]
|
120
|
+
target_name: str, lora_weight_names: Tuple[Set[str]]
|
148
121
|
) -> Optional[str]:
|
149
122
|
"""
|
150
|
-
|
151
|
-
|
123
|
+
Get the weight name in lora_weight_names that can match target_name.
|
124
|
+
|
152
125
|
If there is a weight name in lora_weight_names that can match target_name, return this name
|
153
126
|
Else raise ValueError.
|
154
127
|
"""
|
155
|
-
|
156
|
-
for weight_name in lora_weight_names[idx]:
|
128
|
+
for weight_name in lora_weight_names:
|
157
129
|
if weight_name in target_name:
|
158
130
|
return weight_name
|
159
131
|
raise ValueError(
|
@@ -161,9 +133,4 @@ def get_weight_name(
|
|
161
133
|
)
|
162
134
|
|
163
135
|
|
164
|
-
# TODO: [PR #4274] For future use to simplify the mapping between HF module names and customized module names.
|
165
|
-
VOCAB_PARALLELISM_EMBEDDING_NAMES = ["embeddings"]
|
166
|
-
COLUMN_PARALLELISM_LINEAR_LORA_NAMES = ["gate_proj", "up_proj"]
|
167
|
-
MERGED_COLUMN_PARALLELISM_LINEAR_LORA_NAMES = ["gate_up_proj"]
|
168
|
-
QKV_PARALLELISM_LINEAR_LORA_NAMES = ["qkv_proj"]
|
169
136
|
ROW_PARALLELISM_LINEAR_LORA_NAMES = ["o_proj", "down_proj"]
|