sglang 0.4.8.post1__py3-none-any.whl → 0.4.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +17 -2
- sglang/bench_serving.py +168 -22
- sglang/srt/configs/internvl.py +4 -2
- sglang/srt/configs/janus_pro.py +1 -1
- sglang/srt/configs/model_config.py +48 -0
- sglang/srt/configs/update_config.py +119 -0
- sglang/srt/conversation.py +34 -0
- sglang/srt/disaggregation/decode.py +21 -5
- sglang/srt/disaggregation/nixl/conn.py +6 -6
- sglang/srt/disaggregation/prefill.py +2 -2
- sglang/srt/disaggregation/utils.py +1 -1
- sglang/srt/distributed/parallel_state.py +44 -17
- sglang/srt/entrypoints/EngineBase.py +8 -0
- sglang/srt/entrypoints/engine.py +40 -6
- sglang/srt/entrypoints/http_server.py +111 -24
- sglang/srt/entrypoints/openai/protocol.py +4 -2
- sglang/srt/eplb/__init__.py +0 -0
- sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
- sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
- sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
- sglang/srt/{managers → eplb}/expert_distribution.py +1 -5
- sglang/srt/{managers → eplb}/expert_location.py +1 -1
- sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
- sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
- sglang/srt/hf_transformers_utils.py +2 -1
- sglang/srt/layers/activation.py +2 -2
- sglang/srt/layers/amx_utils.py +86 -0
- sglang/srt/layers/attention/ascend_backend.py +219 -0
- sglang/srt/layers/attention/flashattention_backend.py +32 -9
- sglang/srt/layers/attention/tbo_backend.py +37 -9
- sglang/srt/layers/communicator.py +18 -2
- sglang/srt/layers/dp_attention.py +9 -3
- sglang/srt/layers/elementwise.py +76 -12
- sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
- sglang/srt/layers/layernorm.py +26 -0
- sglang/srt/layers/linear.py +84 -14
- sglang/srt/layers/logits_processor.py +4 -4
- sglang/srt/layers/moe/ep_moe/kernels.py +23 -8
- sglang/srt/layers/moe/ep_moe/layer.py +36 -13
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +23 -17
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +76 -16
- sglang/srt/layers/moe/router.py +60 -22
- sglang/srt/layers/moe/topk.py +10 -28
- sglang/srt/layers/parameter.py +67 -7
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
- sglang/srt/layers/quantization/fp8.py +44 -0
- sglang/srt/layers/quantization/fp8_kernel.py +1 -1
- sglang/srt/layers/quantization/fp8_utils.py +1 -2
- sglang/srt/layers/quantization/gptq.py +5 -1
- sglang/srt/layers/quantization/moe_wna16.py +1 -1
- sglang/srt/layers/quantization/quant_utils.py +166 -0
- sglang/srt/layers/quantization/w8a8_int8.py +52 -1
- sglang/srt/layers/rotary_embedding.py +2 -2
- sglang/srt/layers/vocab_parallel_embedding.py +11 -7
- sglang/srt/lora/lora.py +4 -5
- sglang/srt/lora/lora_manager.py +73 -20
- sglang/srt/managers/configure_logging.py +1 -1
- sglang/srt/managers/io_struct.py +50 -13
- sglang/srt/managers/mm_utils.py +73 -59
- sglang/srt/managers/multimodal_processor.py +2 -6
- sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
- sglang/srt/managers/schedule_batch.py +77 -84
- sglang/srt/managers/scheduler.py +113 -59
- sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
- sglang/srt/managers/session_controller.py +12 -3
- sglang/srt/managers/tokenizer_manager.py +314 -103
- sglang/srt/managers/tp_worker.py +13 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
- sglang/srt/mem_cache/allocator.py +290 -0
- sglang/srt/mem_cache/chunk_cache.py +34 -2
- sglang/srt/mem_cache/memory_pool.py +289 -3
- sglang/srt/mem_cache/multimodal_cache.py +3 -0
- sglang/srt/model_executor/cuda_graph_runner.py +2 -1
- sglang/srt/model_executor/forward_batch_info.py +17 -4
- sglang/srt/model_executor/model_runner.py +297 -56
- sglang/srt/model_loader/loader.py +41 -0
- sglang/srt/model_loader/weight_utils.py +72 -4
- sglang/srt/models/deepseek_nextn.py +1 -3
- sglang/srt/models/deepseek_v2.py +181 -45
- sglang/srt/models/deepseek_vl2.py +3 -5
- sglang/srt/models/gemma3_causal.py +1 -2
- sglang/srt/models/gemma3n_causal.py +4 -3
- sglang/srt/models/gemma3n_mm.py +4 -20
- sglang/srt/models/hunyuan.py +1 -1
- sglang/srt/models/kimi_vl.py +1 -2
- sglang/srt/models/llama.py +10 -4
- sglang/srt/models/llama4.py +32 -45
- sglang/srt/models/llama_eagle3.py +61 -11
- sglang/srt/models/llava.py +5 -5
- sglang/srt/models/minicpmo.py +2 -2
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mllama4.py +43 -11
- sglang/srt/models/phi4mm.py +1 -3
- sglang/srt/models/pixtral.py +3 -7
- sglang/srt/models/qwen2.py +31 -3
- sglang/srt/models/qwen2_5_vl.py +1 -3
- sglang/srt/models/qwen2_audio.py +200 -0
- sglang/srt/models/qwen2_moe.py +32 -6
- sglang/srt/models/qwen2_vl.py +1 -4
- sglang/srt/models/qwen3.py +94 -25
- sglang/srt/models/qwen3_moe.py +68 -21
- sglang/srt/models/vila.py +3 -8
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +140 -158
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3n.py +5 -20
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +3 -6
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
- sglang/srt/operations_strategy.py +6 -2
- sglang/srt/reasoning_parser.py +26 -0
- sglang/srt/sampling/sampling_batch_info.py +39 -1
- sglang/srt/server_args.py +69 -22
- sglang/srt/speculative/build_eagle_tree.py +57 -18
- sglang/srt/speculative/eagle_worker.py +6 -4
- sglang/srt/two_batch_overlap.py +200 -27
- sglang/srt/utils.py +306 -146
- sglang/srt/warmup.py +12 -3
- sglang/test/runners.py +10 -1
- sglang/test/test_utils.py +15 -3
- sglang/version.py +1 -1
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/METADATA +9 -6
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/RECORD +140 -133
- sglang/math_utils.py +0 -8
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
- /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
- /sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/WHEEL +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/top_level.txt +0 -0
sglang/srt/layers/parameter.py
CHANGED
@@ -7,6 +7,8 @@ from typing import Callable, Optional, Union
|
|
7
7
|
import torch
|
8
8
|
from torch.nn import Parameter
|
9
9
|
|
10
|
+
from sglang.srt.utils import is_cpu
|
11
|
+
|
10
12
|
__all__ = [
|
11
13
|
"BasevLLMParameter",
|
12
14
|
"PackedvLLMParameter",
|
@@ -21,6 +23,8 @@ __all__ = [
|
|
21
23
|
|
22
24
|
logger = logging.getLogger(__name__)
|
23
25
|
|
26
|
+
_is_cpu = is_cpu()
|
27
|
+
|
24
28
|
|
25
29
|
class BasevLLMParameter(Parameter):
|
26
30
|
"""
|
@@ -93,9 +97,28 @@ class _ColumnvLLMParameter(BasevLLMParameter):
|
|
93
97
|
):
|
94
98
|
if not use_presharded_weights:
|
95
99
|
shard_size = self.data.shape[self.output_dim]
|
96
|
-
|
97
|
-
|
100
|
+
|
101
|
+
from sglang.srt.model_loader.weight_utils import (
|
102
|
+
narrow_padded_param_and_loaded_weight,
|
98
103
|
)
|
104
|
+
|
105
|
+
if _is_cpu:
|
106
|
+
param_data, loaded_weight = narrow_padded_param_and_loaded_weight(
|
107
|
+
self.data,
|
108
|
+
loaded_weight,
|
109
|
+
0, # param_data_start
|
110
|
+
tp_rank * shard_size,
|
111
|
+
self.output_dim,
|
112
|
+
shard_size,
|
113
|
+
)
|
114
|
+
assert param_data.shape == loaded_weight.shape
|
115
|
+
param_data.copy_(loaded_weight)
|
116
|
+
return
|
117
|
+
else:
|
118
|
+
loaded_weight = loaded_weight.narrow(
|
119
|
+
self.output_dim, tp_rank * shard_size, shard_size
|
120
|
+
)
|
121
|
+
|
99
122
|
assert self.data.shape == loaded_weight.shape
|
100
123
|
self.data.copy_(loaded_weight)
|
101
124
|
|
@@ -116,10 +139,27 @@ class _ColumnvLLMParameter(BasevLLMParameter):
|
|
116
139
|
param_data = self.data
|
117
140
|
|
118
141
|
param_data = param_data.narrow(self.output_dim, shard_offset, shard_size)
|
119
|
-
|
120
|
-
|
121
|
-
|
142
|
+
|
143
|
+
from sglang.srt.model_loader.weight_utils import (
|
144
|
+
narrow_padded_param_and_loaded_weight,
|
145
|
+
)
|
146
|
+
|
147
|
+
if _is_cpu:
|
148
|
+
param_data, loaded_weight = narrow_padded_param_and_loaded_weight(
|
149
|
+
param_data,
|
150
|
+
loaded_weight,
|
151
|
+
0, # param_data_start
|
152
|
+
tp_rank * shard_size,
|
153
|
+
self.output_dim,
|
154
|
+
shard_size,
|
155
|
+
not use_presharded_weights,
|
122
156
|
)
|
157
|
+
else:
|
158
|
+
if not use_presharded_weights:
|
159
|
+
loaded_weight = loaded_weight.narrow(
|
160
|
+
self.output_dim, tp_rank * shard_size, shard_size
|
161
|
+
)
|
162
|
+
|
123
163
|
assert param_data.shape == loaded_weight.shape
|
124
164
|
param_data.copy_(loaded_weight)
|
125
165
|
|
@@ -182,10 +222,30 @@ class RowvLLMParameter(BasevLLMParameter):
|
|
182
222
|
):
|
183
223
|
if not use_presharded_weights:
|
184
224
|
shard_size = self.data.shape[self.input_dim]
|
185
|
-
|
186
|
-
|
225
|
+
|
226
|
+
from sglang.srt.model_loader.weight_utils import (
|
227
|
+
narrow_padded_param_and_loaded_weight,
|
187
228
|
)
|
188
229
|
|
230
|
+
if _is_cpu:
|
231
|
+
param_data, loaded_weight = narrow_padded_param_and_loaded_weight(
|
232
|
+
self.data,
|
233
|
+
loaded_weight,
|
234
|
+
0, # param_data_start
|
235
|
+
tp_rank * shard_size,
|
236
|
+
self.input_dim,
|
237
|
+
shard_size,
|
238
|
+
)
|
239
|
+
|
240
|
+
assert param_data.shape == loaded_weight.shape
|
241
|
+
param_data.copy_(loaded_weight)
|
242
|
+
|
243
|
+
return
|
244
|
+
else:
|
245
|
+
loaded_weight = loaded_weight.narrow(
|
246
|
+
self.input_dim, tp_rank * shard_size, shard_size
|
247
|
+
)
|
248
|
+
|
189
249
|
if len(loaded_weight.shape) == 0:
|
190
250
|
loaded_weight = loaded_weight.reshape(1)
|
191
251
|
|
@@ -76,7 +76,7 @@ class CompressedTensorsW8A16Fp8(CompressedTensorsScheme):
|
|
76
76
|
layer.input_scale = torch.nn.Parameter(
|
77
77
|
layer.input_scale.data, requires_grad=False
|
78
78
|
)
|
79
|
-
prepare_fp8_layer_for_marlin(layer,
|
79
|
+
prepare_fp8_layer_for_marlin(layer, size_k_first=True)
|
80
80
|
|
81
81
|
def create_weights(
|
82
82
|
self,
|
@@ -27,6 +27,7 @@ except ImportError:
|
|
27
27
|
|
28
28
|
|
29
29
|
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
30
|
+
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
|
30
31
|
from sglang.srt.layers.linear import (
|
31
32
|
LinearBase,
|
32
33
|
LinearMethodBase,
|
@@ -73,6 +74,7 @@ from sglang.srt.utils import (
|
|
73
74
|
log_info_on_rank0,
|
74
75
|
print_warning_once,
|
75
76
|
set_weight_attrs,
|
77
|
+
use_intel_amx_backend,
|
76
78
|
)
|
77
79
|
|
78
80
|
_is_hip = is_hip()
|
@@ -330,6 +332,12 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
330
332
|
)
|
331
333
|
|
332
334
|
layer.input_scale = None
|
335
|
+
elif _is_cpu:
|
336
|
+
assert (
|
337
|
+
_is_cpu_amx_available
|
338
|
+
), "Fp8LinearMethod on CPU requires that CPU has AMX support"
|
339
|
+
_amx_process_weight_after_loading(layer, ["weight"])
|
340
|
+
return
|
333
341
|
else:
|
334
342
|
weight, weight_scale = layer.weight.data, layer.weight_scale_inv.data
|
335
343
|
layer.weight = torch.nn.Parameter(weight, requires_grad=False)
|
@@ -426,6 +434,17 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
426
434
|
)
|
427
435
|
|
428
436
|
if self.block_quant:
|
437
|
+
if use_intel_amx_backend(layer):
|
438
|
+
return torch.ops.sgl_kernel.fp8_scaled_mm_cpu(
|
439
|
+
x,
|
440
|
+
layer.weight,
|
441
|
+
layer.weight_scale_inv,
|
442
|
+
self.quant_config.weight_block_size,
|
443
|
+
bias,
|
444
|
+
x.dtype,
|
445
|
+
True, # is_vnni
|
446
|
+
)
|
447
|
+
|
429
448
|
return self.w8a8_block_fp8_linear(
|
430
449
|
input=x,
|
431
450
|
weight=layer.weight,
|
@@ -746,6 +765,13 @@ class Fp8MoEMethod:
|
|
746
765
|
layer.w2_weight.data = shuffle_weight(
|
747
766
|
layer.w2_weight.contiguous(), (16, 16)
|
748
767
|
)
|
768
|
+
|
769
|
+
if _is_cpu:
|
770
|
+
assert (
|
771
|
+
_is_cpu_amx_available
|
772
|
+
), "Fp8MoEMethod on CPU requires that CPU has AMX support"
|
773
|
+
_amx_process_weight_after_loading(layer, ["w13_weight", "w2_weight"])
|
774
|
+
|
749
775
|
return
|
750
776
|
|
751
777
|
# If checkpoint is fp16 or bfloat16, quantize in place.
|
@@ -971,6 +997,24 @@ class Fp8MoEMethod:
|
|
971
997
|
routed_scaling_factor=routed_scaling_factor,
|
972
998
|
)
|
973
999
|
|
1000
|
+
if use_intel_amx_backend(layer):
|
1001
|
+
return torch.ops.sgl_kernel.fused_experts_cpu(
|
1002
|
+
x,
|
1003
|
+
layer.w13_weight,
|
1004
|
+
layer.w2_weight,
|
1005
|
+
topk_weights,
|
1006
|
+
topk_ids,
|
1007
|
+
False, # inplace See [Note] inplace should be False in fused_experts.
|
1008
|
+
False, # use_int8_w8a8
|
1009
|
+
True, # use_fp8_w8a16
|
1010
|
+
layer.w13_weight_scale_inv, # w1_scale
|
1011
|
+
layer.w2_weight_scale_inv, # w2_scale
|
1012
|
+
self.quant_config.weight_block_size, # block_size
|
1013
|
+
None, # a1_scale
|
1014
|
+
None, # a2_scale
|
1015
|
+
True, # is_vnni
|
1016
|
+
)
|
1017
|
+
|
974
1018
|
if _is_hip:
|
975
1019
|
ret = self.maybe_apply_hip_fused_experts(
|
976
1020
|
layer,
|
@@ -23,9 +23,9 @@ import torch
|
|
23
23
|
import triton
|
24
24
|
import triton.language as tl
|
25
25
|
|
26
|
-
from sglang.math_utils import align
|
27
26
|
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
28
27
|
from sglang.srt.utils import (
|
28
|
+
align,
|
29
29
|
direct_register_custom_op,
|
30
30
|
get_device_core_count,
|
31
31
|
get_device_name,
|
@@ -1,9 +1,7 @@
|
|
1
1
|
from typing import Callable, List, Optional, Tuple
|
2
2
|
|
3
|
-
import einops
|
4
3
|
import torch
|
5
4
|
|
6
|
-
from sglang.math_utils import align
|
7
5
|
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
8
6
|
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8
|
9
7
|
from sglang.srt.layers.utils import is_sm100_supported
|
@@ -27,6 +25,7 @@ from sglang.srt.layers.quantization.fp8_kernel import (
|
|
27
25
|
w8a8_block_fp8_matmul_triton,
|
28
26
|
)
|
29
27
|
from sglang.srt.utils import (
|
28
|
+
align,
|
30
29
|
get_bool_env_var,
|
31
30
|
get_cuda_version,
|
32
31
|
get_device_capability,
|
@@ -344,6 +344,10 @@ class GPTQMarlinConfig(QuantizationConfig):
|
|
344
344
|
if (num_bits, sym) not in cls.TYPE_MAP:
|
345
345
|
return False
|
346
346
|
|
347
|
+
assert (
|
348
|
+
VLLM_AVAILABLE
|
349
|
+
), "vllm is not installed, to use gptq_marlin, please install vllm"
|
350
|
+
|
347
351
|
return check_marlin_supported(
|
348
352
|
quant_type=cls.TYPE_MAP[(num_bits, sym)], group_size=group_size
|
349
353
|
)
|
@@ -726,6 +730,6 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
|
726
730
|
g_idx2=layer.w2_g_idx,
|
727
731
|
sort_indices1=layer.w13_g_idx_sort_indices,
|
728
732
|
sort_indices2=layer.w2_g_idx_sort_indices,
|
729
|
-
|
733
|
+
quant_type_id=self.quant_config.quant_type.id,
|
730
734
|
is_k_full=self.is_k_full,
|
731
735
|
).to(orig_dtype)
|
@@ -131,7 +131,7 @@ class MoeWNA16Config(QuantizationConfig):
|
|
131
131
|
capability_tuple = get_device_capability()
|
132
132
|
device_capability = (
|
133
133
|
-1
|
134
|
-
if
|
134
|
+
if all(capability is None for capability in capability_tuple)
|
135
135
|
else capability_tuple[0] * 10 + capability_tuple[1]
|
136
136
|
)
|
137
137
|
# Avoid circular import
|
@@ -0,0 +1,166 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py
|
3
|
+
|
4
|
+
from typing import Optional
|
5
|
+
|
6
|
+
import numpy
|
7
|
+
import torch
|
8
|
+
from sgl_kernel.scalar_type import ScalarType
|
9
|
+
|
10
|
+
|
11
|
+
def get_pack_factor(num_bits):
|
12
|
+
assert 32 % num_bits == 0, f"Unsupported num_bits = {num_bits}"
|
13
|
+
return 32 // num_bits
|
14
|
+
|
15
|
+
|
16
|
+
def pack_cols(
|
17
|
+
q_w: torch.Tensor,
|
18
|
+
num_bits: int,
|
19
|
+
size_k: int,
|
20
|
+
size_n: int,
|
21
|
+
):
|
22
|
+
assert q_w.shape == (size_k, size_n)
|
23
|
+
|
24
|
+
pack_factor = get_pack_factor(num_bits)
|
25
|
+
assert size_n % pack_factor == 0
|
26
|
+
|
27
|
+
orig_device = q_w.device
|
28
|
+
|
29
|
+
q_w = q_w.cpu().numpy().astype(numpy.uint32)
|
30
|
+
|
31
|
+
q_res = numpy.zeros((size_k, size_n // pack_factor), dtype=numpy.uint32)
|
32
|
+
|
33
|
+
for i in range(pack_factor):
|
34
|
+
q_res |= q_w[:, i::pack_factor] << num_bits * i
|
35
|
+
|
36
|
+
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
|
37
|
+
q_res = q_res.contiguous()
|
38
|
+
|
39
|
+
return q_res
|
40
|
+
|
41
|
+
|
42
|
+
def unpack_cols(
|
43
|
+
packed_q_w: torch.Tensor,
|
44
|
+
num_bits: int,
|
45
|
+
size_k: int,
|
46
|
+
size_n: int,
|
47
|
+
):
|
48
|
+
pack_factor = get_pack_factor(num_bits)
|
49
|
+
assert size_n % pack_factor == 0
|
50
|
+
assert packed_q_w.shape == (
|
51
|
+
size_k,
|
52
|
+
size_n // pack_factor,
|
53
|
+
), "packed_q_w.shape = {} size_k = {}, size_n = {} pack_Factor = {}".format(
|
54
|
+
packed_q_w.shape, size_k, size_n, pack_factor
|
55
|
+
)
|
56
|
+
|
57
|
+
orig_device = packed_q_w.device
|
58
|
+
|
59
|
+
packed_q_w_cpu = packed_q_w.cpu().numpy().astype(numpy.uint32)
|
60
|
+
q_res = numpy.zeros((size_k, size_n), dtype=numpy.uint32)
|
61
|
+
|
62
|
+
mask = (1 << num_bits) - 1
|
63
|
+
for i in range(pack_factor):
|
64
|
+
vals = packed_q_w_cpu & mask
|
65
|
+
packed_q_w_cpu >>= num_bits
|
66
|
+
q_res[:, i::pack_factor] = vals
|
67
|
+
|
68
|
+
q_res = torch.from_numpy(q_res.astype(numpy.int32)).to(orig_device)
|
69
|
+
q_res = q_res.contiguous()
|
70
|
+
|
71
|
+
return q_res
|
72
|
+
|
73
|
+
|
74
|
+
def quantize_weights(
|
75
|
+
w: torch.Tensor,
|
76
|
+
quant_type: ScalarType,
|
77
|
+
group_size: Optional[int],
|
78
|
+
zero_points: bool = False,
|
79
|
+
ref_zero_points_after_scales: bool = False,
|
80
|
+
):
|
81
|
+
assert (
|
82
|
+
quant_type.is_integer()
|
83
|
+
), "Floating point quantization may work but has not been tested"
|
84
|
+
assert not zero_points or group_size is not None, (
|
85
|
+
"to have group zero points, group_size must be provided "
|
86
|
+
"(-1 group_size is channelwise)"
|
87
|
+
)
|
88
|
+
|
89
|
+
orig_device = w.device
|
90
|
+
orig_type = w.dtype
|
91
|
+
size_k, size_n = w.shape
|
92
|
+
|
93
|
+
assert w.is_floating_point(), "w must be float"
|
94
|
+
|
95
|
+
if group_size == -1:
|
96
|
+
group_size = size_k
|
97
|
+
|
98
|
+
# Reshape to [groupsize, -1]
|
99
|
+
if group_size is not None and group_size < size_k:
|
100
|
+
w = w.reshape((-1, group_size, size_n))
|
101
|
+
w = w.permute(1, 0, 2)
|
102
|
+
w = w.reshape((group_size, -1))
|
103
|
+
|
104
|
+
# Compute scale for each group
|
105
|
+
max_val = torch.max(w, 0, keepdim=True).values
|
106
|
+
min_val = torch.min(w, 0, keepdim=True).values
|
107
|
+
|
108
|
+
max_q_val = quant_type.max()
|
109
|
+
min_q_val = quant_type.min()
|
110
|
+
|
111
|
+
w_s = torch.Tensor([1.0]).to(w.device) # unscaled case
|
112
|
+
maybe_w_zp = None
|
113
|
+
if group_size is not None:
|
114
|
+
if zero_points:
|
115
|
+
assert not quant_type.is_signed() and quant_type.max() > 0
|
116
|
+
w_s = (max_val - min_val).clamp(min=1e-5) / quant_type.max()
|
117
|
+
maybe_w_zp = (
|
118
|
+
torch.round(torch.abs(min_val / w_s)).clamp(min_q_val, max_q_val).int()
|
119
|
+
)
|
120
|
+
else:
|
121
|
+
# If the bias is such that there are no possible negative/positive
|
122
|
+
# values, set the max value to inf to avoid divide by 0
|
123
|
+
w_s = torch.max(
|
124
|
+
abs(max_val / (max_q_val if max_q_val != 0 else torch.inf)),
|
125
|
+
abs(min_val / (min_q_val if min_q_val != 0 else torch.inf)),
|
126
|
+
)
|
127
|
+
|
128
|
+
# Quantize
|
129
|
+
w_q = torch.round(w / w_s).int() + (maybe_w_zp if zero_points else 0)
|
130
|
+
w_q = torch.clamp(w_q, min_q_val, max_q_val)
|
131
|
+
|
132
|
+
# Compute ref (dequantized)
|
133
|
+
# For some kernels (namely Machete) the zero-points are applied after the
|
134
|
+
# scales are applied, for this case computing the reference in similar way
|
135
|
+
# allows us to use tighter error tolerances in our unit tests.
|
136
|
+
if ref_zero_points_after_scales and maybe_w_zp is not None:
|
137
|
+
w_ref = w_q.to(orig_type) * w_s - maybe_w_zp.to(orig_type) * w_s
|
138
|
+
else:
|
139
|
+
w_ref = (w_q - (maybe_w_zp if zero_points else 0)).to(orig_type) * w_s
|
140
|
+
|
141
|
+
if quant_type.has_bias():
|
142
|
+
w_q += quant_type.bias
|
143
|
+
|
144
|
+
# Restore original shapes
|
145
|
+
if group_size is not None and group_size < size_k:
|
146
|
+
|
147
|
+
def reshape_w(w):
|
148
|
+
w = w.reshape((group_size, -1, size_n))
|
149
|
+
w = w.permute(1, 0, 2)
|
150
|
+
w = w.reshape((size_k, size_n)).contiguous()
|
151
|
+
return w
|
152
|
+
|
153
|
+
w_q = reshape_w(w_q)
|
154
|
+
w_ref = reshape_w(w_ref)
|
155
|
+
w_s = w_s.reshape((-1, size_n)).contiguous()
|
156
|
+
|
157
|
+
if maybe_w_zp is not None:
|
158
|
+
maybe_w_zp = maybe_w_zp.reshape((-1, size_n)).contiguous()
|
159
|
+
maybe_w_zp = maybe_w_zp.to(device=orig_device)
|
160
|
+
|
161
|
+
return (
|
162
|
+
w_ref.to(device=orig_device),
|
163
|
+
w_q.to(device=orig_device),
|
164
|
+
w_s if group_size is not None else None,
|
165
|
+
maybe_w_zp,
|
166
|
+
)
|
@@ -4,6 +4,7 @@ import torch
|
|
4
4
|
from torch.nn.parameter import Parameter
|
5
5
|
|
6
6
|
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
7
|
+
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
|
7
8
|
from sglang.srt.layers.linear import LinearMethodBase
|
8
9
|
from sglang.srt.layers.parameter import ChannelQuantScaleParameter, ModelWeightParameter
|
9
10
|
from sglang.srt.layers.quantization.base_config import (
|
@@ -11,9 +12,17 @@ from sglang.srt.layers.quantization.base_config import (
|
|
11
12
|
QuantizeMethodBase,
|
12
13
|
)
|
13
14
|
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
|
14
|
-
from sglang.srt.utils import
|
15
|
+
from sglang.srt.utils import (
|
16
|
+
cpu_has_amx_support,
|
17
|
+
is_cpu,
|
18
|
+
is_cuda,
|
19
|
+
set_weight_attrs,
|
20
|
+
use_intel_amx_backend,
|
21
|
+
)
|
15
22
|
|
16
23
|
_is_cuda = is_cuda()
|
24
|
+
_is_cpu_amx_available = cpu_has_amx_support()
|
25
|
+
_is_cpu = is_cpu()
|
17
26
|
if _is_cuda:
|
18
27
|
from sgl_kernel import int8_scaled_mm
|
19
28
|
|
@@ -72,6 +81,13 @@ class W8A8Int8LinearMethod(LinearMethodBase):
|
|
72
81
|
self.quantization_config = quantization_config
|
73
82
|
|
74
83
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
84
|
+
if _is_cpu:
|
85
|
+
assert (
|
86
|
+
_is_cpu_amx_available
|
87
|
+
), "W8A8Int8LinearMethod on CPU requires that CPU has AMX support"
|
88
|
+
_amx_process_weight_after_loading(layer, ["weight"])
|
89
|
+
return
|
90
|
+
|
75
91
|
layer.weight = Parameter(layer.weight.t(), requires_grad=False)
|
76
92
|
layer.weight_scale = Parameter(layer.weight_scale.data, requires_grad=False)
|
77
93
|
|
@@ -112,6 +128,16 @@ class W8A8Int8LinearMethod(LinearMethodBase):
|
|
112
128
|
x: torch.Tensor,
|
113
129
|
bias: Optional[torch.Tensor] = None,
|
114
130
|
):
|
131
|
+
if use_intel_amx_backend(layer):
|
132
|
+
return torch.ops.sgl_kernel.int8_scaled_mm_with_quant(
|
133
|
+
x,
|
134
|
+
layer.weight,
|
135
|
+
layer.weight_scale,
|
136
|
+
bias,
|
137
|
+
x.dtype,
|
138
|
+
True, # is_vnni
|
139
|
+
)
|
140
|
+
|
115
141
|
x_q, x_scale = per_token_quant_int8(x)
|
116
142
|
|
117
143
|
return int8_scaled_mm(
|
@@ -206,6 +232,13 @@ class W8A8Int8MoEMethod:
|
|
206
232
|
layer.register_parameter("w2_input_scale", w2_input_scale)
|
207
233
|
|
208
234
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
235
|
+
if _is_cpu:
|
236
|
+
assert (
|
237
|
+
_is_cpu_amx_available
|
238
|
+
), "W8A8Int8MoEMethod on CPU requires that CPU has AMX support"
|
239
|
+
_amx_process_weight_after_loading(layer, ["w13_weight", "w2_weight"])
|
240
|
+
return
|
241
|
+
|
209
242
|
layer.w13_weight = Parameter(layer.w13_weight, requires_grad=False)
|
210
243
|
layer.w2_weight = Parameter(layer.w2_weight, requires_grad=False)
|
211
244
|
layer.w13_weight_scale = Parameter(
|
@@ -252,6 +285,24 @@ class W8A8Int8MoEMethod:
|
|
252
285
|
routed_scaling_factor=routed_scaling_factor,
|
253
286
|
)
|
254
287
|
|
288
|
+
if use_intel_amx_backend(layer):
|
289
|
+
return torch.ops.sgl_kernel.fused_experts_cpu(
|
290
|
+
x,
|
291
|
+
layer.w13_weight,
|
292
|
+
layer.w2_weight,
|
293
|
+
topk_weights,
|
294
|
+
topk_ids,
|
295
|
+
False, # inplace See [Note] inplace should be False in fused_experts.
|
296
|
+
True, # use_int8_w8a8
|
297
|
+
False, # use_fp8_w8a16
|
298
|
+
layer.w13_weight_scale, # w1_scale
|
299
|
+
layer.w2_weight_scale, # w2_scale
|
300
|
+
None, # block_size
|
301
|
+
layer.w13_input_scale, # a1_scale
|
302
|
+
layer.w2_input_scale, # a2_scale
|
303
|
+
True, # is_vnni
|
304
|
+
)
|
305
|
+
|
255
306
|
return fused_experts(
|
256
307
|
x,
|
257
308
|
layer.w13_weight,
|
@@ -660,7 +660,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
|
660
660
|
beta_slow: int = 1,
|
661
661
|
mscale: float = 1,
|
662
662
|
mscale_all_dim: float = 0,
|
663
|
-
device: Optional[str] = "cuda",
|
663
|
+
device: Optional[str] = "cuda" if not _is_npu else "npu",
|
664
664
|
) -> None:
|
665
665
|
self.scaling_factor = scaling_factor
|
666
666
|
self.extrapolation_factor = extrapolation_factor
|
@@ -679,7 +679,7 @@ class DeepseekScalingRotaryEmbedding(RotaryEmbedding):
|
|
679
679
|
)
|
680
680
|
|
681
681
|
# Re-dispatch
|
682
|
-
if _is_hip:
|
682
|
+
if _is_hip or _is_npu:
|
683
683
|
self._forward_method = self.forward_native
|
684
684
|
|
685
685
|
def _compute_inv_freq(self, scaling_factor: float) -> torch.Tensor:
|
@@ -13,6 +13,7 @@ from sglang.srt.distributed import (
|
|
13
13
|
get_tensor_model_parallel_world_size,
|
14
14
|
tensor_model_parallel_all_reduce,
|
15
15
|
)
|
16
|
+
from sglang.srt.layers.amx_utils import PackWeightMethod
|
16
17
|
from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
|
17
18
|
from sglang.srt.layers.parameter import BasevLLMParameter
|
18
19
|
from sglang.srt.layers.quantization.base_config import (
|
@@ -20,12 +21,7 @@ from sglang.srt.layers.quantization.base_config import (
|
|
20
21
|
QuantizeMethodBase,
|
21
22
|
method_has_implemented_embedding,
|
22
23
|
)
|
23
|
-
from sglang.srt.utils import
|
24
|
-
PackWeightMethod,
|
25
|
-
cpu_has_amx_support,
|
26
|
-
is_cpu,
|
27
|
-
set_weight_attrs,
|
28
|
-
)
|
24
|
+
from sglang.srt.utils import cpu_has_amx_support, is_cpu, set_weight_attrs
|
29
25
|
|
30
26
|
DEFAULT_VOCAB_PADDING_SIZE = 64
|
31
27
|
|
@@ -250,8 +246,16 @@ class VocabParallelEmbedding(torch.nn.Module):
|
|
250
246
|
self.tp_size = 1
|
251
247
|
|
252
248
|
self.num_embeddings = num_embeddings
|
253
|
-
self.padding_size = padding_size
|
254
249
|
self.org_vocab_size = org_num_embeddings or num_embeddings
|
250
|
+
|
251
|
+
# Support the case where the vocab size is not divisible by the TP size.
|
252
|
+
if (
|
253
|
+
_is_cpu
|
254
|
+
and pad_vocab_size(self.org_vocab_size, padding_size) % self.tp_size != 0
|
255
|
+
):
|
256
|
+
padding_size *= self.tp_size
|
257
|
+
self.padding_size = padding_size
|
258
|
+
|
255
259
|
num_added_embeddings = num_embeddings - self.org_vocab_size
|
256
260
|
self.use_presharded_weights = use_presharded_weights
|
257
261
|
if use_presharded_weights:
|
sglang/srt/lora/lora.py
CHANGED
@@ -65,7 +65,7 @@ class LoRAAdapter(nn.Module):
|
|
65
65
|
self.layers: List[LoRALayer] = nn.ModuleList(
|
66
66
|
[
|
67
67
|
LoRALayer(config, base_hf_config)
|
68
|
-
for
|
68
|
+
for _ in range(base_hf_config.num_hidden_layers)
|
69
69
|
]
|
70
70
|
)
|
71
71
|
|
@@ -88,10 +88,9 @@ class LoRAAdapter(nn.Module):
|
|
88
88
|
else:
|
89
89
|
self.weights[name] = loaded_weight.cpu()
|
90
90
|
|
91
|
-
#
|
92
|
-
for
|
93
|
-
|
94
|
-
weight_names = [name for name, _ in layer.weights.items()]
|
91
|
+
# normalize kv_proj and gate_up_proj
|
92
|
+
for layer in self.layers:
|
93
|
+
weight_names = list(layer.weights.keys())
|
95
94
|
self.normalize_qkv_proj(weight_names, layer.weights)
|
96
95
|
self.normalize_gate_up_proj(weight_names, layer.weights)
|
97
96
|
|