sglang 0.4.4.post1__py3-none-any.whl → 0.4.4.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +6 -0
- sglang/bench_one_batch.py +1 -1
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +26 -4
- sglang/check_env.py +3 -4
- sglang/lang/backend/openai.py +18 -5
- sglang/lang/chat_template.py +28 -7
- sglang/lang/interpreter.py +7 -3
- sglang/lang/ir.py +10 -0
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/code_completion_parser.py +174 -0
- sglang/srt/configs/__init__.py +2 -6
- sglang/srt/configs/deepseekvl2.py +676 -0
- sglang/srt/configs/janus_pro.py +3 -4
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +49 -8
- sglang/srt/configs/utils.py +25 -0
- sglang/srt/connector/__init__.py +51 -0
- sglang/srt/connector/base_connector.py +112 -0
- sglang/srt/connector/redis.py +85 -0
- sglang/srt/connector/s3.py +122 -0
- sglang/srt/connector/serde/__init__.py +31 -0
- sglang/srt/connector/serde/safe_serde.py +29 -0
- sglang/srt/connector/serde/serde.py +43 -0
- sglang/srt/connector/utils.py +35 -0
- sglang/srt/conversation.py +88 -0
- sglang/srt/disaggregation/conn.py +81 -0
- sglang/srt/disaggregation/decode.py +495 -0
- sglang/srt/disaggregation/mini_lb.py +285 -0
- sglang/srt/disaggregation/prefill.py +249 -0
- sglang/srt/disaggregation/utils.py +44 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
- sglang/srt/distributed/parallel_state.py +42 -8
- sglang/srt/entrypoints/engine.py +55 -5
- sglang/srt/entrypoints/http_server.py +78 -13
- sglang/srt/entrypoints/verl_engine.py +2 -0
- sglang/srt/function_call_parser.py +133 -55
- sglang/srt/hf_transformers_utils.py +28 -3
- sglang/srt/layers/activation.py +4 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +434 -0
- sglang/srt/layers/attention/flashinfer_backend.py +1 -1
- sglang/srt/layers/attention/flashmla_backend.py +284 -0
- sglang/srt/layers/attention/triton_backend.py +171 -38
- sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
- sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
- sglang/srt/layers/attention/utils.py +53 -0
- sglang/srt/layers/attention/vision.py +9 -28
- sglang/srt/layers/dp_attention.py +41 -19
- sglang/srt/layers/layernorm.py +24 -2
- sglang/srt/layers/linear.py +17 -5
- sglang/srt/layers/logits_processor.py +25 -7
- sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
- sglang/srt/layers/moe/ep_moe/layer.py +273 -1
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
- sglang/srt/layers/moe/fused_moe_native.py +2 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
- sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
- sglang/srt/layers/moe/topk.py +60 -20
- sglang/srt/layers/parameter.py +1 -1
- sglang/srt/layers/quantization/__init__.py +80 -53
- sglang/srt/layers/quantization/awq.py +200 -0
- sglang/srt/layers/quantization/base_config.py +5 -0
- sglang/srt/layers/quantization/blockwise_int8.py +1 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
- sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
- sglang/srt/layers/quantization/fp8.py +76 -34
- sglang/srt/layers/quantization/fp8_kernel.py +25 -8
- sglang/srt/layers/quantization/fp8_utils.py +284 -28
- sglang/srt/layers/quantization/gptq.py +36 -19
- sglang/srt/layers/quantization/kv_cache.py +98 -0
- sglang/srt/layers/quantization/modelopt_quant.py +9 -7
- sglang/srt/layers/quantization/utils.py +153 -0
- sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
- sglang/srt/layers/rotary_embedding.py +78 -87
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/backend/base_backend.py +4 -4
- sglang/srt/lora/backend/flashinfer_backend.py +12 -9
- sglang/srt/lora/backend/triton_backend.py +5 -8
- sglang/srt/lora/layers.py +87 -33
- sglang/srt/lora/lora.py +2 -22
- sglang/srt/lora/lora_manager.py +67 -30
- sglang/srt/lora/mem_pool.py +117 -52
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +10 -4
- sglang/srt/lora/triton_ops/qkv_lora_b.py +8 -3
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +16 -5
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +11 -6
- sglang/srt/lora/utils.py +18 -1
- sglang/srt/managers/cache_controller.py +2 -5
- sglang/srt/managers/data_parallel_controller.py +30 -8
- sglang/srt/managers/expert_distribution.py +81 -0
- sglang/srt/managers/io_struct.py +43 -5
- sglang/srt/managers/mm_utils.py +373 -0
- sglang/srt/managers/multimodal_processor.py +68 -0
- sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
- sglang/srt/managers/multimodal_processors/clip.py +63 -0
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
- sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
- sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
- sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
- sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
- sglang/srt/managers/schedule_batch.py +134 -30
- sglang/srt/managers/scheduler.py +290 -31
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +59 -24
- sglang/srt/managers/tp_worker.py +4 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
- sglang/srt/managers/utils.py +6 -1
- sglang/srt/mem_cache/hiradix_cache.py +18 -7
- sglang/srt/mem_cache/memory_pool.py +255 -98
- sglang/srt/mem_cache/paged_allocator.py +2 -2
- sglang/srt/mem_cache/radix_cache.py +4 -4
- sglang/srt/model_executor/cuda_graph_runner.py +36 -21
- sglang/srt/model_executor/forward_batch_info.py +68 -11
- sglang/srt/model_executor/model_runner.py +75 -8
- sglang/srt/model_loader/loader.py +171 -3
- sglang/srt/model_loader/weight_utils.py +51 -3
- sglang/srt/models/clip.py +563 -0
- sglang/srt/models/deepseek_janus_pro.py +31 -88
- sglang/srt/models/deepseek_nextn.py +22 -10
- sglang/srt/models/deepseek_v2.py +329 -73
- sglang/srt/models/deepseek_vl2.py +358 -0
- sglang/srt/models/gemma3_causal.py +694 -0
- sglang/srt/models/gemma3_mm.py +468 -0
- sglang/srt/models/llama.py +47 -7
- sglang/srt/models/llama_eagle.py +1 -0
- sglang/srt/models/llama_eagle3.py +196 -0
- sglang/srt/models/llava.py +3 -3
- sglang/srt/models/llavavid.py +3 -3
- sglang/srt/models/minicpmo.py +1995 -0
- sglang/srt/models/minicpmv.py +62 -137
- sglang/srt/models/mllama.py +4 -4
- sglang/srt/models/phi3_small.py +1 -1
- sglang/srt/models/qwen2.py +3 -0
- sglang/srt/models/qwen2_5_vl.py +68 -146
- sglang/srt/models/qwen2_classification.py +75 -0
- sglang/srt/models/qwen2_moe.py +9 -1
- sglang/srt/models/qwen2_vl.py +25 -63
- sglang/srt/openai_api/adapter.py +201 -104
- sglang/srt/openai_api/protocol.py +33 -7
- sglang/srt/patch_torch.py +71 -0
- sglang/srt/sampling/sampling_batch_info.py +1 -1
- sglang/srt/sampling/sampling_params.py +6 -6
- sglang/srt/server_args.py +114 -14
- sglang/srt/speculative/build_eagle_tree.py +7 -347
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
- sglang/srt/speculative/eagle_utils.py +208 -252
- sglang/srt/speculative/eagle_worker.py +140 -54
- sglang/srt/speculative/spec_info.py +6 -1
- sglang/srt/torch_memory_saver_adapter.py +22 -0
- sglang/srt/utils.py +215 -21
- sglang/test/__init__.py +0 -0
- sglang/test/attention/__init__.py +0 -0
- sglang/test/attention/test_flashattn_backend.py +312 -0
- sglang/test/runners.py +29 -2
- sglang/test/test_activation.py +2 -1
- sglang/test/test_block_fp8.py +5 -4
- sglang/test/test_block_fp8_ep.py +2 -1
- sglang/test/test_dynamic_grad_mode.py +58 -0
- sglang/test/test_layernorm.py +3 -2
- sglang/test/test_utils.py +56 -5
- sglang/utils.py +31 -0
- sglang/version.py +1 -1
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/METADATA +16 -8
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/RECORD +180 -132
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/WHEEL +1 -1
- sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
- sglang/srt/managers/image_processor.py +0 -55
- sglang/srt/managers/image_processors/base_image_processor.py +0 -219
- sglang/srt/managers/image_processors/minicpmv.py +0 -86
- sglang/srt/managers/multi_modality_padding.py +0 -134
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info/licenses}/LICENSE +0 -0
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/top_level.txt +0 -0
sglang/srt/lora/layers.py
CHANGED
@@ -1,3 +1,5 @@
|
|
1
|
+
from typing import List, Tuple
|
2
|
+
|
1
3
|
import torch
|
2
4
|
from torch import nn
|
3
5
|
|
@@ -21,14 +23,10 @@ class BaseLayerWithLoRA(nn.Module):
|
|
21
23
|
def __init__(
|
22
24
|
self,
|
23
25
|
base_layer: nn.Module,
|
24
|
-
lora_rank: int,
|
25
|
-
scaling: float,
|
26
26
|
lora_backend: BaseLoRABackend,
|
27
27
|
):
|
28
28
|
super().__init__()
|
29
29
|
self.base_layer: nn.Module = base_layer
|
30
|
-
self.lora_rank: int = lora_rank
|
31
|
-
self.scaling: float = scaling
|
32
30
|
self.set_lora: bool = False
|
33
31
|
self.lora_backend: BaseLoRABackend = lora_backend
|
34
32
|
|
@@ -38,16 +36,28 @@ class BaseLayerWithLoRA(nn.Module):
|
|
38
36
|
def set_lora_info(self, *args):
|
39
37
|
pass
|
40
38
|
|
39
|
+
def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int):
|
40
|
+
pass
|
41
|
+
|
42
|
+
def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int):
|
43
|
+
pass
|
44
|
+
|
41
45
|
|
42
46
|
class VocabParallelEmbeddingWithLoRA(BaseLayerWithLoRA):
|
47
|
+
"""
|
48
|
+
Vocab parallel embedding layer with support for LoRA (Low-Rank Adaptation).
|
49
|
+
|
50
|
+
Note: The current version does not yet implement the LoRA functionality.
|
51
|
+
This class behaves exactly the same as the base VocabParallelEmbedding.
|
52
|
+
Future versions will integrate LoRA functionality to support efficient parameter fine-tuning.
|
53
|
+
"""
|
54
|
+
|
43
55
|
def __init__(
|
44
56
|
self,
|
45
57
|
base_layer: VocabParallelEmbedding,
|
46
|
-
lora_rank: int,
|
47
|
-
scaling: float,
|
48
58
|
lora_backend: BaseLoRABackend,
|
49
59
|
) -> None:
|
50
|
-
super().__init__(base_layer,
|
60
|
+
super().__init__(base_layer, lora_backend)
|
51
61
|
self.weight = base_layer.weight
|
52
62
|
|
53
63
|
|
@@ -55,11 +65,9 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
|
|
55
65
|
def __init__(
|
56
66
|
self,
|
57
67
|
base_layer: ColumnParallelLinear,
|
58
|
-
lora_rank: int,
|
59
|
-
scaling: float,
|
60
68
|
lora_backend: BaseLoRABackend,
|
61
69
|
) -> None:
|
62
|
-
super().__init__(base_layer,
|
70
|
+
super().__init__(base_layer, lora_backend)
|
63
71
|
|
64
72
|
def set_lora_info(
|
65
73
|
self,
|
@@ -71,7 +79,7 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
|
|
71
79
|
self.B_buffer = B_buffer
|
72
80
|
|
73
81
|
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
74
|
-
backend_kwargs = {"base_output": base_output
|
82
|
+
backend_kwargs = {"base_output": base_output}
|
75
83
|
lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer)
|
76
84
|
lora_output = self.lora_backend.run_lora_b_sgemm(
|
77
85
|
lora_a_output,
|
@@ -80,8 +88,8 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
|
|
80
88
|
)
|
81
89
|
return (
|
82
90
|
lora_output
|
83
|
-
if self.lora_backend.
|
84
|
-
else base_output + lora_output
|
91
|
+
if self.lora_backend.fuse_output_add
|
92
|
+
else base_output + lora_output
|
85
93
|
)
|
86
94
|
|
87
95
|
def forward(self, input_: torch.Tensor):
|
@@ -101,16 +109,24 @@ class ColumnParallelLinearWithLoRA(BaseLayerWithLoRA):
|
|
101
109
|
output_bias = self.base_layer.bias if self.base_layer.skip_bias_add else None
|
102
110
|
return output, output_bias
|
103
111
|
|
112
|
+
def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int):
|
113
|
+
return A
|
114
|
+
|
115
|
+
def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int):
|
116
|
+
shard_size = self.base_layer.output_partition_sizes[0]
|
117
|
+
start_idx = tp_rank * shard_size
|
118
|
+
end_idx = (tp_rank + 1) * shard_size
|
119
|
+
B = B[start_idx:end_idx, :]
|
120
|
+
return B
|
121
|
+
|
104
122
|
|
105
123
|
class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
106
124
|
def __init__(
|
107
125
|
self,
|
108
126
|
base_layer: MergedColumnParallelLinear,
|
109
|
-
lora_rank: int,
|
110
|
-
scaling: float,
|
111
127
|
lora_backend: BaseLoRABackend,
|
112
128
|
) -> None:
|
113
|
-
super().__init__(base_layer,
|
129
|
+
super().__init__(base_layer, lora_backend)
|
114
130
|
|
115
131
|
def set_lora_info(
|
116
132
|
self,
|
@@ -120,6 +136,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|
120
136
|
self.set_lora = True
|
121
137
|
self.A_buffer_gate_up = A_buffer
|
122
138
|
if self.lora_backend.fuse_stacked_lora_b:
|
139
|
+
# TODO: avoid using contiguous() in GPU.
|
123
140
|
# B_buffer_gate_up: (num_lora, 2 * output_dim, r)
|
124
141
|
self.B_buffer_gate_up = torch.cat(
|
125
142
|
(B_buffer[0], B_buffer[1]), dim=-2
|
@@ -128,7 +145,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|
128
145
|
self.B_buffer_gate_up = (B_buffer[0], B_buffer[1])
|
129
146
|
|
130
147
|
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
131
|
-
backend_kwargs = {"base_output": base_output
|
148
|
+
backend_kwargs = {"base_output": base_output}
|
132
149
|
|
133
150
|
lora_output = self.lora_backend.run_gate_up_lora(
|
134
151
|
x,
|
@@ -138,20 +155,28 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|
138
155
|
)
|
139
156
|
return (
|
140
157
|
lora_output
|
141
|
-
if self.lora_backend.
|
142
|
-
else base_output + lora_output
|
158
|
+
if self.lora_backend.fuse_output_add
|
159
|
+
else base_output + lora_output
|
143
160
|
)
|
144
161
|
|
162
|
+
def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int):
|
163
|
+
return A
|
164
|
+
|
165
|
+
def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int):
|
166
|
+
# Since the outputs for both gate and up are identical, we use a random one.
|
167
|
+
shard_size = self.base_layer.output_partition_sizes[0]
|
168
|
+
start_idx = tp_rank * shard_size
|
169
|
+
end_idx = (tp_rank + 1) * shard_size
|
170
|
+
return B[:, start_idx:end_idx, :]
|
171
|
+
|
145
172
|
|
146
173
|
class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
147
174
|
def init__(
|
148
175
|
self,
|
149
176
|
base_layer: QKVParallelLinear,
|
150
|
-
lora_rank: int,
|
151
|
-
scaling: float,
|
152
177
|
lora_backend: BaseLoRABackend,
|
153
178
|
) -> None:
|
154
|
-
super().__init__(base_layer,
|
179
|
+
super().__init__(base_layer, lora_backend)
|
155
180
|
|
156
181
|
def set_lora_info(
|
157
182
|
self,
|
@@ -193,7 +218,7 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|
193
218
|
)
|
194
219
|
|
195
220
|
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
196
|
-
backend_kwargs = {"base_output": base_output
|
221
|
+
backend_kwargs = {"base_output": base_output}
|
197
222
|
if self.lora_backend.fuse_stacked_lora_b:
|
198
223
|
backend_kwargs["output_offset"] = self.output_offset
|
199
224
|
backend_kwargs["max_qkv_out_dim"] = self.max_qkv_out_dim
|
@@ -206,20 +231,39 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|
206
231
|
)
|
207
232
|
return (
|
208
233
|
lora_output
|
209
|
-
if self.lora_backend.
|
210
|
-
else base_output + lora_output
|
234
|
+
if self.lora_backend.fuse_output_add
|
235
|
+
else base_output + lora_output
|
211
236
|
)
|
212
237
|
|
238
|
+
def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int):
|
239
|
+
return A
|
240
|
+
|
241
|
+
def slice_lora_b_weights(
|
242
|
+
self, B: List[torch.Tensor], tp_rank: int
|
243
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
244
|
+
B_q, B_kv = B
|
245
|
+
base_layer = self.base_layer
|
246
|
+
q_proj_shard_size = base_layer.q_proj_shard_size
|
247
|
+
kv_proj_shard_size = base_layer.kv_proj_shard_size
|
248
|
+
num_kv_head_replicas = base_layer.num_kv_head_replicas
|
249
|
+
|
250
|
+
q_start_idx = q_proj_shard_size * tp_rank
|
251
|
+
q_end_idx = q_start_idx + q_proj_shard_size
|
252
|
+
|
253
|
+
kv_shard_id = tp_rank // num_kv_head_replicas
|
254
|
+
kv_start_idx = kv_proj_shard_size * kv_shard_id
|
255
|
+
kv_end_idx = kv_start_idx + kv_proj_shard_size
|
256
|
+
|
257
|
+
return B_q[q_start_idx:q_end_idx, :], B_kv[:, kv_start_idx:kv_end_idx, :]
|
258
|
+
|
213
259
|
|
214
260
|
class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
215
261
|
def __init__(
|
216
262
|
self,
|
217
263
|
base_layer: RowParallelLinear,
|
218
|
-
lora_rank: int,
|
219
|
-
scaling: float,
|
220
264
|
lora_backend: BaseLoRABackend,
|
221
265
|
) -> None:
|
222
|
-
super().__init__(base_layer,
|
266
|
+
super().__init__(base_layer, lora_backend)
|
223
267
|
|
224
268
|
def set_lora_info(self, A_buffer: torch.Tensor, B_buffer: torch.Tensor):
|
225
269
|
self.set_lora = True
|
@@ -227,7 +271,7 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
|
227
271
|
self.B_buffer = B_buffer
|
228
272
|
|
229
273
|
def apply_lora(self, base_output: torch.Tensor, x: torch.Tensor) -> torch.Tensor:
|
230
|
-
backend_kwargs = {"base_output": base_output
|
274
|
+
backend_kwargs = {"base_output": base_output}
|
231
275
|
lora_a_output = self.lora_backend.run_lora_a_sgemm(x, self.A_buffer)
|
232
276
|
lora_output = self.lora_backend.run_lora_b_sgemm(
|
233
277
|
lora_a_output,
|
@@ -236,8 +280,8 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
|
236
280
|
)
|
237
281
|
return (
|
238
282
|
lora_output
|
239
|
-
if self.lora_backend.
|
240
|
-
else base_output + lora_output
|
283
|
+
if self.lora_backend.fuse_output_add
|
284
|
+
else base_output + lora_output
|
241
285
|
)
|
242
286
|
|
243
287
|
def forward(self, input_: torch.Tensor):
|
@@ -274,9 +318,19 @@ class RowParallelLinearWithLoRA(BaseLayerWithLoRA):
|
|
274
318
|
output_bias = self.base_layer.bias
|
275
319
|
return output, output_bias
|
276
320
|
|
321
|
+
def slice_lora_a_weights(self, A: torch.Tensor, tp_rank: int):
|
322
|
+
shard_size = self.base_layer.input_size_per_partition
|
323
|
+
start_idx = tp_rank * shard_size
|
324
|
+
end_idx = (tp_rank + 1) * shard_size
|
325
|
+
A = A[:, start_idx:end_idx].contiguous()
|
326
|
+
return A
|
327
|
+
|
328
|
+
def slice_lora_b_weights(self, B: torch.Tensor, tp_rank: int):
|
329
|
+
return B
|
330
|
+
|
277
331
|
|
278
332
|
def get_lora_layer(
|
279
|
-
layer: nn.Module,
|
333
|
+
layer: nn.Module, lora_backend: BaseLoRABackend
|
280
334
|
) -> BaseLayerWithLoRA:
|
281
335
|
supported_layer_types = {
|
282
336
|
# the order matters
|
@@ -288,6 +342,6 @@ def get_lora_layer(
|
|
288
342
|
}
|
289
343
|
for src_layer_type, lora_layer_type in supported_layer_types.items():
|
290
344
|
if isinstance(layer, src_layer_type): # pylint: disable=unidiomatic-typecheck
|
291
|
-
ret = lora_layer_type(layer,
|
345
|
+
ret = lora_layer_type(layer, lora_backend)
|
292
346
|
return ret
|
293
347
|
raise Exception(f"No corresponding LoRA layer supported for {type(layer)}.")
|
sglang/srt/lora/lora.py
CHANGED
@@ -39,16 +39,9 @@ class LoRALayer(nn.Module):
|
|
39
39
|
super().__init__()
|
40
40
|
self.config: LoRAConfig = config
|
41
41
|
self.base_hf_config: AutoConfig = base_hf_config
|
42
|
-
self.weights: Dict[str, torch.Tensor] = {}
|
43
|
-
self.weight_gpu: Dict[str, torch.Tensor] = {}
|
44
|
-
|
45
|
-
def load_to_gpu(self):
|
46
|
-
for name, weight in self.weights.items():
|
47
|
-
self.weight_gpu[name] = weight.to(torch.float16).to("cuda")
|
48
42
|
|
49
|
-
|
50
|
-
|
51
|
-
self.weight_gpu[name] = None
|
43
|
+
# lora weights in cpu. The weights are loaded from checkpoint.
|
44
|
+
self.weights: Dict[str, torch.Tensor] = {}
|
52
45
|
|
53
46
|
|
54
47
|
class LoRAAdapter(nn.Module):
|
@@ -77,19 +70,6 @@ class LoRAAdapter(nn.Module):
|
|
77
70
|
)
|
78
71
|
|
79
72
|
self.weights: Dict[str, torch.Tensor] = {}
|
80
|
-
self.weights_gpu: Dict[str, torch.Tensor] = {}
|
81
|
-
|
82
|
-
def load_to_gpu(self):
|
83
|
-
for name, weight in self.weights.items():
|
84
|
-
self.weights_gpu[name] = weight.to(torch.float16).to("cuda")
|
85
|
-
for layer in self.layers:
|
86
|
-
layer.load_to_gpu()
|
87
|
-
|
88
|
-
def offload_from_gpu(self):
|
89
|
-
for name, weight in self.weights.items():
|
90
|
-
self.weights_gpu[name] = None
|
91
|
-
for layer in self.layers:
|
92
|
-
layer.offload_from_gpu()
|
93
73
|
|
94
74
|
# initialize the LoRA weights to cpu
|
95
75
|
def initialize_weights(self):
|
sglang/srt/lora/lora_manager.py
CHANGED
@@ -23,7 +23,7 @@ import torch
|
|
23
23
|
from sglang.srt.configs.load_config import LoadConfig
|
24
24
|
from sglang.srt.hf_transformers_utils import AutoConfig
|
25
25
|
from sglang.srt.lora.backend import BaseLoRABackend, get_backend_from_name
|
26
|
-
from sglang.srt.lora.layers import get_lora_layer
|
26
|
+
from sglang.srt.lora.layers import BaseLayerWithLoRA, get_lora_layer
|
27
27
|
from sglang.srt.lora.lora import LoRAAdapter
|
28
28
|
from sglang.srt.lora.lora_config import LoRAConfig
|
29
29
|
from sglang.srt.lora.mem_pool import LoRAMemoryPool
|
@@ -51,6 +51,8 @@ class LoRAManager:
|
|
51
51
|
load_config: LoadConfig,
|
52
52
|
dtype: torch.dtype,
|
53
53
|
lora_backend: str = "triton",
|
54
|
+
tp_size: int = 1,
|
55
|
+
tp_rank: int = 0,
|
54
56
|
):
|
55
57
|
self.base_model: torch.nn.Module = base_model
|
56
58
|
self.lora_paths: Dict[str, str] = lora_paths
|
@@ -58,6 +60,9 @@ class LoRAManager:
|
|
58
60
|
self.max_loras_per_batch: int = max_loras_per_batch
|
59
61
|
self.load_config: LoadConfig = load_config
|
60
62
|
self.dtype: torch.dtype = dtype
|
63
|
+
self.device: torch.device = next(self.base_model.parameters()).device
|
64
|
+
self.tp_size: int = tp_size
|
65
|
+
self.tp_rank: int = tp_rank
|
61
66
|
|
62
67
|
# LoRA backend for running sgemm kernels
|
63
68
|
logger.info(f"Using {lora_backend} as backend of LoRA kernels.")
|
@@ -98,11 +103,14 @@ class LoRAManager:
|
|
98
103
|
self.loras[name] = lora_adapter
|
99
104
|
|
100
105
|
# misc lora configs
|
101
|
-
# FIXME remove the restrictions after implementing unified paging
|
102
106
|
self.max_lora_dim: int = max([x.hf_config["r"] for x in self.configs.values()])
|
103
|
-
|
104
|
-
|
105
|
-
|
107
|
+
|
108
|
+
if self.lora_backend == "flashinfer":
|
109
|
+
# FIXME remove the restrictions after supporting multi-rank for flashinfer backend
|
110
|
+
max_lora_dim = max([x.hf_config["r"] for x in self.configs.values()])
|
111
|
+
scaling = list(self.loras.values())[0].scaling
|
112
|
+
assert all(x.hf_config["r"] == max_lora_dim for x in self.configs.values())
|
113
|
+
assert all(x.scaling == scaling for x in self.loras.values())
|
106
114
|
|
107
115
|
# Convert original model layers to layers with LoRA
|
108
116
|
self.convert_to_lora_layers()
|
@@ -110,7 +118,13 @@ class LoRAManager:
|
|
110
118
|
def init_lora_memory_pool(self):
|
111
119
|
# Initialize memory pool
|
112
120
|
self.memory_pool = LoRAMemoryPool(
|
113
|
-
self.base_hf_config,
|
121
|
+
self.base_hf_config,
|
122
|
+
self.max_loras_per_batch,
|
123
|
+
self.max_lora_dim,
|
124
|
+
self.dtype,
|
125
|
+
self.tp_size,
|
126
|
+
self.tp_rank,
|
127
|
+
self.lora_modules,
|
114
128
|
)
|
115
129
|
|
116
130
|
# Initialize target lora modules in memory pool
|
@@ -131,14 +145,24 @@ class LoRAManager:
|
|
131
145
|
seg_lens = (
|
132
146
|
forward_batch.extend_seq_lens
|
133
147
|
if forward_batch.forward_mode.is_extend()
|
134
|
-
else torch.ones(bs, device=
|
148
|
+
else torch.ones(bs, device=self.device)
|
135
149
|
)
|
136
|
-
seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=
|
150
|
+
seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=self.device)
|
137
151
|
seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)
|
138
152
|
max_len = int(torch.max(seg_lens))
|
139
|
-
weight_indices = torch.empty((bs,), dtype=torch.int64, device=
|
153
|
+
weight_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
|
154
|
+
|
155
|
+
lora_ranks = torch.empty(
|
156
|
+
(self.max_loras_per_batch,), dtype=torch.int64, device="cuda"
|
157
|
+
)
|
158
|
+
scalings = torch.empty(
|
159
|
+
(self.max_loras_per_batch,), dtype=torch.float, device="cuda"
|
160
|
+
)
|
140
161
|
for i, lora_path in enumerate(forward_batch.lora_paths):
|
141
162
|
weight_indices[i] = self.memory_pool.get_buffer_id(lora_path)
|
163
|
+
lora = self.loras[lora_path]
|
164
|
+
lora_ranks[weight_indices[i]] = lora.config.hf_config["r"]
|
165
|
+
scalings[weight_indices[i]] = lora.scaling
|
142
166
|
|
143
167
|
batch_info = LoRABatchInfo(
|
144
168
|
bs=bs,
|
@@ -146,31 +170,41 @@ class LoRAManager:
|
|
146
170
|
seg_indptr=seg_indptr,
|
147
171
|
max_len=max_len,
|
148
172
|
weight_indices=weight_indices,
|
173
|
+
lora_ranks=lora_ranks,
|
174
|
+
scalings=scalings,
|
149
175
|
)
|
150
176
|
self.lora_backend.set_batch_info(batch_info)
|
151
177
|
|
152
178
|
# call set_lora_info for each lora modules
|
153
|
-
for
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
179
|
+
for layer_id, modules in self.lora_modules.items():
|
180
|
+
for module_name, module in modules:
|
181
|
+
if "qkv_proj" in module_name:
|
182
|
+
module.set_lora_info(
|
183
|
+
self.memory_pool.get_tensor(
|
184
|
+
"qkv_proj", layer_id, LoRAType.LORA_A
|
185
|
+
),
|
186
|
+
self.memory_pool.get_tensor(
|
187
|
+
"q_proj", layer_id, LoRAType.LORA_B
|
188
|
+
),
|
189
|
+
self.memory_pool.get_tensor(
|
190
|
+
"kv_proj", layer_id, LoRAType.LORA_B
|
191
|
+
),
|
192
|
+
)
|
193
|
+
else:
|
194
|
+
weight_name = get_weight_name(
|
195
|
+
module_name, self.lora_weight_names, LoRAType.LORA_A
|
196
|
+
)
|
197
|
+
module.set_lora_info(
|
198
|
+
self.memory_pool.get_tensor(
|
199
|
+
weight_name, layer_id, LoRAType.LORA_A
|
200
|
+
),
|
201
|
+
self.memory_pool.get_tensor(
|
202
|
+
weight_name, layer_id, LoRAType.LORA_B
|
203
|
+
),
|
204
|
+
)
|
169
205
|
|
170
206
|
def set_lora_module(self, module_name, module):
|
171
|
-
lora_module = get_lora_layer(
|
172
|
-
module, self.max_lora_dim, self.scaling, self.lora_backend
|
173
|
-
)
|
207
|
+
lora_module = get_lora_layer(module, self.lora_backend)
|
174
208
|
replace_submodule(self.base_model, module_name, lora_module)
|
175
209
|
return lora_module
|
176
210
|
|
@@ -182,10 +216,13 @@ class LoRAManager:
|
|
182
216
|
)
|
183
217
|
|
184
218
|
# Monkey patch to use the LoRA version layers
|
185
|
-
self.lora_modules: List[Tuple[str,
|
219
|
+
self.lora_modules: Dict[int, List[Tuple[str, BaseLayerWithLoRA]]] = {
|
220
|
+
i: [] for i in range(self.base_hf_config.num_hidden_layers)
|
221
|
+
}
|
186
222
|
for module_name, module in self.base_model.named_modules():
|
187
223
|
# The module should be converted if it is included in target_names
|
188
224
|
if module_name.split(".")[-1] in customized_target_names:
|
189
|
-
|
225
|
+
layer_id = get_layer_id(module_name)
|
226
|
+
self.lora_modules[layer_id].append(
|
190
227
|
(module_name, self.set_lora_module(module_name, module))
|
191
228
|
)
|