sglang 0.4.8.post1__py3-none-any.whl → 0.4.9.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +17 -2
- sglang/bench_serving.py +170 -24
- sglang/srt/configs/internvl.py +4 -2
- sglang/srt/configs/janus_pro.py +1 -1
- sglang/srt/configs/model_config.py +60 -1
- sglang/srt/configs/update_config.py +119 -0
- sglang/srt/conversation.py +69 -1
- sglang/srt/disaggregation/decode.py +21 -5
- sglang/srt/disaggregation/mooncake/conn.py +35 -4
- sglang/srt/disaggregation/nixl/conn.py +6 -6
- sglang/srt/disaggregation/prefill.py +2 -2
- sglang/srt/disaggregation/utils.py +1 -1
- sglang/srt/distributed/parallel_state.py +44 -17
- sglang/srt/entrypoints/EngineBase.py +8 -0
- sglang/srt/entrypoints/engine.py +40 -6
- sglang/srt/entrypoints/http_server.py +111 -24
- sglang/srt/entrypoints/http_server_engine.py +1 -1
- sglang/srt/entrypoints/openai/protocol.py +4 -2
- sglang/srt/eplb/__init__.py +0 -0
- sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
- sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
- sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
- sglang/srt/{managers → eplb}/expert_distribution.py +1 -5
- sglang/srt/{managers → eplb}/expert_location.py +1 -1
- sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
- sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
- sglang/srt/hf_transformers_utils.py +2 -1
- sglang/srt/layers/activation.py +2 -2
- sglang/srt/layers/amx_utils.py +86 -0
- sglang/srt/layers/attention/ascend_backend.py +219 -0
- sglang/srt/layers/attention/flashattention_backend.py +32 -9
- sglang/srt/layers/attention/tbo_backend.py +37 -9
- sglang/srt/layers/communicator.py +20 -2
- sglang/srt/layers/dp_attention.py +9 -3
- sglang/srt/layers/elementwise.py +76 -12
- sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
- sglang/srt/layers/layernorm.py +26 -0
- sglang/srt/layers/linear.py +84 -14
- sglang/srt/layers/logits_processor.py +4 -4
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +215 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +81 -8
- sglang/srt/layers/moe/ep_moe/layer.py +176 -15
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +23 -17
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +3 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +211 -74
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +176 -0
- sglang/srt/layers/moe/router.py +60 -22
- sglang/srt/layers/moe/topk.py +10 -28
- sglang/srt/layers/parameter.py +67 -7
- sglang/srt/layers/quantization/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
- sglang/srt/layers/quantization/fp8.py +72 -7
- sglang/srt/layers/quantization/fp8_kernel.py +1 -1
- sglang/srt/layers/quantization/fp8_utils.py +1 -2
- sglang/srt/layers/quantization/gptq.py +5 -1
- sglang/srt/layers/quantization/modelopt_quant.py +244 -1
- sglang/srt/layers/quantization/moe_wna16.py +1 -1
- sglang/srt/layers/quantization/quant_utils.py +166 -0
- sglang/srt/layers/quantization/w4afp8.py +264 -0
- sglang/srt/layers/quantization/w8a8_int8.py +52 -1
- sglang/srt/layers/rotary_embedding.py +2 -2
- sglang/srt/layers/vocab_parallel_embedding.py +20 -10
- sglang/srt/lora/lora.py +4 -5
- sglang/srt/lora/lora_manager.py +73 -20
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +30 -19
- sglang/srt/lora/triton_ops/qkv_lora_b.py +30 -19
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +27 -11
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +27 -15
- sglang/srt/managers/cache_controller.py +41 -195
- sglang/srt/managers/configure_logging.py +1 -1
- sglang/srt/managers/io_struct.py +58 -14
- sglang/srt/managers/mm_utils.py +77 -61
- sglang/srt/managers/multimodal_processor.py +2 -6
- sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
- sglang/srt/managers/schedule_batch.py +78 -85
- sglang/srt/managers/scheduler.py +130 -64
- sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
- sglang/srt/managers/session_controller.py +12 -3
- sglang/srt/managers/tokenizer_manager.py +314 -103
- sglang/srt/managers/tp_worker.py +13 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
- sglang/srt/mem_cache/allocator.py +290 -0
- sglang/srt/mem_cache/chunk_cache.py +34 -2
- sglang/srt/mem_cache/hiradix_cache.py +2 -0
- sglang/srt/mem_cache/memory_pool.py +402 -66
- sglang/srt/mem_cache/memory_pool_host.py +6 -109
- sglang/srt/mem_cache/multimodal_cache.py +3 -0
- sglang/srt/mem_cache/radix_cache.py +8 -4
- sglang/srt/model_executor/cuda_graph_runner.py +2 -1
- sglang/srt/model_executor/forward_batch_info.py +17 -4
- sglang/srt/model_executor/model_runner.py +297 -56
- sglang/srt/model_loader/loader.py +41 -0
- sglang/srt/model_loader/weight_utils.py +72 -4
- sglang/srt/models/deepseek_nextn.py +1 -3
- sglang/srt/models/deepseek_v2.py +195 -45
- sglang/srt/models/deepseek_vl2.py +3 -5
- sglang/srt/models/gemma3_causal.py +1 -2
- sglang/srt/models/gemma3n_causal.py +4 -3
- sglang/srt/models/gemma3n_mm.py +4 -20
- sglang/srt/models/hunyuan.py +1 -1
- sglang/srt/models/kimi_vl.py +1 -2
- sglang/srt/models/llama.py +10 -4
- sglang/srt/models/llama4.py +32 -45
- sglang/srt/models/llama_eagle3.py +61 -11
- sglang/srt/models/llava.py +5 -5
- sglang/srt/models/minicpmo.py +2 -2
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mllama4.py +402 -89
- sglang/srt/models/phi4mm.py +1 -3
- sglang/srt/models/pixtral.py +3 -7
- sglang/srt/models/qwen2.py +31 -3
- sglang/srt/models/qwen2_5_vl.py +1 -3
- sglang/srt/models/qwen2_audio.py +200 -0
- sglang/srt/models/qwen2_moe.py +32 -6
- sglang/srt/models/qwen2_vl.py +1 -4
- sglang/srt/models/qwen3.py +94 -25
- sglang/srt/models/qwen3_moe.py +68 -21
- sglang/srt/models/vila.py +3 -8
- sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +2 -2
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +140 -158
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3n.py +5 -20
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +65 -66
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
- sglang/srt/operations_strategy.py +6 -2
- sglang/srt/reasoning_parser.py +26 -0
- sglang/srt/sampling/sampling_batch_info.py +39 -1
- sglang/srt/server_args.py +84 -22
- sglang/srt/speculative/build_eagle_tree.py +57 -18
- sglang/srt/speculative/eagle_worker.py +6 -4
- sglang/srt/two_batch_overlap.py +203 -27
- sglang/srt/utils.py +343 -163
- sglang/srt/warmup.py +12 -3
- sglang/test/runners.py +10 -1
- sglang/test/test_cutlass_w4a8_moe.py +281 -0
- sglang/test/test_utils.py +15 -3
- sglang/utils.py +5 -5
- sglang/version.py +1 -1
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/METADATA +12 -8
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/RECORD +157 -146
- sglang/math_utils.py +0 -8
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
- /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,264 @@
|
|
1
|
+
import logging
|
2
|
+
from typing import Any, Dict, List, Optional
|
3
|
+
|
4
|
+
import torch
|
5
|
+
from torch.nn import Module
|
6
|
+
from torch.nn.parameter import Parameter
|
7
|
+
|
8
|
+
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
|
9
|
+
from sglang.srt.layers.quantization.base_config import (
|
10
|
+
QuantizationConfig,
|
11
|
+
QuantizeMethodBase,
|
12
|
+
)
|
13
|
+
from sglang.srt.layers.quantization.fp8 import Fp8LinearMethod
|
14
|
+
from sglang.srt.layers.quantization.utils import is_layer_skipped
|
15
|
+
from sglang.srt.utils import set_weight_attrs
|
16
|
+
|
17
|
+
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
18
|
+
|
19
|
+
logger = logging.getLogger(__name__)
|
20
|
+
|
21
|
+
|
22
|
+
class W4AFp8Config(QuantizationConfig):
|
23
|
+
"""Config class for MIXED_PRECISION W4AFp8."""
|
24
|
+
|
25
|
+
def __init__(
|
26
|
+
self,
|
27
|
+
is_checkpoint_fp8_serialized: bool = True,
|
28
|
+
is_checkpoint_w4afp8_serialized: bool = True,
|
29
|
+
linear_activation_scheme: str = "dynamic",
|
30
|
+
moe_activation_scheme: str = "static",
|
31
|
+
ignored_layers: Optional[List[str]] = None,
|
32
|
+
weight_block_size: Optional[List[int]] = None,
|
33
|
+
group_size: int = 128,
|
34
|
+
) -> None:
|
35
|
+
super().__init__()
|
36
|
+
self.is_checkpoint_fp8_serialized = is_checkpoint_fp8_serialized
|
37
|
+
self.is_checkpoint_w4afp8_serialized = is_checkpoint_w4afp8_serialized
|
38
|
+
if is_checkpoint_w4afp8_serialized:
|
39
|
+
logger.warning("Detected w4afp8 checkpoint. Please note that")
|
40
|
+
if moe_activation_scheme not in ACTIVATION_SCHEMES:
|
41
|
+
raise ValueError(f"Unsupported activation scheme {moe_activation_scheme}")
|
42
|
+
self.linear_activation_scheme = linear_activation_scheme
|
43
|
+
self.moe_activation_scheme = moe_activation_scheme
|
44
|
+
self.ignored_layers = ignored_layers or []
|
45
|
+
self.weight_block_size = [128, 128]
|
46
|
+
self.group_size = group_size
|
47
|
+
|
48
|
+
@classmethod
|
49
|
+
def get_name(cls) -> str:
|
50
|
+
return "w4afp8"
|
51
|
+
|
52
|
+
@classmethod
|
53
|
+
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
54
|
+
return [torch.bfloat16, torch.float8_e4m3fn]
|
55
|
+
|
56
|
+
@classmethod
|
57
|
+
def get_min_capability(cls) -> int:
|
58
|
+
return 90
|
59
|
+
|
60
|
+
@classmethod
|
61
|
+
def get_config_filenames(cls) -> List[str]:
|
62
|
+
return []
|
63
|
+
|
64
|
+
@classmethod
|
65
|
+
def from_config(cls, config: Dict[str, Any]) -> "W4AFp8Config":
|
66
|
+
quant_method = cls.get_from_keys(config, ["quant_method"])
|
67
|
+
is_checkpoint_fp8_serialized = "fp8" in quant_method
|
68
|
+
is_checkpoint_w4afp8_serialized = "w4afp8" in quant_method
|
69
|
+
linear_activation_scheme = "dynamic"
|
70
|
+
moe_activation_scheme = "static"
|
71
|
+
weight_block_size = [128, 128]
|
72
|
+
return cls(
|
73
|
+
is_checkpoint_fp8_serialized=is_checkpoint_fp8_serialized,
|
74
|
+
is_checkpoint_w4afp8_serialized=is_checkpoint_w4afp8_serialized,
|
75
|
+
linear_activation_scheme=linear_activation_scheme,
|
76
|
+
moe_activation_scheme=moe_activation_scheme,
|
77
|
+
weight_block_size=weight_block_size,
|
78
|
+
)
|
79
|
+
|
80
|
+
def get_quant_method(
|
81
|
+
self, layer: torch.nn.Module, prefix: str
|
82
|
+
) -> Optional["QuantizeMethodBase"]:
|
83
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
84
|
+
|
85
|
+
if isinstance(layer, LinearBase):
|
86
|
+
if is_layer_skipped(prefix, self.ignored_layers):
|
87
|
+
return UnquantizedLinearMethod()
|
88
|
+
return Fp8LinearMethod(self)
|
89
|
+
elif isinstance(layer, FusedMoE):
|
90
|
+
return W4AFp8MoEMethod(self)
|
91
|
+
return None
|
92
|
+
|
93
|
+
def get_scaled_act_names(self) -> List[str]:
|
94
|
+
return []
|
95
|
+
|
96
|
+
|
97
|
+
class W4AFp8MoEMethod:
|
98
|
+
|
99
|
+
def __init__(self, quant_config: W4AFp8Config):
|
100
|
+
self.quant_config = quant_config
|
101
|
+
|
102
|
+
def create_weights(
|
103
|
+
self,
|
104
|
+
layer: Module,
|
105
|
+
num_experts_per_partition: int,
|
106
|
+
hidden_size: int,
|
107
|
+
intermediate_size: int,
|
108
|
+
params_dtype: torch.dtype,
|
109
|
+
**extra_weight_attrs,
|
110
|
+
):
|
111
|
+
assert "weight_loader" in extra_weight_attrs
|
112
|
+
|
113
|
+
# Fused gate_up_proj (column parallel)
|
114
|
+
w13_weight = torch.nn.Parameter(
|
115
|
+
torch.empty(
|
116
|
+
num_experts_per_partition,
|
117
|
+
intermediate_size * 2,
|
118
|
+
hidden_size // 2,
|
119
|
+
dtype=torch.int8,
|
120
|
+
),
|
121
|
+
requires_grad=False,
|
122
|
+
)
|
123
|
+
layer.register_parameter("w13_weight", w13_weight)
|
124
|
+
set_weight_attrs(w13_weight, extra_weight_attrs)
|
125
|
+
|
126
|
+
# down_proj (row parallel)
|
127
|
+
w2_weight = torch.nn.Parameter(
|
128
|
+
torch.empty(
|
129
|
+
num_experts_per_partition,
|
130
|
+
hidden_size,
|
131
|
+
intermediate_size // 2,
|
132
|
+
dtype=torch.int8,
|
133
|
+
),
|
134
|
+
requires_grad=False,
|
135
|
+
)
|
136
|
+
layer.register_parameter("w2_weight", w2_weight)
|
137
|
+
set_weight_attrs(w2_weight, extra_weight_attrs)
|
138
|
+
|
139
|
+
w13_weight_scale = torch.nn.Parameter(
|
140
|
+
torch.zeros(
|
141
|
+
num_experts_per_partition,
|
142
|
+
2 * intermediate_size,
|
143
|
+
hidden_size // self.quant_config.group_size,
|
144
|
+
dtype=torch.float32,
|
145
|
+
),
|
146
|
+
requires_grad=False,
|
147
|
+
)
|
148
|
+
layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
|
149
|
+
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
150
|
+
|
151
|
+
w2_weight_scale = torch.nn.Parameter(
|
152
|
+
torch.zeros(
|
153
|
+
num_experts_per_partition,
|
154
|
+
hidden_size,
|
155
|
+
intermediate_size // self.quant_config.group_size,
|
156
|
+
dtype=torch.float32,
|
157
|
+
),
|
158
|
+
requires_grad=False,
|
159
|
+
)
|
160
|
+
layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
|
161
|
+
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
162
|
+
|
163
|
+
# Input scales
|
164
|
+
w13_input_scale = torch.nn.Parameter(
|
165
|
+
torch.ones((num_experts_per_partition, 2), dtype=torch.bfloat16),
|
166
|
+
requires_grad=False,
|
167
|
+
)
|
168
|
+
layer.register_parameter("w13_input_scale", w13_input_scale)
|
169
|
+
set_weight_attrs(w13_input_scale, extra_weight_attrs)
|
170
|
+
|
171
|
+
w2_input_scale = torch.nn.Parameter(
|
172
|
+
torch.ones(num_experts_per_partition, dtype=torch.bfloat16),
|
173
|
+
requires_grad=False,
|
174
|
+
)
|
175
|
+
layer.register_parameter("w2_input_scale", w2_input_scale)
|
176
|
+
set_weight_attrs(w2_input_scale, extra_weight_attrs)
|
177
|
+
|
178
|
+
# Pre-populate the strides
|
179
|
+
device = layer.w13_weight.device
|
180
|
+
|
181
|
+
self.a_strides1 = torch.full(
|
182
|
+
(num_experts_per_partition, 3),
|
183
|
+
hidden_size,
|
184
|
+
device=device,
|
185
|
+
dtype=torch.int64,
|
186
|
+
)
|
187
|
+
self.c_strides1 = torch.full(
|
188
|
+
(num_experts_per_partition, 3),
|
189
|
+
2 * intermediate_size,
|
190
|
+
device=device,
|
191
|
+
dtype=torch.int64,
|
192
|
+
)
|
193
|
+
self.a_strides2 = torch.full(
|
194
|
+
(num_experts_per_partition, 3),
|
195
|
+
intermediate_size,
|
196
|
+
device=device,
|
197
|
+
dtype=torch.int64,
|
198
|
+
)
|
199
|
+
self.c_strides2 = torch.full(
|
200
|
+
(num_experts_per_partition, 3),
|
201
|
+
hidden_size,
|
202
|
+
device=device,
|
203
|
+
dtype=torch.int64,
|
204
|
+
)
|
205
|
+
self.b_strides1 = self.a_strides1
|
206
|
+
self.s_strides13 = self.c_strides1
|
207
|
+
self.b_strides2 = self.a_strides2
|
208
|
+
self.s_strides2 = self.c_strides2
|
209
|
+
|
210
|
+
self.expert_offsets = torch.empty(
|
211
|
+
(num_experts_per_partition + 1), dtype=torch.int32, device=device
|
212
|
+
)
|
213
|
+
self.problem_sizes1 = torch.empty(
|
214
|
+
(num_experts_per_partition, 3), dtype=torch.int32, device=device
|
215
|
+
)
|
216
|
+
self.problem_sizes2 = torch.empty(
|
217
|
+
(num_experts_per_partition, 3), dtype=torch.int32, device=device
|
218
|
+
)
|
219
|
+
|
220
|
+
return
|
221
|
+
|
222
|
+
def _interleave_scales(self, scales: torch.Tensor) -> torch.Tensor:
|
223
|
+
"""Interleave scales in groups of 4 similar to TRT-LLM implementation."""
|
224
|
+
s_shape = scales.shape
|
225
|
+
# Reshape to separate groups of 4
|
226
|
+
scales_interleaved = scales.reshape(
|
227
|
+
s_shape[0], s_shape[1], (s_shape[2] // 4), 4
|
228
|
+
)
|
229
|
+
# Permute dimensions to interleave
|
230
|
+
scales_interleaved = scales_interleaved.permute(0, 2, 1, 3)
|
231
|
+
# Reshape back to original dimensions but with interleaved values
|
232
|
+
scales_interleaved = scales_interleaved.reshape(
|
233
|
+
s_shape[0], s_shape[2] // 4, s_shape[1] * 4
|
234
|
+
)
|
235
|
+
return scales_interleaved.contiguous()
|
236
|
+
|
237
|
+
def process_weights_after_loading(self, layer: Module) -> None:
|
238
|
+
dtype = torch.bfloat16
|
239
|
+
device = layer.w2_weight.device
|
240
|
+
|
241
|
+
# Interleave w13_weight_scale (gate_up_proj)
|
242
|
+
w13_weight_scale = layer.w13_weight_scale_inv.to(dtype)
|
243
|
+
w13_weight_scale = self._interleave_scales(w13_weight_scale)
|
244
|
+
layer.w13_weight_scale_inv = Parameter(w13_weight_scale, requires_grad=False)
|
245
|
+
|
246
|
+
# Interleave w2_weight_scale (down_proj)
|
247
|
+
w2_weight_scale = layer.w2_weight_scale_inv.to(dtype)
|
248
|
+
w2_weight_scale = self._interleave_scales(w2_weight_scale)
|
249
|
+
layer.w2_weight_scale_inv = Parameter(w2_weight_scale, requires_grad=False)
|
250
|
+
|
251
|
+
# Process input scales
|
252
|
+
w13_input_scale_max = layer.w13_input_scale.max().to(dtype).item()
|
253
|
+
new_w13_input_scale = torch.tensor(
|
254
|
+
[w13_input_scale_max],
|
255
|
+
dtype=dtype,
|
256
|
+
device=device,
|
257
|
+
)
|
258
|
+
layer.w13_input_scale = Parameter(new_w13_input_scale, requires_grad=False)
|
259
|
+
|
260
|
+
w2_input_scale_max = layer.w2_input_scale.max().to(dtype).item()
|
261
|
+
new_w2_input_scale = torch.tensor(
|
262
|
+
[w2_input_scale_max], dtype=dtype, device=device
|
263
|
+
)
|
264
|
+
layer.w2_input_scale = Parameter(new_w2_input_scale, requires_grad=False)
|
@@ -4,6 +4,7 @@ import torch
|
|
4
4
|
from torch.nn.parameter import Parameter
|
5
5
|
|
6
6
|
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
7
|
+
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
|
7
8
|
from sglang.srt.layers.linear import LinearMethodBase
|
8
9
|
from sglang.srt.layers.parameter import ChannelQuantScaleParameter, ModelWeightParameter
|
9
10
|
from sglang.srt.layers.quantization.base_config import (
|
@@ -11,9 +12,17 @@ from sglang.srt.layers.quantization.base_config import (
|
|
11
12
|
QuantizeMethodBase,
|
12
13
|
)
|
13
14
|
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
|
14
|
-
from sglang.srt.utils import
|
15
|
+
from sglang.srt.utils import (
|
16
|
+
cpu_has_amx_support,
|
17
|
+
is_cpu,
|
18
|
+
is_cuda,
|
19
|
+
set_weight_attrs,
|
20
|
+
use_intel_amx_backend,
|
21
|
+
)
|
15
22
|
|
16
23
|
_is_cuda = is_cuda()
|
24
|
+
_is_cpu_amx_available = cpu_has_amx_support()
|
25
|
+
_is_cpu = is_cpu()
|
17
26
|
if _is_cuda:
|
18
27
|
from sgl_kernel import int8_scaled_mm
|
19
28
|
|
@@ -72,6 +81,13 @@ class W8A8Int8LinearMethod(LinearMethodBase):
|
|
72
81
|
self.quantization_config = quantization_config
|
73
82
|
|
74
83
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
84
|
+
if _is_cpu:
|
85
|
+
assert (
|
86
|
+
_is_cpu_amx_available
|
87
|
+
), "W8A8Int8LinearMethod on CPU requires that CPU has AMX support"
|
88
|
+
_amx_process_weight_after_loading(layer, ["weight"])
|
89
|
+
return
|
90
|
+
|
75
91
|
layer.weight = Parameter(layer.weight.t(), requires_grad=False)
|
76
92
|
layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False)
|
77
93
|
|
@@ -112,6 +128,16 @@ class W8A8Int8LinearMethod(LinearMethodBase):
|
|
112
128
|
x: torch.Tensor,
|
113
129
|
bias: Optional[torch.Tensor] = None,
|
114
130
|
):
|
131
|
+
if use_intel_amx_backend(layer):
|
132
|
+
return torch.ops.sgl_kernel.int8_scaled_mm_with_quant(
|
133
|
+
x,
|
134
|
+
layer.weight,
|
135
|
+
layer.weight_scale,
|
136
|
+
bias,
|
137
|
+
x.dtype,
|
138
|
+
True, # is_vnni
|
139
|
+
)
|
140
|
+
|
115
141
|
x_q, x_scale = per_token_quant_int8(x)
|
116
142
|
|
117
143
|
return int8_scaled_mm(
|
@@ -206,6 +232,13 @@ class W8A8Int8MoEMethod:
|
|
206
232
|
layer.register_parameter("w2_input_scale", w2_input_scale)
|
207
233
|
|
208
234
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
235
|
+
if _is_cpu:
|
236
|
+
assert (
|
237
|
+
_is_cpu_amx_available
|
238
|
+
), "W8A8Int8MoEMethod on CPU requires that CPU has AMX support"
|
239
|
+
_amx_process_weight_after_loading(layer, ["w13_weight", "w2_weight"])
|
240
|
+
return
|
241
|
+
|
209
242
|
layer.w13_weight = Parameter(layer.w13_weight, requires_grad=False)
|
210
243
|
layer.w2_weight = Parameter(layer.w2_weight, requires_grad=False)
|
211
244
|
layer.w13_weight_scale = Parameter(
|
@@ -252,6 +285,24 @@ class W8A8Int8MoEMethod:
|
|
252
285
|
routed_scaling_factor=routed_scaling_factor,
|
253
286
|
)
|
254
287
|
|
288
|
+
if use_intel_amx_backend(layer):
|
289
|
+
return torch.ops.sgl_kernel.fused_experts_cpu(
|
290
|
+
x,
|
291
|
+
layer.w13_weight,
|
292
|
+
layer.w2_weight,
|
293
|
+
topk_weights,
|
294
|
+
topk_ids,
|
295
|
+
False, # inplace See [Note] inplace should be False in fused_experts.
|
296
|
+
True, # use_int8_w8a8
|
297
|
+
False, # use_fp8_w8a16
|
298
|
+
layer.w13_weight_scale, # w1_scale
|
299
|
+
layer.w2_weight_scale, # w2_scale
|
300
|
+
None, # block_size
|
301
|
+
layer.w13_input_scale, # a1_scale
|
302
|
+
layer.w2_input_scale, # a2_scale
|
303
|
+
True, # is_vnni
|
304
|
+
)
|
305
|
+
|
255
306
|
return fused_experts(
|
256
307
|
x,
|
257
308
|
layer.w13_weight,
|
@@ -660,7 +660,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
|
660
660
|
beta_slow: int = 1,
|
661
661
|
mscale: float = 1,
|
662
662
|
mscale_all_dim: float = 0,
|
663
|
-
device: Optional[str] = "cuda",
|
663
|
+
device: Optional[str] = "cuda" if not _is_npu else "npu",
|
664
664
|
) -> None:
|
665
665
|
self.scaling_factor = scaling_factor
|
666
666
|
self.extrapolation_factor = extrapolation_factor
|
@@ -679,7 +679,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
|
679
679
|
)
|
680
680
|
|
681
681
|
# Re-dispatch
|
682
|
-
if _is_hip:
|
682
|
+
if _is_hip or _is_npu:
|
683
683
|
self._forward_method = self.forward_native
|
684
684
|
|
685
685
|
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
|
@@ -1,5 +1,6 @@
|
|
1
1
|
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.3.post1/vllm/model_executor/layers/vocab_parallel_embedding.py
|
2
2
|
|
3
|
+
import logging
|
3
4
|
from dataclasses import dataclass
|
4
5
|
from typing import List, Optional, Sequence, Tuple
|
5
6
|
|
@@ -13,6 +14,7 @@ from sglang.srt.distributed import (
|
|
13
14
|
get_tensor_model_parallel_world_size,
|
14
15
|
tensor_model_parallel_all_reduce,
|
15
16
|
)
|
17
|
+
from sglang.srt.layers.amx_utils import PackWeightMethod
|
16
18
|
from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
|
17
19
|
from sglang.srt.layers.parameter import BasevLLMParameter
|
18
20
|
from sglang.srt.layers.quantization.base_config import (
|
@@ -20,18 +22,15 @@ from sglang.srt.layers.quantization.base_config import (
|
|
20
22
|
QuantizeMethodBase,
|
21
23
|
method_has_implemented_embedding,
|
22
24
|
)
|
23
|
-
from sglang.srt.utils import
|
24
|
-
PackWeightMethod,
|
25
|
-
cpu_has_amx_support,
|
26
|
-
is_cpu,
|
27
|
-
set_weight_attrs,
|
28
|
-
)
|
25
|
+
from sglang.srt.utils import cpu_has_amx_support, is_cpu, set_weight_attrs
|
29
26
|
|
30
27
|
DEFAULT_VOCAB_PADDING_SIZE = 64
|
31
28
|
|
32
29
|
_is_cpu_amx_available = cpu_has_amx_support()
|
33
30
|
_is_cpu = is_cpu()
|
34
31
|
|
32
|
+
logger = logging.getLogger(__name__)
|
33
|
+
|
35
34
|
|
36
35
|
class UnquantizedEmbeddingMethod(QuantizeMethodBase):
|
37
36
|
"""Unquantized method for embeddings."""
|
@@ -250,8 +249,16 @@ class VocabParallelEmbedding(torch.nn.Module):
|
|
250
249
|
self.tp_size = 1
|
251
250
|
|
252
251
|
self.num_embeddings = num_embeddings
|
253
|
-
self.padding_size = padding_size
|
254
252
|
self.org_vocab_size = org_num_embeddings or num_embeddings
|
253
|
+
|
254
|
+
# Support the case where the vocab size is not divisible by the TP size.
|
255
|
+
if (
|
256
|
+
_is_cpu
|
257
|
+
and pad_vocab_size(self.org_vocab_size, padding_size) % self.tp_size != 0
|
258
|
+
):
|
259
|
+
padding_size *= self.tp_size
|
260
|
+
self.padding_size = padding_size
|
261
|
+
|
255
262
|
num_added_embeddings = num_embeddings - self.org_vocab_size
|
256
263
|
self.use_presharded_weights = use_presharded_weights
|
257
264
|
if use_presharded_weights:
|
@@ -558,9 +565,12 @@ class ParallelLMHead(VocabParallelEmbedding):
|
|
558
565
|
)
|
559
566
|
self.quant_config = quant_config
|
560
567
|
|
561
|
-
# We only support pack LMHead if it's not quantized.
|
562
|
-
if
|
563
|
-
self
|
568
|
+
# We only support pack LMHead if it's not quantized.
|
569
|
+
if _is_cpu and _is_cpu_amx_available:
|
570
|
+
if hasattr(self, "weight") and self.weight.dtype == torch.bfloat16:
|
571
|
+
self.quant_method = PackWeightMethod(weight_names=["weight"])
|
572
|
+
else:
|
573
|
+
logger.warning("The weight of LmHead is not packed")
|
564
574
|
|
565
575
|
if bias:
|
566
576
|
self.bias = Parameter(
|
sglang/srt/lora/lora.py
CHANGED
@@ -65,7 +65,7 @@ class LoRAAdapter(nn.Module):
|
|
65
65
|
self.layers: List[LoRALayer] = nn.ModuleList(
|
66
66
|
[
|
67
67
|
LoRALayer(config, base_hf_config)
|
68
|
-
for
|
68
|
+
for _ in range(base_hf_config.num_hidden_layers)
|
69
69
|
]
|
70
70
|
)
|
71
71
|
|
@@ -88,10 +88,9 @@ class LoRAAdapter(nn.Module):
|
|
88
88
|
else:
|
89
89
|
self.weights[name] = loaded_weight.cpu()
|
90
90
|
|
91
|
-
#
|
92
|
-
for
|
93
|
-
|
94
|
-
weight_names = [name for name, _ in layer.weights.items()]
|
91
|
+
# normalize kv_proj and gate_up_proj
|
92
|
+
for layer in self.layers:
|
93
|
+
weight_names = list(layer.weights.keys())
|
95
94
|
self.normalize_qkv_proj(weight_names, layer.weights)
|
96
95
|
self.normalize_gate_up_proj(weight_names, layer.weights)
|
97
96
|
|
sglang/srt/lora/lora_manager.py
CHANGED
@@ -35,6 +35,7 @@ from sglang.srt.lora.utils import (
|
|
35
35
|
get_normalized_lora_weight_names,
|
36
36
|
get_weight_name,
|
37
37
|
)
|
38
|
+
from sglang.srt.managers.io_struct import LoRAUpdateResult
|
38
39
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
39
40
|
from sglang.srt.utils import replace_submodule
|
40
41
|
|
@@ -98,44 +99,96 @@ class LoRAManager:
|
|
98
99
|
],
|
99
100
|
)
|
100
101
|
|
101
|
-
def
|
102
|
+
def create_lora_update_result(
|
103
|
+
self, success: bool, error_message: str = ""
|
104
|
+
) -> LoRAUpdateResult:
|
105
|
+
return LoRAUpdateResult(
|
106
|
+
success=success,
|
107
|
+
error_message=error_message,
|
108
|
+
loaded_adapters={
|
109
|
+
name: config.path for name, config in self.configs.items()
|
110
|
+
},
|
111
|
+
)
|
112
|
+
|
113
|
+
def load_lora_adapters(self, lora_paths: Dict[str, str]) -> LoRAUpdateResult:
|
102
114
|
"""
|
103
115
|
Load LoRA adapters from the specified paths.
|
104
|
-
TODO (lifuhuang): This method should be exposed to the server/engine API to support dynamic LoRA loading.
|
105
116
|
|
106
117
|
Args:
|
107
118
|
lora_paths (Dict[str, str]): A dictionary mapping LoRA adapter names to their file paths.
|
108
119
|
If a LoRA adapter is already loaded, it will be skipped with a warning.
|
109
120
|
"""
|
110
121
|
|
122
|
+
results = []
|
111
123
|
for lora_name, lora_path in lora_paths.items():
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
124
|
+
result = self.load_lora_adapter(lora_name, lora_path, update_state=False)
|
125
|
+
results.append(result)
|
126
|
+
|
127
|
+
self.update_state_from_configs()
|
128
|
+
|
129
|
+
return self.create_lora_update_result(
|
130
|
+
success=all(result.success for result in results),
|
131
|
+
error_message="\n".join(
|
132
|
+
result.error_message for result in results if not result.success
|
133
|
+
),
|
134
|
+
)
|
135
|
+
|
136
|
+
def load_lora_adapter(
|
137
|
+
self, lora_name: str, lora_path: str, update_state: bool = True
|
138
|
+
) -> LoRAUpdateResult:
|
139
|
+
"""
|
140
|
+
Load a single LoRA adapter from the specified path.
|
141
|
+
|
142
|
+
Args:
|
143
|
+
lora_name (str): The name of the LoRA adapter.
|
144
|
+
lora_path (str): The file path to the LoRA adapter.
|
145
|
+
update_state (bool): Whether to refresh the internal state after loading the adapter. This is useful for batch loading.
|
146
|
+
"""
|
118
147
|
|
148
|
+
success = True
|
149
|
+
error_message = ""
|
150
|
+
|
151
|
+
if lora_name in self.loras:
|
152
|
+
success = False
|
153
|
+
error_message = f"LoRA adapter {lora_name} is skipped as it is already loaded. If you want to reload it, please unload it first."
|
154
|
+
|
155
|
+
try:
|
119
156
|
self.configs[lora_name] = LoRAConfig(lora_path)
|
157
|
+
except Exception as e:
|
158
|
+
success = False
|
159
|
+
error_message = (
|
160
|
+
f"Failed to load LoRA adapter {lora_name} from {lora_path}: {str(e)}"
|
161
|
+
)
|
120
162
|
|
121
|
-
|
163
|
+
if update_state:
|
164
|
+
self.update_state_from_configs()
|
165
|
+
|
166
|
+
return self.create_lora_update_result(
|
167
|
+
success=success,
|
168
|
+
error_message=error_message,
|
169
|
+
)
|
122
170
|
|
123
|
-
def
|
171
|
+
def unload_lora_adapter(self, lora_name: str) -> LoRAUpdateResult:
|
124
172
|
"""
|
125
173
|
Unload LoRA adapters by their names. This will remove the adapters from the memory pool and
|
126
174
|
delete the corresponding LoRA modules.
|
127
|
-
|
128
|
-
Args:
|
129
|
-
lora_names (Set[str]): A set of LoRA adapter names to unload.
|
130
175
|
"""
|
131
|
-
|
132
|
-
|
133
|
-
|
134
|
-
|
135
|
-
|
176
|
+
|
177
|
+
success = True
|
178
|
+
error_message = ""
|
179
|
+
if lora_name in self.loras:
|
180
|
+
del self.configs[lora_name]
|
181
|
+
else:
|
182
|
+
error_message = f"LoRA adapter {lora_name} is not loaded."
|
183
|
+
success = False
|
136
184
|
|
137
185
|
self.update_state_from_configs()
|
138
186
|
|
187
|
+
return self.create_lora_update_result(
|
188
|
+
success=success,
|
189
|
+
error_message=error_message,
|
190
|
+
)
|
191
|
+
|
139
192
|
def prepare_lora_batch(self, forward_batch: ForwardBatch):
|
140
193
|
# load active loras into lora memory pool
|
141
194
|
cur_uids = set(forward_batch.lora_paths)
|
@@ -372,8 +425,8 @@ class LoRAManager:
|
|
372
425
|
lora_adapter.initialize_weights()
|
373
426
|
self.loras[name] = lora_adapter
|
374
427
|
|
375
|
-
# Clean up unused LoRA adapters
|
376
|
-
for name in self.loras:
|
428
|
+
# Clean up unused LoRA adapters, copying the list to avoid modifying the dict during iteration.
|
429
|
+
for name in list(self.loras):
|
377
430
|
if name not in self.configs:
|
378
431
|
logger.info(f"Unloading LoRA adapter {name}")
|
379
432
|
del self.loras[name]
|