sglang 0.4.8__py3-none-any.whl → 0.4.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +17 -2
- sglang/bench_serving.py +168 -22
- sglang/srt/configs/internvl.py +4 -2
- sglang/srt/configs/janus_pro.py +1 -1
- sglang/srt/configs/model_config.py +49 -0
- sglang/srt/configs/update_config.py +119 -0
- sglang/srt/conversation.py +35 -0
- sglang/srt/custom_op.py +7 -1
- sglang/srt/disaggregation/base/conn.py +2 -0
- sglang/srt/disaggregation/decode.py +22 -6
- sglang/srt/disaggregation/mooncake/conn.py +289 -48
- sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
- sglang/srt/disaggregation/nixl/conn.py +100 -52
- sglang/srt/disaggregation/prefill.py +5 -4
- sglang/srt/disaggregation/utils.py +13 -12
- sglang/srt/distributed/parallel_state.py +44 -17
- sglang/srt/entrypoints/EngineBase.py +8 -0
- sglang/srt/entrypoints/engine.py +45 -9
- sglang/srt/entrypoints/http_server.py +111 -24
- sglang/srt/entrypoints/openai/protocol.py +51 -6
- sglang/srt/entrypoints/openai/serving_chat.py +52 -76
- sglang/srt/entrypoints/openai/serving_completions.py +1 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/eplb/__init__.py +0 -0
- sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
- sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
- sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
- sglang/srt/{managers → eplb}/expert_distribution.py +18 -1
- sglang/srt/{managers → eplb}/expert_location.py +1 -1
- sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
- sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
- sglang/srt/hf_transformers_utils.py +2 -1
- sglang/srt/layers/activation.py +7 -0
- sglang/srt/layers/amx_utils.py +86 -0
- sglang/srt/layers/attention/ascend_backend.py +219 -0
- sglang/srt/layers/attention/flashattention_backend.py +56 -23
- sglang/srt/layers/attention/tbo_backend.py +37 -9
- sglang/srt/layers/communicator.py +18 -2
- sglang/srt/layers/dp_attention.py +9 -3
- sglang/srt/layers/elementwise.py +76 -12
- sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
- sglang/srt/layers/layernorm.py +41 -0
- sglang/srt/layers/linear.py +99 -12
- sglang/srt/layers/logits_processor.py +15 -6
- sglang/srt/layers/moe/ep_moe/kernels.py +23 -8
- sglang/srt/layers/moe/ep_moe/layer.py +115 -25
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +42 -19
- sglang/srt/layers/moe/fused_moe_native.py +7 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +8 -4
- sglang/srt/layers/moe/fused_moe_triton/layer.py +129 -10
- sglang/srt/layers/moe/router.py +60 -22
- sglang/srt/layers/moe/topk.py +36 -28
- sglang/srt/layers/parameter.py +67 -7
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
- sglang/srt/layers/quantization/fp8.py +44 -0
- sglang/srt/layers/quantization/fp8_kernel.py +1 -1
- sglang/srt/layers/quantization/fp8_utils.py +6 -6
- sglang/srt/layers/quantization/gptq.py +5 -1
- sglang/srt/layers/quantization/moe_wna16.py +1 -1
- sglang/srt/layers/quantization/quant_utils.py +166 -0
- sglang/srt/layers/quantization/w8a8_int8.py +52 -1
- sglang/srt/layers/rotary_embedding.py +105 -13
- sglang/srt/layers/vocab_parallel_embedding.py +19 -2
- sglang/srt/lora/lora.py +4 -5
- sglang/srt/lora/lora_manager.py +73 -20
- sglang/srt/managers/configure_logging.py +1 -1
- sglang/srt/managers/io_struct.py +60 -15
- sglang/srt/managers/mm_utils.py +73 -59
- sglang/srt/managers/multimodal_processor.py +2 -6
- sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
- sglang/srt/managers/schedule_batch.py +80 -79
- sglang/srt/managers/scheduler.py +153 -63
- sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
- sglang/srt/managers/session_controller.py +12 -3
- sglang/srt/managers/tokenizer_manager.py +314 -103
- sglang/srt/managers/tp_worker.py +13 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
- sglang/srt/mem_cache/allocator.py +290 -0
- sglang/srt/mem_cache/chunk_cache.py +34 -2
- sglang/srt/mem_cache/memory_pool.py +289 -3
- sglang/srt/mem_cache/multimodal_cache.py +3 -0
- sglang/srt/model_executor/cuda_graph_runner.py +3 -2
- sglang/srt/model_executor/forward_batch_info.py +17 -4
- sglang/srt/model_executor/model_runner.py +302 -58
- sglang/srt/model_loader/loader.py +86 -10
- sglang/srt/model_loader/weight_utils.py +160 -3
- sglang/srt/models/deepseek_nextn.py +5 -4
- sglang/srt/models/deepseek_v2.py +305 -26
- sglang/srt/models/deepseek_vl2.py +3 -5
- sglang/srt/models/gemma3_causal.py +1 -2
- sglang/srt/models/gemma3n_audio.py +949 -0
- sglang/srt/models/gemma3n_causal.py +1010 -0
- sglang/srt/models/gemma3n_mm.py +495 -0
- sglang/srt/models/hunyuan.py +771 -0
- sglang/srt/models/kimi_vl.py +1 -2
- sglang/srt/models/llama.py +10 -4
- sglang/srt/models/llama4.py +32 -45
- sglang/srt/models/llama_eagle3.py +61 -11
- sglang/srt/models/llava.py +5 -5
- sglang/srt/models/minicpmo.py +2 -2
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mllama4.py +43 -11
- sglang/srt/models/phi4mm.py +1 -3
- sglang/srt/models/pixtral.py +3 -7
- sglang/srt/models/qwen2.py +31 -3
- sglang/srt/models/qwen2_5_vl.py +1 -3
- sglang/srt/models/qwen2_audio.py +200 -0
- sglang/srt/models/qwen2_moe.py +32 -6
- sglang/srt/models/qwen2_vl.py +1 -4
- sglang/srt/models/qwen3.py +94 -25
- sglang/srt/models/qwen3_moe.py +68 -21
- sglang/srt/models/vila.py +3 -8
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +150 -133
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
- sglang/srt/multimodal/processors/gemma3n.py +82 -0
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +3 -6
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
- sglang/srt/operations_strategy.py +6 -2
- sglang/srt/reasoning_parser.py +26 -0
- sglang/srt/sampling/sampling_batch_info.py +39 -1
- sglang/srt/server_args.py +85 -24
- sglang/srt/speculative/build_eagle_tree.py +57 -18
- sglang/srt/speculative/eagle_worker.py +6 -4
- sglang/srt/two_batch_overlap.py +204 -28
- sglang/srt/utils.py +369 -138
- sglang/srt/warmup.py +12 -3
- sglang/test/runners.py +10 -1
- sglang/test/test_utils.py +15 -3
- sglang/version.py +1 -1
- {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/METADATA +9 -6
- {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/RECORD +149 -137
- sglang/math_utils.py +0 -8
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
- /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
- /sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +0 -0
- {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/WHEEL +0 -0
- {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,166 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py
|
3
|
+
|
4
|
+
from typing import Optional
|
5
|
+
|
6
|
+
import numpy
|
7
|
+
import torch
|
8
|
+
from sgl_kernel.scalar_type import ScalarType
|
9
|
+
|
10
|
+
|
11
|
+
def get_pack_factor(num_bits):
|
12
|
+
assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}"
|
13
|
+
return 32 // num_bits
|
14
|
+
|
15
|
+
|
16
|
+
def pack_cols(
|
17
|
+
q_w: torch.Tensor,
|
18
|
+
num_bits: int,
|
19
|
+
size_k: int,
|
20
|
+
size_n: int,
|
21
|
+
):
|
22
|
+
assert q_w.shape == (size_k, size_n)
|
23
|
+
|
24
|
+
pack_factor = get_pack_factor(num_bits)
|
25
|
+
assert size_n % pack_factor == 0
|
26
|
+
|
27
|
+
orig_device = q_w.device
|
28
|
+
|
29
|
+
q_w = q_w.cpu().numpy().astype(numpy.uint32)
|
30
|
+
|
31
|
+
q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32)
|
32
|
+
|
33
|
+
for i in range(pack_factor):
|
34
|
+
q_res |= q_w[:, i::pack_factor] << num_bits * i
|
35
|
+
|
36
|
+
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
|
37
|
+
q_res = q_res.contiguous()
|
38
|
+
|
39
|
+
return q_res
|
40
|
+
|
41
|
+
|
42
|
+
def unpack_cols(
|
43
|
+
packed_q_w: torch.Tensor,
|
44
|
+
num_bits: int,
|
45
|
+
size_k: int,
|
46
|
+
size_n: int,
|
47
|
+
):
|
48
|
+
pack_factor = get_pack_factor(num_bits)
|
49
|
+
assert size_n % pack_factor == 0
|
50
|
+
assert packed_q_w.shape == (
|
51
|
+
size_k,
|
52
|
+
size_n // pack_factor,
|
53
|
+
), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format(
|
54
|
+
packed_q_w.shape, size_k, size_n, pack_factor
|
55
|
+
)
|
56
|
+
|
57
|
+
orig_device = packed_q_w.device
|
58
|
+
|
59
|
+
packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32)
|
60
|
+
q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32)
|
61
|
+
|
62
|
+
mask = (1 << num_bits) - 1
|
63
|
+
for i in range(pack_factor):
|
64
|
+
vals = packed_q_w_cpu & mask
|
65
|
+
packed_q_w_cpu >>= num_bits
|
66
|
+
q_res[:, i::pack_factor] = vals
|
67
|
+
|
68
|
+
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
|
69
|
+
q_res = q_res.contiguous()
|
70
|
+
|
71
|
+
return q_res
|
72
|
+
|
73
|
+
|
74
|
+
def quantize_weights(
|
75
|
+
w: torch.Tensor,
|
76
|
+
quant_type: ScalarType,
|
77
|
+
group_size: Optional[int],
|
78
|
+
zero_points: bool = False,
|
79
|
+
ref_zero_points_after_scales: bool = False,
|
80
|
+
):
|
81
|
+
assert (
|
82
|
+
quant_type.is_integer()
|
83
|
+
), "Floating point quantization may work but has not been tested"
|
84
|
+
assert not zero_points or group_size is not None, (
|
85
|
+
"to have group zero points, group_size must be provided "
|
86
|
+
"(-1 group_size is channelwise)"
|
87
|
+
)
|
88
|
+
|
89
|
+
orig_device = w.device
|
90
|
+
orig_type = w.dtype
|
91
|
+
size_k, size_n = w.shape
|
92
|
+
|
93
|
+
assert w.is_floating_point(), "w must be float"
|
94
|
+
|
95
|
+
if group_size == -1:
|
96
|
+
group_size = size_k
|
97
|
+
|
98
|
+
# Reshape to [groupsize, -1]
|
99
|
+
if group_size is not None and group_size < size_k:
|
100
|
+
w = w.reshape((-1, group_size, size_n))
|
101
|
+
w = w.permute(1, 0, 2)
|
102
|
+
w = w.reshape((group_size, -1))
|
103
|
+
|
104
|
+
# Compute scale for each group
|
105
|
+
max_val = torch.max(w, 0, keepdim=True).values
|
106
|
+
min_val = torch.min(w, 0, keepdim=True).values
|
107
|
+
|
108
|
+
max_q_val = quant_type.max()
|
109
|
+
min_q_val = quant_type.min()
|
110
|
+
|
111
|
+
w_s = torch.Tensor([1.0]).to(w.device) # unscaled case
|
112
|
+
maybe_w_zp = None
|
113
|
+
if group_size is not None:
|
114
|
+
if zero_points:
|
115
|
+
assert not quant_type.is_signed() and quant_type.max() > 0
|
116
|
+
w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max()
|
117
|
+
maybe_w_zp = (
|
118
|
+
torch.round(torch.abs(min_val / w_s)).clamp(min_q_val, max_q_val).int()
|
119
|
+
)
|
120
|
+
else:
|
121
|
+
# If the bias is such that there are no possible negative/positive
|
122
|
+
# values, set the max value to inf to avoid divide by 0
|
123
|
+
w_s = torch.max(
|
124
|
+
abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)),
|
125
|
+
abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)),
|
126
|
+
)
|
127
|
+
|
128
|
+
# Quantize
|
129
|
+
w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0)
|
130
|
+
w_q = torch.clamp(w_q, min_q_val, max_q_val)
|
131
|
+
|
132
|
+
# Compute ref (dequantized)
|
133
|
+
# For some kernels (namely Machete) the zero-points are applied after the
|
134
|
+
# scales are applied, for this case computing the reference in similar way
|
135
|
+
# allows us to use tighter error tolerances in our unit tests.
|
136
|
+
if ref_zero_points_after_scales and maybe_w_zp is not None:
|
137
|
+
w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s
|
138
|
+
else:
|
139
|
+
w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s
|
140
|
+
|
141
|
+
if quant_type.has_bias():
|
142
|
+
w_q += quant_type.bias
|
143
|
+
|
144
|
+
# Restore original shapes
|
145
|
+
if group_size is not None and group_size < size_k:
|
146
|
+
|
147
|
+
def reshape_w(w):
|
148
|
+
w = w.reshape((group_size, -1, size_n))
|
149
|
+
w = w.permute(1, 0, 2)
|
150
|
+
w = w.reshape((size_k, size_n)).contiguous()
|
151
|
+
return w
|
152
|
+
|
153
|
+
w_q = reshape_w(w_q)
|
154
|
+
w_ref = reshape_w(w_ref)
|
155
|
+
w_s = w_s.reshape((-1, size_n)).contiguous()
|
156
|
+
|
157
|
+
if maybe_w_zp is not None:
|
158
|
+
maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous()
|
159
|
+
maybe_w_zp = maybe_w_zp.to(device=orig_device)
|
160
|
+
|
161
|
+
return (
|
162
|
+
w_ref.to(device=orig_device),
|
163
|
+
w_q.to(device=orig_device),
|
164
|
+
w_s if group_size is not None else None,
|
165
|
+
maybe_w_zp,
|
166
|
+
)
|
@@ -4,6 +4,7 @@ import torch
|
|
4
4
|
from torch.nn.parameter import Parameter
|
5
5
|
|
6
6
|
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
7
|
+
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
|
7
8
|
from sglang.srt.layers.linear import LinearMethodBase
|
8
9
|
from sglang.srt.layers.parameter import ChannelQuantScaleParameter, ModelWeightParameter
|
9
10
|
from sglang.srt.layers.quantization.base_config import (
|
@@ -11,9 +12,17 @@ from sglang.srt.layers.quantization.base_config import (
|
|
11
12
|
QuantizeMethodBase,
|
12
13
|
)
|
13
14
|
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
|
14
|
-
from sglang.srt.utils import
|
15
|
+
from sglang.srt.utils import (
|
16
|
+
cpu_has_amx_support,
|
17
|
+
is_cpu,
|
18
|
+
is_cuda,
|
19
|
+
set_weight_attrs,
|
20
|
+
use_intel_amx_backend,
|
21
|
+
)
|
15
22
|
|
16
23
|
_is_cuda = is_cuda()
|
24
|
+
_is_cpu_amx_available = cpu_has_amx_support()
|
25
|
+
_is_cpu = is_cpu()
|
17
26
|
if _is_cuda:
|
18
27
|
from sgl_kernel import int8_scaled_mm
|
19
28
|
|
@@ -72,6 +81,13 @@ class W8A8Int8LinearMethod(LinearMethodBase):
|
|
72
81
|
self.quantization_config = quantization_config
|
73
82
|
|
74
83
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
84
|
+
if _is_cpu:
|
85
|
+
assert (
|
86
|
+
_is_cpu_amx_available
|
87
|
+
), "W8A8Int8LinearMethod on CPU requires that CPU has AMX support"
|
88
|
+
_amx_process_weight_after_loading(layer, ["weight"])
|
89
|
+
return
|
90
|
+
|
75
91
|
layer.weight = Parameter(layer.weight.t(), requires_grad=False)
|
76
92
|
layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False)
|
77
93
|
|
@@ -112,6 +128,16 @@ class W8A8Int8LinearMethod(LinearMethodBase):
|
|
112
128
|
x: torch.Tensor,
|
113
129
|
bias: Optional[torch.Tensor] = None,
|
114
130
|
):
|
131
|
+
if use_intel_amx_backend(layer):
|
132
|
+
return torch.ops.sgl_kernel.int8_scaled_mm_with_quant(
|
133
|
+
x,
|
134
|
+
layer.weight,
|
135
|
+
layer.weight_scale,
|
136
|
+
bias,
|
137
|
+
x.dtype,
|
138
|
+
True, # is_vnni
|
139
|
+
)
|
140
|
+
|
115
141
|
x_q, x_scale = per_token_quant_int8(x)
|
116
142
|
|
117
143
|
return int8_scaled_mm(
|
@@ -206,6 +232,13 @@ class W8A8Int8MoEMethod:
|
|
206
232
|
layer.register_parameter("w2_input_scale", w2_input_scale)
|
207
233
|
|
208
234
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
235
|
+
if _is_cpu:
|
236
|
+
assert (
|
237
|
+
_is_cpu_amx_available
|
238
|
+
), "W8A8Int8MoEMethod on CPU requires that CPU has AMX support"
|
239
|
+
_amx_process_weight_after_loading(layer, ["w13_weight", "w2_weight"])
|
240
|
+
return
|
241
|
+
|
209
242
|
layer.w13_weight = Parameter(layer.w13_weight, requires_grad=False)
|
210
243
|
layer.w2_weight = Parameter(layer.w2_weight, requires_grad=False)
|
211
244
|
layer.w13_weight_scale = Parameter(
|
@@ -252,6 +285,24 @@ class W8A8Int8MoEMethod:
|
|
252
285
|
routed_scaling_factor=routed_scaling_factor,
|
253
286
|
)
|
254
287
|
|
288
|
+
if use_intel_amx_backend(layer):
|
289
|
+
return torch.ops.sgl_kernel.fused_experts_cpu(
|
290
|
+
x,
|
291
|
+
layer.w13_weight,
|
292
|
+
layer.w2_weight,
|
293
|
+
topk_weights,
|
294
|
+
topk_ids,
|
295
|
+
False, # inplace See [Note] inplace should be False in fused_experts.
|
296
|
+
True, # use_int8_w8a8
|
297
|
+
False, # use_fp8_w8a16
|
298
|
+
layer.w13_weight_scale, # w1_scale
|
299
|
+
layer.w2_weight_scale, # w2_scale
|
300
|
+
None, # block_size
|
301
|
+
layer.w13_input_scale, # a1_scale
|
302
|
+
layer.w2_input_scale, # a2_scale
|
303
|
+
True, # is_vnni
|
304
|
+
)
|
305
|
+
|
255
306
|
return fused_experts(
|
256
307
|
x,
|
257
308
|
layer.w13_weight,
|
@@ -8,16 +8,29 @@ import torch
|
|
8
8
|
import torch.nn as nn
|
9
9
|
|
10
10
|
from sglang.srt.custom_op import CustomOp
|
11
|
-
from sglang.srt.utils import
|
11
|
+
from sglang.srt.utils import (
|
12
|
+
cpu_has_amx_support,
|
13
|
+
get_bool_env_var,
|
14
|
+
is_cpu,
|
15
|
+
is_cuda,
|
16
|
+
is_hip,
|
17
|
+
is_npu,
|
18
|
+
)
|
12
19
|
|
13
20
|
_is_cuda = is_cuda()
|
14
21
|
_is_hip = is_hip()
|
22
|
+
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
15
23
|
_is_npu = is_npu()
|
16
24
|
_is_cpu_amx_available = cpu_has_amx_support()
|
17
25
|
_is_cpu = is_cpu()
|
18
26
|
|
19
27
|
if _is_cuda:
|
20
28
|
from sgl_kernel import apply_rope_with_cos_sin_cache_inplace
|
29
|
+
if _use_aiter:
|
30
|
+
from aiter.rotary_embedding import get_rope as aiter_get_rope
|
31
|
+
|
32
|
+
if is_npu():
|
33
|
+
import torch_npu
|
21
34
|
|
22
35
|
|
23
36
|
def _rotate_neox(x: torch.Tensor) -> torch.Tensor:
|
@@ -152,6 +165,36 @@ class RotaryEmbedding(CustomOp):
|
|
152
165
|
key = torch.cat((key_rot, key_pass), dim=-1).reshape(key_shape)
|
153
166
|
return query, key
|
154
167
|
|
168
|
+
def forward_npu(
|
169
|
+
self,
|
170
|
+
positions: torch.Tensor,
|
171
|
+
query: torch.Tensor,
|
172
|
+
key: torch.Tensor,
|
173
|
+
offsets: Optional[torch.Tensor] = None,
|
174
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
175
|
+
"""A PyTorch-npu implementation of forward()."""
|
176
|
+
import os
|
177
|
+
|
178
|
+
if get_bool_env_var("SGLANG_ENABLE_TORCH_COMPILE"):
|
179
|
+
return self.forward_native(positions, query, key, offsets)
|
180
|
+
else:
|
181
|
+
rotary_mode = "half"
|
182
|
+
if self.is_neox_style:
|
183
|
+
rotary_mode = "half"
|
184
|
+
else:
|
185
|
+
rotary_mode = "interleave"
|
186
|
+
mrope_section = [0, 0, 0]
|
187
|
+
query_out, key_out = torch_npu.npu_mrope(
|
188
|
+
positions,
|
189
|
+
query,
|
190
|
+
key,
|
191
|
+
self.cos_sin_cache,
|
192
|
+
self.head_size,
|
193
|
+
mrope_section=mrope_section,
|
194
|
+
rotary_mode=rotary_mode,
|
195
|
+
)
|
196
|
+
return query_out, key_out
|
197
|
+
|
155
198
|
def forward_cpu(
|
156
199
|
self,
|
157
200
|
positions: torch.Tensor,
|
@@ -617,7 +660,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
|
617
660
|
beta_slow: int = 1,
|
618
661
|
mscale: float = 1,
|
619
662
|
mscale_all_dim: float = 0,
|
620
|
-
device: Optional[str] = "cuda",
|
663
|
+
device: Optional[str] = "cuda" if not _is_npu else "npu",
|
621
664
|
) -> None:
|
622
665
|
self.scaling_factor = scaling_factor
|
623
666
|
self.extrapolation_factor = extrapolation_factor
|
@@ -636,7 +679,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
|
636
679
|
)
|
637
680
|
|
638
681
|
# Re-dispatch
|
639
|
-
if _is_hip:
|
682
|
+
if _is_hip or _is_npu:
|
640
683
|
self._forward_method = self.forward_native
|
641
684
|
|
642
685
|
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
|
@@ -847,6 +890,43 @@ class Llama4VisionRotaryEmbedding(RotaryEmbedding):
|
|
847
890
|
return query_out.type_as(query), key_out.type_as(key)
|
848
891
|
|
849
892
|
|
893
|
+
class DynamicNTKAlphaRotaryEmbedding(RotaryEmbedding):
|
894
|
+
"""RotaryEmbedding extended with Dynamic NTK scaling.
|
895
|
+
|
896
|
+
Credits to the Reddit users /u/bloc97 and /u/emozilla
|
897
|
+
"""
|
898
|
+
|
899
|
+
def __init__(
|
900
|
+
self,
|
901
|
+
head_size: int,
|
902
|
+
rotary_dim: int,
|
903
|
+
max_position_embeddings: int,
|
904
|
+
base: int,
|
905
|
+
is_neox_style: bool,
|
906
|
+
scaling_alpha: float,
|
907
|
+
dtype: torch.dtype,
|
908
|
+
) -> None:
|
909
|
+
self.scaling_alpha = scaling_alpha
|
910
|
+
super().__init__(
|
911
|
+
head_size, rotary_dim, max_position_embeddings, base, is_neox_style, dtype
|
912
|
+
)
|
913
|
+
|
914
|
+
def _compute_cos_sin_cache(self) -> torch.Tensor:
|
915
|
+
max_len = self.max_position_embeddings
|
916
|
+
base = self.base * self.scaling_alpha ** (
|
917
|
+
self.rotary_dim / (self.rotary_dim - 2)
|
918
|
+
)
|
919
|
+
|
920
|
+
inv_freq = self._compute_inv_freq(base)
|
921
|
+
t = torch.arange(max_len, dtype=torch.float)
|
922
|
+
|
923
|
+
freqs = torch.einsum("i,j -> ij", t, inv_freq)
|
924
|
+
cos = freqs.cos()
|
925
|
+
sin = freqs.sin()
|
926
|
+
cache = torch.cat((cos, sin), dim=-1)
|
927
|
+
return cache
|
928
|
+
|
929
|
+
|
850
930
|
class MRotaryEmbedding(RotaryEmbedding):
|
851
931
|
"""Rotary Embedding with Multimodal Sections."""
|
852
932
|
|
@@ -1191,15 +1271,26 @@ def get_rope(
|
|
1191
1271
|
)
|
1192
1272
|
elif scaling_type == "dynamic":
|
1193
1273
|
scaling_factor = rope_scaling["factor"]
|
1194
|
-
|
1195
|
-
|
1196
|
-
|
1197
|
-
|
1198
|
-
|
1199
|
-
|
1200
|
-
|
1201
|
-
|
1202
|
-
|
1274
|
+
if "alpha" in rope_scaling:
|
1275
|
+
rotary_emb = DynamicNTKAlphaRotaryEmbedding(
|
1276
|
+
head_size,
|
1277
|
+
rotary_dim,
|
1278
|
+
max_position,
|
1279
|
+
base,
|
1280
|
+
is_neox_style,
|
1281
|
+
rope_scaling["alpha"],
|
1282
|
+
dtype,
|
1283
|
+
)
|
1284
|
+
else:
|
1285
|
+
rotary_emb = DynamicNTKScalingRotaryEmbedding(
|
1286
|
+
head_size,
|
1287
|
+
rotary_dim,
|
1288
|
+
max_position,
|
1289
|
+
base,
|
1290
|
+
is_neox_style,
|
1291
|
+
scaling_factor,
|
1292
|
+
dtype,
|
1293
|
+
)
|
1203
1294
|
elif scaling_type == "yarn":
|
1204
1295
|
scaling_factor = rope_scaling["factor"]
|
1205
1296
|
original_max_position = rope_scaling["original_max_position_embeddings"]
|
@@ -1388,7 +1479,8 @@ def get_rope_wrapper(
|
|
1388
1479
|
device: Optional[str] = None,
|
1389
1480
|
):
|
1390
1481
|
if device != "cpu":
|
1391
|
-
|
1482
|
+
wrapper = aiter_get_rope if _use_aiter else get_rope
|
1483
|
+
return wrapper(
|
1392
1484
|
head_size,
|
1393
1485
|
rotary_dim,
|
1394
1486
|
max_position,
|
@@ -13,6 +13,7 @@ from sglang.srt.distributed import (
|
|
13
13
|
get_tensor_model_parallel_world_size,
|
14
14
|
tensor_model_parallel_all_reduce,
|
15
15
|
)
|
16
|
+
from sglang.srt.layers.amx_utils import PackWeightMethod
|
16
17
|
from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
|
17
18
|
from sglang.srt.layers.parameter import BasevLLMParameter
|
18
19
|
from sglang.srt.layers.quantization.base_config import (
|
@@ -20,10 +21,13 @@ from sglang.srt.layers.quantization.base_config import (
|
|
20
21
|
QuantizeMethodBase,
|
21
22
|
method_has_implemented_embedding,
|
22
23
|
)
|
23
|
-
from sglang.srt.utils import set_weight_attrs
|
24
|
+
from sglang.srt.utils import cpu_has_amx_support, is_cpu, set_weight_attrs
|
24
25
|
|
25
26
|
DEFAULT_VOCAB_PADDING_SIZE = 64
|
26
27
|
|
28
|
+
_is_cpu_amx_available = cpu_has_amx_support()
|
29
|
+
_is_cpu = is_cpu()
|
30
|
+
|
27
31
|
|
28
32
|
class UnquantizedEmbeddingMethod(QuantizeMethodBase):
|
29
33
|
"""Unquantized method for embeddings."""
|
@@ -242,8 +246,16 @@ class VocabParallelEmbedding(torch.nn.Module):
|
|
242
246
|
self.tp_size = 1
|
243
247
|
|
244
248
|
self.num_embeddings = num_embeddings
|
245
|
-
self.padding_size = padding_size
|
246
249
|
self.org_vocab_size = org_num_embeddings or num_embeddings
|
250
|
+
|
251
|
+
# Support the case where the vocab size is not divisible by the TP size.
|
252
|
+
if (
|
253
|
+
_is_cpu
|
254
|
+
and pad_vocab_size(self.org_vocab_size, padding_size) % self.tp_size != 0
|
255
|
+
):
|
256
|
+
padding_size *= self.tp_size
|
257
|
+
self.padding_size = padding_size
|
258
|
+
|
247
259
|
num_added_embeddings = num_embeddings - self.org_vocab_size
|
248
260
|
self.use_presharded_weights = use_presharded_weights
|
249
261
|
if use_presharded_weights:
|
@@ -549,6 +561,11 @@ class ParallelLMHead(VocabParallelEmbedding):
|
|
549
561
|
use_presharded_weights=use_presharded_weights,
|
550
562
|
)
|
551
563
|
self.quant_config = quant_config
|
564
|
+
|
565
|
+
# We only support pack LMHead if it's not quantized. For LMHead with quant_config, the weight_name will be "qweight"
|
566
|
+
if self.quant_config is None and _is_cpu and _is_cpu_amx_available:
|
567
|
+
self.quant_method = PackWeightMethod(weight_names=["weight"])
|
568
|
+
|
552
569
|
if bias:
|
553
570
|
self.bias = Parameter(
|
554
571
|
torch.empty(self.num_embeddings_per_partition, dtype=params_dtype)
|
sglang/srt/lora/lora.py
CHANGED
@@ -65,7 +65,7 @@ class LoRAAdapter(nn.Module):
|
|
65
65
|
self.layers: List[LoRALayer] = nn.ModuleList(
|
66
66
|
[
|
67
67
|
LoRALayer(config, base_hf_config)
|
68
|
-
for
|
68
|
+
for _ in range(base_hf_config.num_hidden_layers)
|
69
69
|
]
|
70
70
|
)
|
71
71
|
|
@@ -88,10 +88,9 @@ class LoRAAdapter(nn.Module):
|
|
88
88
|
else:
|
89
89
|
self.weights[name] = loaded_weight.cpu()
|
90
90
|
|
91
|
-
#
|
92
|
-
for
|
93
|
-
|
94
|
-
weight_names = [name for name, _ in layer.weights.items()]
|
91
|
+
# normalize kv_proj and gate_up_proj
|
92
|
+
for layer in self.layers:
|
93
|
+
weight_names = list(layer.weights.keys())
|
95
94
|
self.normalize_qkv_proj(weight_names, layer.weights)
|
96
95
|
self.normalize_gate_up_proj(weight_names, layer.weights)
|
97
96
|
|
sglang/srt/lora/lora_manager.py
CHANGED
@@ -35,6 +35,7 @@ from sglang.srt.lora.utils import (
|
|
35
35
|
get_normalized_lora_weight_names,
|
36
36
|
get_weight_name,
|
37
37
|
)
|
38
|
+
from sglang.srt.managers.io_struct import LoRAUpdateResult
|
38
39
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
39
40
|
from sglang.srt.utils import replace_submodule
|
40
41
|
|
@@ -98,44 +99,96 @@ class LoRAManager:
|
|
98
99
|
],
|
99
100
|
)
|
100
101
|
|
101
|
-
def
|
102
|
+
def create_lora_update_result(
|
103
|
+
self, success: bool, error_message: str = ""
|
104
|
+
) -> LoRAUpdateResult:
|
105
|
+
return LoRAUpdateResult(
|
106
|
+
success=success,
|
107
|
+
error_message=error_message,
|
108
|
+
loaded_adapters={
|
109
|
+
name: config.path for name, config in self.configs.items()
|
110
|
+
},
|
111
|
+
)
|
112
|
+
|
113
|
+
def load_lora_adapters(self, lora_paths: Dict[str, str]) -> LoRAUpdateResult:
|
102
114
|
"""
|
103
115
|
Load LoRA adapters from the specified paths.
|
104
|
-
TODO (lifuhuang): This method should be exposed to the server/engine API to support dynamic LoRA loading.
|
105
116
|
|
106
117
|
Args:
|
107
118
|
lora_paths (Dict[str, str]): A dictionary mapping LoRA adapter names to their file paths.
|
108
119
|
If a LoRA adapter is already loaded, it will be skipped with a warning.
|
109
120
|
"""
|
110
121
|
|
122
|
+
results = []
|
111
123
|
for lora_name, lora_path in lora_paths.items():
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
124
|
+
result = self.load_lora_adapter(lora_name, lora_path, update_state=False)
|
125
|
+
results.append(result)
|
126
|
+
|
127
|
+
self.update_state_from_configs()
|
128
|
+
|
129
|
+
return self.create_lora_update_result(
|
130
|
+
success=all(result.success for result in results),
|
131
|
+
error_message="\n".join(
|
132
|
+
result.error_message for result in results if not result.success
|
133
|
+
),
|
134
|
+
)
|
135
|
+
|
136
|
+
def load_lora_adapter(
|
137
|
+
self, lora_name: str, lora_path: str, update_state: bool = True
|
138
|
+
) -> LoRAUpdateResult:
|
139
|
+
"""
|
140
|
+
Load a single LoRA adapter from the specified path.
|
141
|
+
|
142
|
+
Args:
|
143
|
+
lora_name (str): The name of the LoRA adapter.
|
144
|
+
lora_path (str): The file path to the LoRA adapter.
|
145
|
+
update_state (bool): Whether to refresh the internal state after loading the adapter. This is useful for batch loading.
|
146
|
+
"""
|
118
147
|
|
148
|
+
success = True
|
149
|
+
error_message = ""
|
150
|
+
|
151
|
+
if lora_name in self.loras:
|
152
|
+
success = False
|
153
|
+
error_message = f"LoRA adapter {lora_name} is skipped as it is already loaded. If you want to reload it, please unload it first."
|
154
|
+
|
155
|
+
try:
|
119
156
|
self.configs[lora_name] = LoRAConfig(lora_path)
|
157
|
+
except Exception as e:
|
158
|
+
success = False
|
159
|
+
error_message = (
|
160
|
+
f"Failed to load LoRA adapter {lora_name} from {lora_path}: {str(e)}"
|
161
|
+
)
|
120
162
|
|
121
|
-
|
163
|
+
if update_state:
|
164
|
+
self.update_state_from_configs()
|
165
|
+
|
166
|
+
return self.create_lora_update_result(
|
167
|
+
success=success,
|
168
|
+
error_message=error_message,
|
169
|
+
)
|
122
170
|
|
123
|
-
def
|
171
|
+
def unload_lora_adapter(self, lora_name: str) -> LoRAUpdateResult:
|
124
172
|
"""
|
125
173
|
Unload LoRA adapters by their names. This will remove the adapters from the memory pool and
|
126
174
|
delete the corresponding LoRA modules.
|
127
|
-
|
128
|
-
Args:
|
129
|
-
lora_names (Set[str]): A set of LoRA adapter names to unload.
|
130
175
|
"""
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
176
|
+
|
177
|
+
success = True
|
178
|
+
error_message = ""
|
179
|
+
if lora_name in self.loras:
|
180
|
+
del self.configs[lora_name]
|
181
|
+
else:
|
182
|
+
error_message = f"LoRA adapter {lora_name} is not loaded."
|
183
|
+
success = False
|
136
184
|
|
137
185
|
self.update_state_from_configs()
|
138
186
|
|
187
|
+
return self.create_lora_update_result(
|
188
|
+
success=success,
|
189
|
+
error_message=error_message,
|
190
|
+
)
|
191
|
+
|
139
192
|
def prepare_lora_batch(self, forward_batch: ForwardBatch):
|
140
193
|
# load active loras into lora memory pool
|
141
194
|
cur_uids = set(forward_batch.lora_paths)
|
@@ -372,8 +425,8 @@ class LoRAManager:
|
|
372
425
|
lora_adapter.initialize_weights()
|
373
426
|
self.loras[name] = lora_adapter
|
374
427
|
|
375
|
-
# Clean up unused LoRA adapters
|
376
|
-
for name in self.loras:
|
428
|
+
# Clean up unused LoRA adapters, copying the list to avoid modifying the dict during iteration.
|
429
|
+
for name in list(self.loras):
|
377
430
|
if name not in self.configs:
|
378
431
|
logger.info(f"Unloading LoRA adapter {name}")
|
379
432
|
del self.loras[name]
|
@@ -28,7 +28,7 @@ if __name__ == "__main__":
|
|
28
28
|
parser = argparse.ArgumentParser()
|
29
29
|
parser.add_argument("--url", type=str, default="http://localhost:30000")
|
30
30
|
parser.add_argument("--log-requests", action="store_true")
|
31
|
-
parser.add_argument("--log-requests-level", type=int, default=
|
31
|
+
parser.add_argument("--log-requests-level", type=int, default=3)
|
32
32
|
parser.add_argument(
|
33
33
|
"--dump-requests-folder", type=str, default="/tmp/sglang_request_dump"
|
34
34
|
)
|