sglang 0.4.6.post1__py3-none-any.whl → 0.4.6.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -11
- sglang/bench_serving.py +149 -1
- sglang/check_env.py +3 -3
- sglang/lang/chat_template.py +44 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/deepseekvl2.py +3 -0
- sglang/srt/configs/device_config.py +1 -1
- sglang/srt/configs/internvl.py +696 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/kimi_vl.py +38 -0
- sglang/srt/configs/kimi_vl_moonvit.py +32 -0
- sglang/srt/configs/model_config.py +32 -0
- sglang/srt/constrained/xgrammar_backend.py +11 -19
- sglang/srt/conversation.py +151 -3
- sglang/srt/disaggregation/decode.py +4 -1
- sglang/srt/disaggregation/mini_lb.py +74 -23
- sglang/srt/disaggregation/mooncake/conn.py +9 -18
- sglang/srt/disaggregation/nixl/conn.py +241 -71
- sglang/srt/disaggregation/utils.py +44 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
- sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
- sglang/srt/distributed/device_communicators/pynccl.py +2 -1
- sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
- sglang/srt/distributed/parallel_state.py +22 -1
- sglang/srt/entrypoints/engine.py +58 -24
- sglang/srt/entrypoints/http_server.py +28 -1
- sglang/srt/entrypoints/verl_engine.py +3 -2
- sglang/srt/function_call_parser.py +97 -0
- sglang/srt/hf_transformers_utils.py +22 -1
- sglang/srt/layers/attention/cutlass_mla_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +146 -50
- sglang/srt/layers/attention/flashinfer_backend.py +129 -94
- sglang/srt/layers/attention/flashinfer_mla_backend.py +88 -30
- sglang/srt/layers/attention/flashmla_backend.py +3 -0
- sglang/srt/layers/attention/merge_state.py +46 -0
- sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
- sglang/srt/layers/attention/vision.py +290 -163
- sglang/srt/layers/dp_attention.py +5 -2
- sglang/srt/layers/moe/ep_moe/kernels.py +342 -7
- sglang/srt/layers/moe/ep_moe/layer.py +120 -1
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +98 -57
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +10 -5
- sglang/srt/layers/quantization/__init__.py +2 -2
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
- sglang/srt/layers/quantization/deep_gemm.py +6 -1
- sglang/srt/layers/quantization/fp8.py +108 -95
- sglang/srt/layers/quantization/fp8_kernel.py +79 -60
- sglang/srt/layers/quantization/fp8_utils.py +71 -23
- sglang/srt/layers/quantization/kv_cache.py +3 -10
- sglang/srt/layers/quantization/utils.py +0 -5
- sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
- sglang/srt/layers/utils.py +35 -0
- sglang/srt/lora/layers.py +35 -9
- sglang/srt/lora/lora_manager.py +81 -35
- sglang/srt/managers/cache_controller.py +115 -119
- sglang/srt/managers/data_parallel_controller.py +52 -34
- sglang/srt/managers/io_struct.py +10 -0
- sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
- sglang/srt/managers/multimodal_processors/internvl.py +232 -0
- sglang/srt/managers/multimodal_processors/kimi_vl.py +73 -0
- sglang/srt/managers/schedule_batch.py +44 -16
- sglang/srt/managers/schedule_policy.py +11 -5
- sglang/srt/managers/scheduler.py +291 -72
- sglang/srt/managers/scheduler_output_processor_mixin.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +24 -13
- sglang/srt/managers/tp_worker.py +60 -28
- sglang/srt/managers/tp_worker_overlap_thread.py +9 -3
- sglang/srt/mem_cache/chunk_cache.py +2 -0
- sglang/srt/mem_cache/memory_pool.py +70 -36
- sglang/srt/model_executor/cuda_graph_runner.py +82 -19
- sglang/srt/model_executor/forward_batch_info.py +31 -1
- sglang/srt/model_executor/model_runner.py +159 -90
- sglang/srt/model_loader/loader.py +18 -11
- sglang/srt/models/clip.py +4 -4
- sglang/srt/models/deepseek_janus_pro.py +1 -1
- sglang/srt/models/deepseek_nextn.py +2 -277
- sglang/srt/models/deepseek_v2.py +132 -37
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/internlm2.py +3 -0
- sglang/srt/models/internvl.py +670 -0
- sglang/srt/models/kimi_vl.py +308 -0
- sglang/srt/models/kimi_vl_moonvit.py +639 -0
- sglang/srt/models/llama.py +93 -31
- sglang/srt/models/llama4.py +54 -7
- sglang/srt/models/llama_eagle.py +4 -1
- sglang/srt/models/llama_eagle3.py +4 -1
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/phi3_small.py +16 -2
- sglang/srt/models/qwen2_5_vl.py +8 -4
- sglang/srt/models/qwen2_moe.py +8 -3
- sglang/srt/models/qwen2_vl.py +4 -16
- sglang/srt/models/qwen3_moe.py +8 -3
- sglang/srt/models/xiaomi_mimo.py +171 -0
- sglang/srt/openai_api/adapter.py +58 -62
- sglang/srt/openai_api/protocol.py +38 -16
- sglang/srt/reasoning_parser.py +2 -2
- sglang/srt/sampling/sampling_batch_info.py +54 -2
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server_args.py +93 -24
- sglang/srt/speculative/eagle_worker.py +3 -2
- sglang/srt/utils.py +123 -10
- sglang/test/runners.py +4 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_deepep_utils.py +219 -0
- sglang/test/test_utils.py +32 -1
- sglang/version.py +1 -1
- {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/METADATA +18 -9
- {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/RECORD +119 -99
- {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/WHEEL +1 -1
- {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/top_level.txt +0 -0
@@ -14,6 +14,9 @@ except ImportError:
|
|
14
14
|
|
15
15
|
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
|
16
16
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
17
|
+
fp8_dtype,
|
18
|
+
fp8_max,
|
19
|
+
is_fp8_fnuz,
|
17
20
|
per_token_group_quant_fp8,
|
18
21
|
scaled_fp8_quant,
|
19
22
|
sglang_per_token_quant_fp8,
|
@@ -30,8 +33,11 @@ from sglang.srt.utils import (
|
|
30
33
|
|
31
34
|
_is_hip = is_hip()
|
32
35
|
_is_cuda = is_cuda()
|
36
|
+
_is_fp8_fnuz = is_fp8_fnuz()
|
33
37
|
|
34
|
-
|
38
|
+
use_aiter_moe = get_bool_env_var("SGLANG_AITER_MOE")
|
39
|
+
|
40
|
+
if _is_hip and use_aiter_moe:
|
35
41
|
from aiter import gemm_a8w8_blockscale
|
36
42
|
|
37
43
|
if _is_cuda:
|
@@ -43,19 +49,23 @@ use_vllm_cutlass_w8a8_fp8_kernel = get_bool_env_var("USE_VLLM_CUTLASS_W8A8_FP8_K
|
|
43
49
|
# from pytorch 2.5. Allocating a dummy tensor to pass as input_scale
|
44
50
|
TORCH_DEVICE_IDENTITY = None
|
45
51
|
|
46
|
-
|
47
|
-
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
# The condition is
|
55
|
-
#
|
56
|
-
|
57
|
-
|
58
|
-
)
|
52
|
+
|
53
|
+
def use_rowwise_torch_scaled_mm():
|
54
|
+
_TORCH_VERSION = torch.__version__.split("+")[0]
|
55
|
+
try:
|
56
|
+
_TORCH_VERSION_TUPLE = tuple(map(int, _TORCH_VERSION.split(".")[:3]))
|
57
|
+
except ValueError:
|
58
|
+
_TORCH_VERSION_TUPLE = (0, 0, 0)
|
59
|
+
if _is_hip:
|
60
|
+
# The condition to determine if it is on a platform that supports
|
61
|
+
# torch._scaled_mm rowwise feature.
|
62
|
+
# The condition is determined once as the operations
|
63
|
+
# are time consuming.
|
64
|
+
return get_device_capability() >= (9, 4) and _TORCH_VERSION_TUPLE >= (2, 7, 0)
|
65
|
+
return False
|
66
|
+
|
67
|
+
|
68
|
+
USE_ROWWISE_TORCH_SCALED_MM = use_rowwise_torch_scaled_mm()
|
59
69
|
|
60
70
|
|
61
71
|
def cutlass_fp8_supported():
|
@@ -132,7 +142,7 @@ def apply_w8a8_block_fp8_linear(
|
|
132
142
|
output = fp8_blockwise_scaled_mm(
|
133
143
|
q_input, weight.T, x_scale, weight_scale.T, out_dtype=input.dtype
|
134
144
|
)
|
135
|
-
elif _is_hip and
|
145
|
+
elif _is_hip and use_aiter_moe:
|
136
146
|
q_input, x_scale = per_token_group_quant_fp8(
|
137
147
|
input_2d, block_size[1], column_major_scales=False
|
138
148
|
)
|
@@ -164,18 +174,21 @@ def apply_w8a8_block_fp8_linear(
|
|
164
174
|
|
165
175
|
|
166
176
|
def input_to_float8(
|
167
|
-
x: torch.Tensor, dtype: torch.dtype =
|
177
|
+
x: torch.Tensor, dtype: torch.dtype = fp8_dtype
|
168
178
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
169
179
|
"""This function quantizes input values to float8 values with tensor-wise quantization."""
|
170
|
-
finfo = torch.finfo(dtype)
|
171
180
|
min_val, max_val = x.aminmax()
|
172
181
|
amax = torch.maximum(min_val.abs(), max_val.abs()).float().clamp(min=1e-12)
|
173
|
-
|
174
|
-
if
|
175
|
-
dtype =
|
176
|
-
|
177
|
-
|
178
|
-
|
182
|
+
|
183
|
+
if _is_fp8_fnuz:
|
184
|
+
dtype = fp8_dtype
|
185
|
+
fp_max = fp8_max
|
186
|
+
else:
|
187
|
+
finfo = torch.finfo(dtype)
|
188
|
+
fp_max = finfo.max
|
189
|
+
|
190
|
+
scale = fp_max / amax
|
191
|
+
x_scl_sat = (x.float() * scale).clamp(min=-fp_max, max=fp_max)
|
179
192
|
return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
|
180
193
|
|
181
194
|
|
@@ -222,6 +235,41 @@ def block_quant_to_tensor_quant(
|
|
222
235
|
return x_q_tensor, scale
|
223
236
|
|
224
237
|
|
238
|
+
def block_quant_dequant(
|
239
|
+
x_q_block: torch.Tensor,
|
240
|
+
x_s: torch.Tensor,
|
241
|
+
block_size: List[int],
|
242
|
+
dtype: torch.dtype,
|
243
|
+
) -> torch.Tensor:
|
244
|
+
"""This function converts block-wise quantization to unquantized.
|
245
|
+
The inputs are block-wise quantization tensor `x_q_block`, block-wise quantization scale
|
246
|
+
and the block size.
|
247
|
+
The output is an unquantized tensor with dtype.
|
248
|
+
"""
|
249
|
+
block_n, block_k = block_size[0], block_size[1]
|
250
|
+
n, k = x_q_block.shape
|
251
|
+
n_tiles = (n + block_n - 1) // block_n
|
252
|
+
k_tiles = (k + block_k - 1) // block_k
|
253
|
+
assert n_tiles == x_s.shape[0]
|
254
|
+
assert k_tiles == x_s.shape[1]
|
255
|
+
|
256
|
+
x_dq_block = torch.empty_like(x_q_block, dtype=dtype)
|
257
|
+
|
258
|
+
for j in range(n_tiles):
|
259
|
+
for i in range(k_tiles):
|
260
|
+
x_q_block_tile = x_q_block[
|
261
|
+
j * block_n : min((j + 1) * block_n, n),
|
262
|
+
i * block_k : min((i + 1) * block_k, k),
|
263
|
+
]
|
264
|
+
x_dq_block_tile = x_dq_block[
|
265
|
+
j * block_n : min((j + 1) * block_n, n),
|
266
|
+
i * block_k : min((i + 1) * block_k, k),
|
267
|
+
]
|
268
|
+
x_dq_block_tile[:, :] = x_q_block_tile.to(torch.float32) * x_s[j][i]
|
269
|
+
|
270
|
+
return x_dq_block
|
271
|
+
|
272
|
+
|
225
273
|
def channel_quant_to_tensor_quant(
|
226
274
|
x_q_channel: torch.Tensor,
|
227
275
|
x_s: torch.Tensor,
|
@@ -8,10 +8,8 @@ from sglang.srt.layers.quantization.base_config import (
|
|
8
8
|
QuantizationConfig,
|
9
9
|
QuantizeMethodBase,
|
10
10
|
)
|
11
|
+
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz
|
11
12
|
from sglang.srt.layers.radix_attention import RadixAttention
|
12
|
-
from sglang.srt.utils import is_hip
|
13
|
-
|
14
|
-
_is_hip = is_hip()
|
15
13
|
|
16
14
|
logger = logging.getLogger(__name__)
|
17
15
|
|
@@ -44,11 +42,6 @@ class BaseKVCacheMethod(QuantizeMethodBase):
|
|
44
42
|
torch.tensor(-1.0, dtype=torch.float32), requires_grad=False
|
45
43
|
)
|
46
44
|
|
47
|
-
@classmethod
|
48
|
-
def is_fp8_fnuz(cls) -> bool:
|
49
|
-
# only device 0 is checked, this assumes MI300 platforms are homogeneous
|
50
|
-
return "gfx94" in torch.cuda.get_device_properties(0).gcnArchName
|
51
|
-
|
52
45
|
def apply(self, layer: torch.nn.Module) -> torch.Tensor:
|
53
46
|
raise RuntimeError(f"{self.__class__.__name__}.apply should not be called.")
|
54
47
|
|
@@ -57,7 +50,7 @@ class BaseKVCacheMethod(QuantizeMethodBase):
|
|
57
50
|
# We prefer to use separate k_scale and v_scale if present
|
58
51
|
k_scale = layer.k_scale.to("cpu").tolist()
|
59
52
|
v_scale = layer.v_scale.to("cpu").tolist()
|
60
|
-
if
|
53
|
+
if is_fp8_fnuz():
|
61
54
|
k_scale *= 2
|
62
55
|
v_scale *= 2
|
63
56
|
elif layer.k_scale < 0.0 and layer.v_scale < 0.0:
|
@@ -73,7 +66,7 @@ class BaseKVCacheMethod(QuantizeMethodBase):
|
|
73
66
|
scale_to_duplicate = max(layer.k_scale, layer.v_scale)
|
74
67
|
k_scale = scale_to_duplicate.to("cpu").tolist()
|
75
68
|
v_scale = scale_to_duplicate.to("cpu").tolist()
|
76
|
-
if
|
69
|
+
if is_fp8_fnuz():
|
77
70
|
k_scale *= 2
|
78
71
|
v_scale *= 2
|
79
72
|
|
@@ -14,11 +14,6 @@ if not _is_cuda:
|
|
14
14
|
from vllm._custom_ops import scaled_fp8_quant
|
15
15
|
|
16
16
|
|
17
|
-
def is_fp8_fnuz() -> bool:
|
18
|
-
# only device 0 is checked, this assumes MI300 platforms are homogeneous
|
19
|
-
return "gfx94" in torch.cuda.get_device_properties(0).gcnArchName
|
20
|
-
|
21
|
-
|
22
17
|
def is_layer_skipped(
|
23
18
|
prefix: str,
|
24
19
|
ignored_layers: List[str],
|
@@ -9,16 +9,20 @@ from sglang.srt.layers.quantization.base_config import (
|
|
9
9
|
QuantizationConfig,
|
10
10
|
QuantizeMethodBase,
|
11
11
|
)
|
12
|
-
from sglang.srt.layers.quantization.fp8_kernel import
|
12
|
+
from sglang.srt.layers.quantization.fp8_kernel import (
|
13
|
+
fp8_dtype,
|
14
|
+
is_fp8_fnuz,
|
15
|
+
per_token_group_quant_fp8,
|
16
|
+
)
|
13
17
|
from sglang.srt.layers.quantization.fp8_utils import (
|
14
18
|
apply_fp8_linear,
|
15
19
|
cutlass_fp8_supported,
|
16
20
|
input_to_float8,
|
17
21
|
normalize_e4m3fn_to_e4m3fnuz,
|
18
22
|
)
|
19
|
-
from sglang.srt.utils import
|
23
|
+
from sglang.srt.utils import set_weight_attrs
|
20
24
|
|
21
|
-
|
25
|
+
_is_fp8_fnuz = is_fp8_fnuz()
|
22
26
|
|
23
27
|
|
24
28
|
class W8A8Fp8Config(QuantizationConfig):
|
@@ -97,7 +101,7 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
|
|
97
101
|
if self.quantization_config.is_checkpoint_fp8_serialized:
|
98
102
|
weight_scale = layer.weight_scale.detach()
|
99
103
|
# If checkpoint offline quantized with w8a8_fp8, load the weight and weight_scale directly.
|
100
|
-
if
|
104
|
+
if _is_fp8_fnuz:
|
101
105
|
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
102
106
|
weight=weight, weight_scale=weight_scale
|
103
107
|
)
|
@@ -113,14 +117,9 @@ class W8A8Fp8LinearMethod(LinearMethodBase):
|
|
113
117
|
layer.weight, layer.weight.shape[-1]
|
114
118
|
)
|
115
119
|
weight_scale = weight_scale.t().contiguous()
|
116
|
-
if _is_hip:
|
117
|
-
weight, weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
|
118
|
-
weight=weight, weight_scale=weight_scale
|
119
|
-
)
|
120
120
|
else:
|
121
121
|
# if cutlass not supported, we fall back to use torch._scaled_mm
|
122
122
|
# which requires per tensor quantization on weight
|
123
|
-
fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
124
123
|
qweight, weight_scale = input_to_float8(layer.weight, dtype=fp8_dtype)
|
125
124
|
|
126
125
|
# Update the layer with the new values.
|
@@ -227,7 +226,6 @@ class W8A8FP8MoEMethod:
|
|
227
226
|
):
|
228
227
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
229
228
|
|
230
|
-
fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
|
231
229
|
# WEIGHTS
|
232
230
|
w13_weight = torch.nn.Parameter(
|
233
231
|
torch.empty(
|
@@ -0,0 +1,35 @@
|
|
1
|
+
import logging
|
2
|
+
import re
|
3
|
+
|
4
|
+
import torch
|
5
|
+
|
6
|
+
logger = logging.getLogger(__name__)
|
7
|
+
|
8
|
+
|
9
|
+
def get_layer_id(weight_name):
|
10
|
+
# example weight name: model.layers.10.self_attn.qkv_proj.weight
|
11
|
+
match = re.search(r"layers\.(\d+)\.", weight_name)
|
12
|
+
if match:
|
13
|
+
return int(match.group(1))
|
14
|
+
return None
|
15
|
+
|
16
|
+
|
17
|
+
class PPMissingLayer(torch.nn.Identity):
|
18
|
+
# Adapted from
|
19
|
+
# https://github.com/vllm-project/vllm/blob/18ed3132d2bfe1df9a74729457b69243955221e8/vllm/model_executor/models/utils.py#L468C1-L486C1
|
20
|
+
"""
|
21
|
+
A placeholder layer for missing layers in a pipeline parallel model.
|
22
|
+
"""
|
23
|
+
|
24
|
+
def __init__(self, *args, **kwargs):
|
25
|
+
super().__init__()
|
26
|
+
self.return_tuple = kwargs.get("return_tuple", False)
|
27
|
+
|
28
|
+
def forward(self, *args, **kwargs):
|
29
|
+
"""
|
30
|
+
Return the first arg from args or the first value from kwargs.
|
31
|
+
|
32
|
+
Wraps the input in a tuple if `self.return_tuple` is True.
|
33
|
+
"""
|
34
|
+
input = args[0] if args else next(iter(kwargs.values()))
|
35
|
+
return (input,) if self.return_tuple else input
|
sglang/srt/lora/layers.py
CHANGED
@@ -136,11 +136,19 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|
136
136
|
self.set_lora = True
|
137
137
|
self.A_buffer_gate_up = A_buffer
|
138
138
|
if self.lora_backend.fuse_stacked_lora_b:
|
139
|
-
# TODO: avoid using contiguous() in GPU.
|
140
139
|
# B_buffer_gate_up: (num_lora, 2 * output_dim, r)
|
141
|
-
self.B_buffer_gate_up
|
142
|
-
|
143
|
-
|
140
|
+
if not hasattr(self, "B_buffer_gate_up") or self.B_buffer_gate_up is None:
|
141
|
+
self.B_buffer_gate_up = torch.empty(
|
142
|
+
(
|
143
|
+
B_buffer[0].shape[0],
|
144
|
+
2 * B_buffer[0].shape[1],
|
145
|
+
B_buffer[0].shape[2],
|
146
|
+
),
|
147
|
+
dtype=B_buffer[0].dtype,
|
148
|
+
device=B_buffer[0].device,
|
149
|
+
)
|
150
|
+
self.B_buffer_gate_up[:, : B_buffer[0].shape[1], :].copy_(B_buffer[0])
|
151
|
+
self.B_buffer_gate_up[:, B_buffer[0].shape[1] :, :].copy_(B_buffer[1])
|
144
152
|
else:
|
145
153
|
self.B_buffer_gate_up = (B_buffer[0], B_buffer[1])
|
146
154
|
|
@@ -171,7 +179,7 @@ class MergedColumnParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|
171
179
|
|
172
180
|
|
173
181
|
class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
174
|
-
def
|
182
|
+
def __init__(
|
175
183
|
self,
|
176
184
|
base_layer: QKVParallelLinear,
|
177
185
|
lora_backend: BaseLoRABackend,
|
@@ -194,12 +202,30 @@ class QKVParallelLinearWithLoRA(ColumnParallelLinearWithLoRA):
|
|
194
202
|
output_dim_q, output_dim_kv = B_buffer_q.shape[-2], B_buffer_kv.shape[-2]
|
195
203
|
|
196
204
|
# B_buffer_qkv: (num_lora, output_dim_q + 2 * output_dim_kv, r)
|
197
|
-
self.B_buffer_qkv
|
198
|
-
|
199
|
-
|
205
|
+
if not hasattr(self, "B_buffer_qkv") or self.B_buffer_qkv is None:
|
206
|
+
self.B_buffer_qkv = torch.empty(
|
207
|
+
(
|
208
|
+
B_buffer_q[0].shape[0],
|
209
|
+
output_dim_q + 2 * output_dim_kv,
|
210
|
+
B_buffer_q[0].shape[2],
|
211
|
+
),
|
212
|
+
dtype=B_buffer_q[0].dtype,
|
213
|
+
device=B_buffer_q[0].device,
|
214
|
+
)
|
215
|
+
self.B_buffer_qkv[:, :output_dim_q, :].copy_(B_buffer_q[0])
|
216
|
+
self.B_buffer_qkv[:, output_dim_q : output_dim_q + output_dim_kv, :].copy_(
|
217
|
+
B_buffer_kv[0]
|
218
|
+
)
|
219
|
+
self.B_buffer_qkv[:, output_dim_q + output_dim_kv :, :].copy_(
|
220
|
+
B_buffer_kv[1]
|
221
|
+
)
|
200
222
|
|
201
223
|
# Offsets of q/k/v in output dimension
|
202
|
-
self.output_offset
|
224
|
+
if not hasattr(self, "output_offset") or self.output_offset is None:
|
225
|
+
self.output_offset = torch.empty(
|
226
|
+
4, dtype=torch.int32, device=B_buffer_q.device
|
227
|
+
)
|
228
|
+
self.output_offset[:4] = torch.tensor(
|
203
229
|
[
|
204
230
|
0,
|
205
231
|
output_dim_q,
|
sglang/srt/lora/lora_manager.py
CHANGED
@@ -72,6 +72,23 @@ class LoRAManager:
|
|
72
72
|
self.init_loras()
|
73
73
|
self.init_lora_memory_pool()
|
74
74
|
|
75
|
+
def init_cuda_graph_batch_info(self, max_bs_in_cuda_graph: int):
|
76
|
+
self.max_bs_in_cuda_graph = max_bs_in_cuda_graph
|
77
|
+
with torch.device("cuda"):
|
78
|
+
self.cuda_graph_batch_info = LoRABatchInfo(
|
79
|
+
bs=self.max_bs_in_cuda_graph,
|
80
|
+
seg_lens=torch.zeros(self.max_bs_in_cuda_graph, dtype=torch.int32),
|
81
|
+
seg_indptr=torch.zeros(
|
82
|
+
self.max_bs_in_cuda_graph + 1, dtype=torch.int32
|
83
|
+
),
|
84
|
+
max_len=0,
|
85
|
+
weight_indices=torch.zeros(
|
86
|
+
self.max_bs_in_cuda_graph, dtype=torch.int32
|
87
|
+
),
|
88
|
+
lora_ranks=torch.zeros(self.max_loras_per_batch, dtype=torch.int32),
|
89
|
+
scalings=torch.zeros(self.max_loras_per_batch, dtype=torch.float),
|
90
|
+
)
|
91
|
+
|
75
92
|
def init_loras(self):
|
76
93
|
# Config of each LoRA adapter
|
77
94
|
self.configs: Dict[str, LoRAConfig] = {}
|
@@ -136,43 +153,72 @@ class LoRAManager:
|
|
136
153
|
assert len(cur_uids) <= self.max_loras_per_batch
|
137
154
|
self.memory_pool.prepare_lora_batch(cur_uids, self.loras)
|
138
155
|
|
139
|
-
#
|
140
|
-
if cur_uids == set([None]):
|
141
|
-
return
|
142
|
-
|
143
|
-
# set up batch info shared by all lora moruldes
|
156
|
+
# set up batch info shared by all lora modules
|
144
157
|
bs = forward_batch.batch_size
|
145
|
-
seg_lens = (
|
146
|
-
forward_batch.extend_seq_lens
|
147
|
-
if forward_batch.forward_mode.is_extend()
|
148
|
-
else torch.ones(bs, device=self.device)
|
149
|
-
)
|
150
|
-
seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=self.device)
|
151
|
-
seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)
|
152
|
-
max_len = int(torch.max(seg_lens))
|
153
|
-
weight_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
|
154
158
|
|
155
|
-
|
156
|
-
(self
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
165
|
-
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
159
|
+
if (
|
160
|
+
hasattr(self, "max_bs_in_cuda_graph")
|
161
|
+
and bs <= self.max_bs_in_cuda_graph
|
162
|
+
and forward_batch.forward_mode.is_cuda_graph()
|
163
|
+
):
|
164
|
+
# Do in-place updates when CUDA graph is enabled and the batch forward mode
|
165
|
+
# could use CUDA graph.
|
166
|
+
self.cuda_graph_batch_info.bs = bs
|
167
|
+
self.cuda_graph_batch_info.seg_lens[:bs].fill_(1)
|
168
|
+
torch.cumsum(
|
169
|
+
self.cuda_graph_batch_info.seg_lens[:bs],
|
170
|
+
dim=0,
|
171
|
+
out=self.cuda_graph_batch_info.seg_indptr[1 : bs + 1],
|
172
|
+
)
|
173
|
+
self.cuda_graph_batch_info.max_len = int(
|
174
|
+
torch.max(self.cuda_graph_batch_info.seg_lens[:bs])
|
175
|
+
)
|
176
|
+
|
177
|
+
for i, lora_path in enumerate(forward_batch.lora_paths):
|
178
|
+
self.cuda_graph_batch_info.weight_indices[i] = (
|
179
|
+
self.memory_pool.get_buffer_id(lora_path)
|
180
|
+
)
|
181
|
+
if lora_path is not None:
|
182
|
+
lora = self.loras[lora_path]
|
183
|
+
self.cuda_graph_batch_info.lora_ranks[
|
184
|
+
self.cuda_graph_batch_info.weight_indices[i]
|
185
|
+
] = lora.config.hf_config["r"]
|
186
|
+
self.cuda_graph_batch_info.scalings[
|
187
|
+
self.cuda_graph_batch_info.weight_indices[i]
|
188
|
+
] = lora.scaling
|
189
|
+
batch_info = self.cuda_graph_batch_info
|
190
|
+
else:
|
191
|
+
seg_lens = (
|
192
|
+
forward_batch.extend_seq_lens
|
193
|
+
if forward_batch.forward_mode.is_extend()
|
194
|
+
else torch.ones(bs, device=self.device)
|
195
|
+
)
|
196
|
+
seg_indptr = torch.zeros((bs + 1,), dtype=torch.int32, device=self.device)
|
197
|
+
seg_indptr[1:] = torch.cumsum(seg_lens, dim=0)
|
198
|
+
max_len = int(torch.max(seg_lens))
|
199
|
+
weight_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
|
200
|
+
|
201
|
+
lora_ranks = torch.zeros(
|
202
|
+
(self.max_loras_per_batch,), dtype=torch.int64, device="cuda"
|
203
|
+
)
|
204
|
+
scalings = torch.zeros(
|
205
|
+
(self.max_loras_per_batch,), dtype=torch.float, device="cuda"
|
206
|
+
)
|
207
|
+
for i, lora_path in enumerate(forward_batch.lora_paths):
|
208
|
+
weight_indices[i] = self.memory_pool.get_buffer_id(lora_path)
|
209
|
+
if lora_path is not None:
|
210
|
+
lora = self.loras[lora_path]
|
211
|
+
lora_ranks[weight_indices[i]] = lora.config.hf_config["r"]
|
212
|
+
scalings[weight_indices[i]] = lora.scaling
|
213
|
+
batch_info = LoRABatchInfo(
|
214
|
+
bs=bs,
|
215
|
+
seg_lens=seg_lens,
|
216
|
+
seg_indptr=seg_indptr,
|
217
|
+
max_len=max_len,
|
218
|
+
weight_indices=weight_indices,
|
219
|
+
lora_ranks=lora_ranks,
|
220
|
+
scalings=scalings,
|
221
|
+
)
|
176
222
|
self.lora_backend.set_batch_info(batch_info)
|
177
223
|
|
178
224
|
# call set_lora_info for each lora modules
|