sglang 0.4.9.post1__py3-none-any.whl → 0.4.9.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/srt/configs/model_config.py +24 -1
- sglang/srt/conversation.py +21 -2
- sglang/srt/disaggregation/ascend/__init__.py +6 -0
- sglang/srt/disaggregation/ascend/conn.py +44 -0
- sglang/srt/disaggregation/ascend/transfer_engine.py +58 -0
- sglang/srt/disaggregation/mooncake/conn.py +15 -14
- sglang/srt/disaggregation/mooncake/transfer_engine.py +17 -8
- sglang/srt/disaggregation/utils.py +25 -3
- sglang/srt/entrypoints/engine.py +1 -1
- sglang/srt/entrypoints/http_server.py +1 -0
- sglang/srt/entrypoints/openai/protocol.py +11 -0
- sglang/srt/entrypoints/openai/serving_chat.py +7 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/kimik2_detector.py +220 -0
- sglang/srt/hf_transformers_utils.py +18 -0
- sglang/srt/jinja_template_utils.py +8 -0
- sglang/srt/layers/communicator.py +17 -4
- sglang/srt/layers/linear.py +12 -2
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -1
- sglang/srt/layers/moe/ep_moe/layer.py +2 -1
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -2
- sglang/srt/layers/moe/topk.py +8 -2
- sglang/srt/layers/parameter.py +19 -3
- sglang/srt/layers/quantization/fp8_kernel.py +2 -2
- sglang/srt/layers/quantization/moe_wna16.py +1 -2
- sglang/srt/layers/quantization/w8a8_int8.py +738 -14
- sglang/srt/managers/io_struct.py +27 -2
- sglang/srt/managers/mm_utils.py +55 -94
- sglang/srt/managers/schedule_batch.py +16 -5
- sglang/srt/managers/scheduler.py +21 -1
- sglang/srt/managers/tokenizer_manager.py +16 -0
- sglang/srt/mem_cache/memory_pool.py +65 -40
- sglang/srt/model_executor/forward_batch_info.py +13 -1
- sglang/srt/model_loader/loader.py +23 -12
- sglang/srt/models/deepseek_janus_pro.py +1 -1
- sglang/srt/models/deepseek_v2.py +62 -17
- sglang/srt/models/deepseek_vl2.py +1 -1
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +6 -3
- sglang/srt/models/internvl.py +8 -2
- sglang/srt/models/kimi_vl.py +8 -2
- sglang/srt/models/llama.py +2 -0
- sglang/srt/models/llava.py +3 -1
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/minicpmo.py +1 -2
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mixtral_quant.py +4 -0
- sglang/srt/models/mllama4.py +13 -4
- sglang/srt/models/phi4mm.py +8 -2
- sglang/srt/models/phimoe.py +553 -0
- sglang/srt/models/qwen2.py +2 -0
- sglang/srt/models/qwen2_5_vl.py +10 -7
- sglang/srt/models/qwen2_vl.py +12 -1
- sglang/srt/models/vila.py +8 -2
- sglang/srt/multimodal/processors/base_processor.py +197 -137
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +1 -1
- sglang/srt/multimodal/processors/gemma3.py +4 -2
- sglang/srt/multimodal/processors/gemma3n.py +1 -1
- sglang/srt/multimodal/processors/internvl.py +1 -1
- sglang/srt/multimodal/processors/janus_pro.py +1 -1
- sglang/srt/multimodal/processors/kimi_vl.py +1 -1
- sglang/srt/multimodal/processors/minicpm.py +4 -3
- sglang/srt/multimodal/processors/mllama4.py +1 -1
- sglang/srt/multimodal/processors/phi4mm.py +1 -1
- sglang/srt/multimodal/processors/pixtral.py +1 -1
- sglang/srt/multimodal/processors/qwen_vl.py +203 -80
- sglang/srt/multimodal/processors/vila.py +1 -1
- sglang/srt/server_args.py +11 -4
- sglang/srt/utils.py +154 -31
- sglang/version.py +1 -1
- {sglang-0.4.9.post1.dist-info → sglang-0.4.9.post2.dist-info}/METADATA +4 -3
- {sglang-0.4.9.post1.dist-info → sglang-0.4.9.post2.dist-info}/RECORD +75 -70
- {sglang-0.4.9.post1.dist-info → sglang-0.4.9.post2.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post1.dist-info → sglang-0.4.9.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post1.dist-info → sglang-0.4.9.post2.dist-info}/top_level.txt +0 -0
@@ -1,21 +1,37 @@
|
|
1
|
-
|
1
|
+
import importlib
|
2
|
+
import sys
|
3
|
+
from types import MappingProxyType
|
4
|
+
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union
|
2
5
|
|
3
6
|
import torch
|
4
7
|
from torch.nn.parameter import Parameter
|
5
8
|
|
6
|
-
from sglang.srt.distributed import
|
9
|
+
from sglang.srt.distributed import (
|
10
|
+
get_tensor_model_parallel_rank,
|
11
|
+
get_tensor_model_parallel_world_size,
|
12
|
+
)
|
7
13
|
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
|
8
|
-
from sglang.srt.layers.linear import
|
9
|
-
|
14
|
+
from sglang.srt.layers.linear import (
|
15
|
+
LinearMethodBase,
|
16
|
+
RowParallelLinear,
|
17
|
+
UnquantizedLinearMethod,
|
18
|
+
)
|
19
|
+
from sglang.srt.layers.parameter import (
|
20
|
+
ChannelQuantScaleParameter,
|
21
|
+
ModelWeightParameter,
|
22
|
+
PerTensorScaleParameter,
|
23
|
+
)
|
10
24
|
from sglang.srt.layers.quantization.base_config import (
|
11
25
|
QuantizationConfig,
|
12
26
|
QuantizeMethodBase,
|
13
27
|
)
|
14
28
|
from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
|
15
29
|
from sglang.srt.utils import (
|
30
|
+
apply_module_patch,
|
16
31
|
cpu_has_amx_support,
|
17
32
|
is_cpu,
|
18
33
|
is_cuda,
|
34
|
+
is_npu,
|
19
35
|
set_weight_attrs,
|
20
36
|
use_intel_amx_backend,
|
21
37
|
)
|
@@ -25,6 +41,134 @@ _is_cpu_amx_available = cpu_has_amx_support()
|
|
25
41
|
_is_cpu = is_cpu()
|
26
42
|
if _is_cuda:
|
27
43
|
from sgl_kernel import int8_scaled_mm
|
44
|
+
_is_npu = is_npu()
|
45
|
+
|
46
|
+
if _is_npu:
|
47
|
+
import torch_npu
|
48
|
+
|
49
|
+
try:
|
50
|
+
from mindie_turbo import _ops as ops
|
51
|
+
from mindie_turbo.quantize.quant_utils import quant_per_tensor
|
52
|
+
except ImportError:
|
53
|
+
useMindIETurbo = False
|
54
|
+
else:
|
55
|
+
useMindIETurbo = True
|
56
|
+
|
57
|
+
|
58
|
+
# func refers to RMSNorm.__init__
|
59
|
+
def npu_wrapper_rmsnorm_init(func):
|
60
|
+
def init(self, hidden_size: int, **extra_args) -> None:
|
61
|
+
func(self, hidden_size, **extra_args)
|
62
|
+
self.ignore_anti = True
|
63
|
+
# The Ascend w8a8_int8 quantization requires adding a bias in rmsnorm
|
64
|
+
self.bias = torch.nn.Parameter(torch.zeros(hidden_size), requires_grad=False)
|
65
|
+
|
66
|
+
return init
|
67
|
+
|
68
|
+
|
69
|
+
# func refers to RMSNorm.forward_oot
|
70
|
+
def npu_wrapper_rmsnorm_forward(func):
|
71
|
+
def _rmsnorm_forward_oot(
|
72
|
+
self,
|
73
|
+
x: torch.Tensor,
|
74
|
+
residual: Optional[torch.Tensor] = None,
|
75
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
|
76
|
+
if not x.is_contiguous():
|
77
|
+
x = x.contiguous()
|
78
|
+
original_dtype = x.dtype
|
79
|
+
x = x.to(torch.float32)
|
80
|
+
if residual is not None:
|
81
|
+
x = x + residual.to(torch.float32)
|
82
|
+
residual = x.to(original_dtype)
|
83
|
+
|
84
|
+
x = (
|
85
|
+
torch_npu.npu_rms_norm(
|
86
|
+
x, self.weight.to(torch.float32), self.variance_epsilon
|
87
|
+
)[0]
|
88
|
+
+ self.bias
|
89
|
+
)
|
90
|
+
|
91
|
+
if residual is None:
|
92
|
+
return x.to(original_dtype)
|
93
|
+
return x.to(original_dtype), residual
|
94
|
+
|
95
|
+
return _rmsnorm_forward_oot
|
96
|
+
|
97
|
+
|
98
|
+
def npu_fused_experts(
|
99
|
+
hidden_states: torch.Tensor,
|
100
|
+
w13: torch.Tensor,
|
101
|
+
w13_scale: torch.Tensor,
|
102
|
+
w2: torch.Tensor,
|
103
|
+
w2_scale: torch.Tensor,
|
104
|
+
topk_weights: torch.Tensor,
|
105
|
+
topk_ids: torch.Tensor,
|
106
|
+
top_k: int,
|
107
|
+
):
|
108
|
+
original_shape = hidden_states.shape
|
109
|
+
original_dtype = hidden_states.dtype
|
110
|
+
scale_dtype = original_dtype if original_dtype == torch.bfloat16 else torch.float32
|
111
|
+
if len(original_shape) == 3:
|
112
|
+
hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
|
113
|
+
num_tokens = hidden_states.shape[0]
|
114
|
+
num_experts = w13.shape[0]
|
115
|
+
row_idx_len = num_tokens * top_k
|
116
|
+
row_idx = (
|
117
|
+
torch.arange(0, row_idx_len, dtype=torch.int32, device=topk_weights.device)
|
118
|
+
.view(top_k, -1)
|
119
|
+
.permute(1, 0)
|
120
|
+
.contiguous()
|
121
|
+
)
|
122
|
+
hidden_states, expanded_row_idx, expanded_expert_idx = (
|
123
|
+
torch_npu.npu_moe_init_routing(
|
124
|
+
hidden_states, row_idx=row_idx, expert_idx=topk_ids, active_num=num_tokens
|
125
|
+
)
|
126
|
+
)
|
127
|
+
expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
|
128
|
+
expanded_expert_idx, num_experts
|
129
|
+
)
|
130
|
+
expert_tokens = expert_tokens.to(torch.int64)
|
131
|
+
# gmm1: gate_up_proj
|
132
|
+
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states)
|
133
|
+
hidden_states = torch_npu.npu_grouped_matmul(
|
134
|
+
x=[hidden_states],
|
135
|
+
weight=[w13],
|
136
|
+
scale=[w13_scale.to(scale_dtype)],
|
137
|
+
per_token_scale=[pertoken_scale],
|
138
|
+
split_item=2,
|
139
|
+
group_list_type=0,
|
140
|
+
group_type=0,
|
141
|
+
group_list=expert_tokens,
|
142
|
+
output_dtype=original_dtype,
|
143
|
+
)[0]
|
144
|
+
# act_fn: swiglu
|
145
|
+
hidden_states = torch_npu.npu_swiglu(hidden_states)
|
146
|
+
hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states)
|
147
|
+
# gmm2: down_proj
|
148
|
+
hidden_states = torch_npu.npu_grouped_matmul(
|
149
|
+
x=[hidden_states],
|
150
|
+
weight=[w2],
|
151
|
+
scale=[w2_scale.to(scale_dtype)],
|
152
|
+
per_token_scale=[pertoken_scale],
|
153
|
+
split_item=2,
|
154
|
+
group_list_type=0,
|
155
|
+
group_type=0,
|
156
|
+
group_list=expert_tokens,
|
157
|
+
output_dtype=original_dtype,
|
158
|
+
)[0]
|
159
|
+
|
160
|
+
final_hidden_states = torch_npu.npu_moe_finalize_routing(
|
161
|
+
hidden_states,
|
162
|
+
skip1=None,
|
163
|
+
skip2=None,
|
164
|
+
bias=None,
|
165
|
+
scales=topk_weights,
|
166
|
+
expanded_src_to_dst_row=expanded_row_idx,
|
167
|
+
export_for_source_row=topk_ids,
|
168
|
+
)
|
169
|
+
if len(original_shape) == 3:
|
170
|
+
final_hidden_states = final_hidden_states.view(original_shape)
|
171
|
+
return final_hidden_states
|
28
172
|
|
29
173
|
|
30
174
|
class W8A8Int8Config(QuantizationConfig):
|
@@ -34,16 +178,47 @@ class W8A8Int8Config(QuantizationConfig):
|
|
34
178
|
- Activation: dynamic, per-token, symmetric
|
35
179
|
"""
|
36
180
|
|
37
|
-
def __init__(self):
|
38
|
-
|
181
|
+
def __init__(self, quant_config: Dict[str, Any]):
|
182
|
+
super().__init__()
|
183
|
+
self.quant_description = quant_config
|
184
|
+
self.is_dynamic = quant_config.get("is_dynamic", False)
|
185
|
+
if _is_npu:
|
186
|
+
if (
|
187
|
+
"packed_modules_mapping" in quant_config
|
188
|
+
and quant_config["packed_modules_mapping"] is not None
|
189
|
+
):
|
190
|
+
self.packed_modules_mapping = quant_config["packed_modules_mapping"]
|
191
|
+
|
192
|
+
# Ascend w8a8_int8 quantization with bias, use wrappers to isolate the effects between models
|
193
|
+
for name in self.quant_description.keys():
|
194
|
+
if "norm.bias" in name:
|
195
|
+
apply_module_patch(
|
196
|
+
"sglang.srt.layers.layernorm.RMSNorm",
|
197
|
+
"__init__",
|
198
|
+
[npu_wrapper_rmsnorm_init],
|
199
|
+
)
|
200
|
+
apply_module_patch(
|
201
|
+
"sglang.srt.layers.layernorm.RMSNorm",
|
202
|
+
"forward_npu",
|
203
|
+
[npu_wrapper_rmsnorm_forward],
|
204
|
+
)
|
39
205
|
|
40
206
|
@classmethod
|
41
207
|
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
42
|
-
return
|
208
|
+
return (
|
209
|
+
[torch.float16, torch.bfloat16]
|
210
|
+
if not _is_npu
|
211
|
+
else [torch.int8, torch.float16, torch.bfloat16]
|
212
|
+
)
|
43
213
|
|
44
214
|
@classmethod
|
45
215
|
def get_min_capability(cls) -> int:
|
46
|
-
|
216
|
+
if _is_npu:
|
217
|
+
raise NotImplementedError(
|
218
|
+
'NPU hardware does not support "get_min_capability" feature.'
|
219
|
+
)
|
220
|
+
else:
|
221
|
+
return 75
|
47
222
|
|
48
223
|
@classmethod
|
49
224
|
def get_name(self) -> str:
|
@@ -55,7 +230,7 @@ class W8A8Int8Config(QuantizationConfig):
|
|
55
230
|
|
56
231
|
@classmethod
|
57
232
|
def from_config(cls, config: Dict[str, Any]) -> "W8A8Int8Config":
|
58
|
-
return cls()
|
233
|
+
return cls(config)
|
59
234
|
|
60
235
|
def get_quant_method(
|
61
236
|
self,
|
@@ -65,11 +240,65 @@ class W8A8Int8Config(QuantizationConfig):
|
|
65
240
|
from sglang.srt.layers.linear import LinearBase
|
66
241
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
67
242
|
|
68
|
-
if
|
69
|
-
|
70
|
-
|
71
|
-
|
72
|
-
|
243
|
+
if _is_npu:
|
244
|
+
if isinstance(layer, LinearBase):
|
245
|
+
prefix_in_quant_config = prefix
|
246
|
+
proj_name = prefix.split(".")[-1]
|
247
|
+
if proj_name in self.packed_modules_mapping:
|
248
|
+
prefix_in_quant_config = prefix.replace(
|
249
|
+
proj_name, self.packed_modules_mapping[proj_name][0]
|
250
|
+
)
|
251
|
+
self.is_dynamic = (
|
252
|
+
self.quant_description[prefix_in_quant_config + ".weight"]
|
253
|
+
== "W8A8_DYNAMIC"
|
254
|
+
)
|
255
|
+
if self.is_layer_skipped(prefix, self.packed_modules_mapping):
|
256
|
+
return UnquantizedLinearMethod()
|
257
|
+
return (
|
258
|
+
NPU_W8A8DynamicLinearMethod(self)
|
259
|
+
if self.is_dynamic
|
260
|
+
else NPU_W8A8LinearMethod(self)
|
261
|
+
)
|
262
|
+
elif isinstance(layer, FusedMoE):
|
263
|
+
return NPU_W8A8MoEMethod(self)
|
264
|
+
return None
|
265
|
+
else:
|
266
|
+
if isinstance(layer, LinearBase):
|
267
|
+
return W8A8Int8LinearMethod(self)
|
268
|
+
elif isinstance(layer, FusedMoE):
|
269
|
+
return W8A8Int8MoEMethod(self)
|
270
|
+
return None
|
271
|
+
|
272
|
+
def is_layer_skipped(
|
273
|
+
self, prefix: str, fused_mapping: Mapping[str, List[str]] = MappingProxyType({})
|
274
|
+
):
|
275
|
+
# adapted from vllm.model_executor.layers.quantization.utils.quant_utils.is_layer_skipped
|
276
|
+
proj_name = prefix.split(".")[-1]
|
277
|
+
if proj_name in fused_mapping:
|
278
|
+
shard_prefixes = [
|
279
|
+
prefix.replace(proj_name, shard_proj_name)
|
280
|
+
for shard_proj_name in fused_mapping[proj_name]
|
281
|
+
]
|
282
|
+
|
283
|
+
is_skipped = None
|
284
|
+
for shard_prefix in shard_prefixes:
|
285
|
+
is_shard_skipped = (
|
286
|
+
self.quant_description[shard_prefix + ".weight"] == "FLOAT"
|
287
|
+
)
|
288
|
+
|
289
|
+
if is_skipped is None:
|
290
|
+
is_skipped = is_shard_skipped
|
291
|
+
elif is_shard_skipped != is_skipped:
|
292
|
+
raise ValueError(
|
293
|
+
f"Detected some but not all shards of {prefix} "
|
294
|
+
"are quantized. All shards of fused layers "
|
295
|
+
"to have the same precision."
|
296
|
+
)
|
297
|
+
else:
|
298
|
+
is_skipped = self.quant_description[prefix + ".weight"] == "FLOAT"
|
299
|
+
|
300
|
+
assert is_skipped is not None
|
301
|
+
return is_skipped
|
73
302
|
|
74
303
|
def get_scaled_act_names(self) -> List[str]:
|
75
304
|
return []
|
@@ -321,3 +550,498 @@ class W8A8Int8MoEMethod:
|
|
321
550
|
no_combine=no_combine,
|
322
551
|
routed_scaling_factor=routed_scaling_factor,
|
323
552
|
)
|
553
|
+
|
554
|
+
|
555
|
+
class NPU_W8A8LinearMethodImpl:
|
556
|
+
"""Linear method for NPU W8A8."""
|
557
|
+
|
558
|
+
def __init__(self) -> None:
|
559
|
+
# aclnn quant matmul requires to transpose matrix B, set to true by default.
|
560
|
+
self.transpose_weight = True
|
561
|
+
|
562
|
+
@staticmethod
|
563
|
+
def get_weight(
|
564
|
+
input_size: int,
|
565
|
+
output_size: int,
|
566
|
+
params_dtype: torch.dtype = torch.bfloat16,
|
567
|
+
) -> Dict[str, Any]:
|
568
|
+
params_dict = {"weight": torch.empty(output_size, input_size, dtype=torch.int8)}
|
569
|
+
return params_dict
|
570
|
+
|
571
|
+
@staticmethod
|
572
|
+
def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
|
573
|
+
params_dict = {}
|
574
|
+
params_dict["input_scale"] = torch.empty(1, dtype=params_dtype)
|
575
|
+
params_dict["input_offset"] = torch.empty(1, dtype=torch.int8)
|
576
|
+
return params_dict
|
577
|
+
|
578
|
+
@staticmethod
|
579
|
+
def get_perchannel_param(
|
580
|
+
output_size: int,
|
581
|
+
params_dtype: torch.dtype,
|
582
|
+
) -> Dict[str, Any]:
|
583
|
+
params_dict = {}
|
584
|
+
params_dict["quant_bias"] = torch.empty(output_size, dtype=torch.int32)
|
585
|
+
if params_dtype == torch.bfloat16:
|
586
|
+
params_dict["deq_scale"] = torch.empty(output_size, dtype=torch.float32)
|
587
|
+
elif params_dtype == torch.float16:
|
588
|
+
params_dict["deq_scale"] = torch.empty(output_size, dtype=torch.int64)
|
589
|
+
params_dict["weight_scale"] = torch.empty(output_size, 1, dtype=params_dtype)
|
590
|
+
params_dict["weight_offset"] = torch.empty(output_size, 1, dtype=params_dtype)
|
591
|
+
return params_dict
|
592
|
+
|
593
|
+
@staticmethod
|
594
|
+
def apply(
|
595
|
+
layer: torch.nn.Module,
|
596
|
+
x: torch.Tensor,
|
597
|
+
bias: Optional[torch.Tensor] = None,
|
598
|
+
tp_rank: Optional[int] = 0,
|
599
|
+
) -> torch.Tensor:
|
600
|
+
original_dtype = x.dtype
|
601
|
+
if original_dtype != torch.int8:
|
602
|
+
x = torch_npu.npu_quantize(
|
603
|
+
x,
|
604
|
+
layer.aclnn_input_scale,
|
605
|
+
layer.aclnn_input_offset,
|
606
|
+
torch.qint8,
|
607
|
+
-1,
|
608
|
+
True,
|
609
|
+
)
|
610
|
+
|
611
|
+
quant_bias = layer.quant_bias if tp_rank == 0 else None
|
612
|
+
return torch_npu.npu_quant_matmul(
|
613
|
+
x,
|
614
|
+
layer.weight,
|
615
|
+
layer.deq_scale,
|
616
|
+
bias=quant_bias,
|
617
|
+
output_dtype=original_dtype,
|
618
|
+
)
|
619
|
+
|
620
|
+
def process_weights_after_loading(self, layer):
|
621
|
+
expanding_factor = layer.weight.data.shape[1]
|
622
|
+
layer.aclnn_input_scale = torch.nn.Parameter(
|
623
|
+
layer.input_scale.data.repeat(expanding_factor).to(device="npu"),
|
624
|
+
requires_grad=False,
|
625
|
+
)
|
626
|
+
layer.aclnn_input_offset = torch.nn.Parameter(
|
627
|
+
layer.input_offset.data.repeat(expanding_factor).to(device="npu"),
|
628
|
+
requires_grad=False,
|
629
|
+
)
|
630
|
+
if self.transpose_weight:
|
631
|
+
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
|
632
|
+
layer.weight_scale.data = torch.flatten(layer.weight_scale.data)
|
633
|
+
layer.weight_offset.data = torch.flatten(layer.weight_offset.data)
|
634
|
+
|
635
|
+
|
636
|
+
class NPU_W8A8LinearMethodMTImpl:
|
637
|
+
"""Linear method for NPU W8A8."""
|
638
|
+
|
639
|
+
def __init__(self) -> None:
|
640
|
+
self.transpose_weight = True
|
641
|
+
|
642
|
+
@staticmethod
|
643
|
+
def get_weight(
|
644
|
+
input_size: int,
|
645
|
+
output_size: int,
|
646
|
+
params_dtype: torch.dtype = torch.bfloat16,
|
647
|
+
) -> Dict[str, Any]:
|
648
|
+
params_dict = {"weight": torch.empty(output_size, input_size, dtype=torch.int8)}
|
649
|
+
return params_dict
|
650
|
+
|
651
|
+
@staticmethod
|
652
|
+
def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
|
653
|
+
params_dict = {}
|
654
|
+
params_dict["input_scale"] = torch.empty(1, dtype=params_dtype)
|
655
|
+
params_dict["input_offset"] = torch.empty(1, dtype=torch.int8)
|
656
|
+
return params_dict
|
657
|
+
|
658
|
+
@staticmethod
|
659
|
+
def get_perchannel_param(
|
660
|
+
output_size: int,
|
661
|
+
params_dtype: torch.dtype,
|
662
|
+
) -> Dict[str, Any]:
|
663
|
+
params_dict = {}
|
664
|
+
params_dict["quant_bias"] = torch.empty(output_size, dtype=torch.int32)
|
665
|
+
if params_dtype == torch.bfloat16:
|
666
|
+
params_dict["deq_scale"] = torch.empty(output_size, dtype=torch.float32)
|
667
|
+
elif params_dtype == torch.float16:
|
668
|
+
params_dict["deq_scale"] = torch.empty(output_size, dtype=torch.int64)
|
669
|
+
params_dict["weight_scale"] = torch.empty(output_size, 1, dtype=params_dtype)
|
670
|
+
params_dict["weight_offset"] = torch.empty(output_size, 1, dtype=params_dtype)
|
671
|
+
return params_dict
|
672
|
+
|
673
|
+
@staticmethod
|
674
|
+
def apply(
|
675
|
+
layer: torch.nn.Module,
|
676
|
+
x: torch.Tensor,
|
677
|
+
bias: Optional[torch.Tensor] = None,
|
678
|
+
tp_rank: Optional[int] = 0,
|
679
|
+
) -> torch.Tensor:
|
680
|
+
original_dtype = x.dtype
|
681
|
+
if original_dtype != torch.int8:
|
682
|
+
x = quant_per_tensor(x, layer.input_scale, layer.input_offset)
|
683
|
+
|
684
|
+
quant_bias = layer.quant_bias if tp_rank == 0 else None
|
685
|
+
return ops.quant_matmul(
|
686
|
+
x=x, weight=layer.weight, deq_scale=layer.deq_scale, deq_bias=quant_bias
|
687
|
+
)
|
688
|
+
|
689
|
+
def process_weights_after_loading(self, layer):
|
690
|
+
layer.aclnn_deq_scale = torch.nn.Parameter(
|
691
|
+
torch_npu.npu_trans_quant_param(layer.deq_scale.npu()).to(device="npu"),
|
692
|
+
requires_grad=False,
|
693
|
+
)
|
694
|
+
|
695
|
+
|
696
|
+
class NPU_W8A8LinearMethod(LinearMethodBase):
|
697
|
+
"""Linear method for NPU quantization.
|
698
|
+
|
699
|
+
This class search for specific quantization
|
700
|
+
implementation supported on NPU hardware for linear methods.
|
701
|
+
|
702
|
+
Args:
|
703
|
+
quant_config: The NPU quantization config.
|
704
|
+
"""
|
705
|
+
|
706
|
+
def __init__(self, quantization_config: W8A8Int8Config) -> None:
|
707
|
+
self.quantization_config = quantization_config
|
708
|
+
self.quant_method = (
|
709
|
+
NPU_W8A8LinearMethodMTImpl()
|
710
|
+
if useMindIETurbo
|
711
|
+
else NPU_W8A8LinearMethodImpl()
|
712
|
+
)
|
713
|
+
|
714
|
+
def create_weights(
|
715
|
+
self,
|
716
|
+
layer: torch.nn.Module,
|
717
|
+
input_size_per_partition: int,
|
718
|
+
output_partition_sizes: List[int],
|
719
|
+
input_size: int,
|
720
|
+
output_size: int,
|
721
|
+
params_dtype: torch.dtype,
|
722
|
+
**extra_weight_attrs,
|
723
|
+
) -> None:
|
724
|
+
output_size_per_partition = sum(output_partition_sizes)
|
725
|
+
weight_loader = extra_weight_attrs.get("weight_loader")
|
726
|
+
|
727
|
+
weight_dict = self.quant_method.get_weight(
|
728
|
+
input_size_per_partition, output_size_per_partition, params_dtype
|
729
|
+
)
|
730
|
+
for weight_name, weight_param in weight_dict.items():
|
731
|
+
param = torch.nn.Parameter(weight_param, requires_grad=False)
|
732
|
+
set_weight_attrs(param, {"input_dim": 1, "output_dim": 0})
|
733
|
+
layer.register_parameter(weight_name, param)
|
734
|
+
set_weight_attrs(param, extra_weight_attrs)
|
735
|
+
|
736
|
+
pertensor_dict = self.quant_method.get_pertensor_param(params_dtype)
|
737
|
+
for pertensor_name, pertensor_param in pertensor_dict.items():
|
738
|
+
param = PerTensorScaleParameter(
|
739
|
+
data=pertensor_param, weight_loader=weight_loader
|
740
|
+
)
|
741
|
+
# disable warning
|
742
|
+
param.ignore_warning = True
|
743
|
+
layer.register_parameter(pertensor_name, param)
|
744
|
+
|
745
|
+
perchannel_dict = self.quant_method.get_perchannel_param(
|
746
|
+
output_size_per_partition, params_dtype
|
747
|
+
)
|
748
|
+
for perchannel_name, perchannel_param in perchannel_dict.items():
|
749
|
+
param = torch.nn.Parameter(perchannel_param, requires_grad=False)
|
750
|
+
set_weight_attrs(param, {"output_dim": 0})
|
751
|
+
layer.register_parameter(perchannel_name, param)
|
752
|
+
set_weight_attrs(param, extra_weight_attrs)
|
753
|
+
|
754
|
+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
755
|
+
if hasattr(self.quant_method, "process_weights_after_loading"):
|
756
|
+
self.quant_method.process_weights_after_loading(layer)
|
757
|
+
|
758
|
+
def apply(
|
759
|
+
self,
|
760
|
+
layer: torch.nn.Module,
|
761
|
+
x: torch.Tensor,
|
762
|
+
bias: Optional[torch.Tensor] = None,
|
763
|
+
) -> torch.Tensor:
|
764
|
+
if isinstance(layer, RowParallelLinear):
|
765
|
+
tp_rank = get_tensor_model_parallel_rank()
|
766
|
+
return self.quant_method.apply(layer, x, bias, tp_rank)
|
767
|
+
return self.quant_method.apply(layer, x, bias)
|
768
|
+
|
769
|
+
|
770
|
+
class NPU_W8A8DynamicLinearMethodImpl:
|
771
|
+
"""Linear method for NPU W8A8_DYNAMIC."""
|
772
|
+
|
773
|
+
def __init__(self):
|
774
|
+
self.transpose_weight = True
|
775
|
+
|
776
|
+
@staticmethod
|
777
|
+
def get_weight(
|
778
|
+
input_size: int, output_size: int, params_dtype: torch.dtype
|
779
|
+
) -> Dict[str, Any]:
|
780
|
+
params_dict = {"weight": torch.empty(output_size, input_size, dtype=torch.int8)}
|
781
|
+
return params_dict
|
782
|
+
|
783
|
+
@staticmethod
|
784
|
+
def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
|
785
|
+
return {}
|
786
|
+
|
787
|
+
@staticmethod
|
788
|
+
def get_perchannel_param(
|
789
|
+
output_size: int,
|
790
|
+
params_dtype: torch.dtype,
|
791
|
+
) -> Dict[str, Any]:
|
792
|
+
params_dict = {}
|
793
|
+
params_dict["weight_scale"] = torch.empty(output_size, 1, dtype=params_dtype)
|
794
|
+
params_dict["weight_offset"] = torch.empty(output_size, 1, dtype=params_dtype)
|
795
|
+
return params_dict
|
796
|
+
|
797
|
+
@staticmethod
|
798
|
+
def apply(
|
799
|
+
layer: torch.nn.Module,
|
800
|
+
x: torch.Tensor,
|
801
|
+
bias: Optional[torch.Tensor] = None,
|
802
|
+
tp_rank: Optional[int] = 0,
|
803
|
+
) -> torch.Tensor:
|
804
|
+
original_dtype = x.dtype
|
805
|
+
# use ATB quantize
|
806
|
+
quant_out, dynamic_scale = torch_npu.npu_dynamic_quant(x)
|
807
|
+
return torch_npu.npu_quant_matmul(
|
808
|
+
quant_out,
|
809
|
+
layer.weight,
|
810
|
+
layer.weight_scale,
|
811
|
+
pertoken_scale=dynamic_scale,
|
812
|
+
bias=bias,
|
813
|
+
output_dtype=original_dtype,
|
814
|
+
)
|
815
|
+
|
816
|
+
def process_weights_after_loading(self, layer):
|
817
|
+
if self.transpose_weight:
|
818
|
+
layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
|
819
|
+
layer.weight_scale.data = layer.weight_scale.data.flatten()
|
820
|
+
layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32)
|
821
|
+
layer.weight_offset.data = layer.weight_offset.data.flatten()
|
822
|
+
|
823
|
+
|
824
|
+
class NPU_W8A8DynamicLinearMethod(LinearMethodBase):
|
825
|
+
"""Linear method for NPU quantization.
|
826
|
+
|
827
|
+
This class search for specific quantization
|
828
|
+
implementations supported on NPU hardware for linear methods.
|
829
|
+
|
830
|
+
Args:
|
831
|
+
quant_config: The NPU quantization config.
|
832
|
+
"""
|
833
|
+
|
834
|
+
def __init__(self, quantization_config: W8A8Int8Config) -> None:
|
835
|
+
self.quantization_config = quantization_config
|
836
|
+
self.quant_method = NPU_W8A8DynamicLinearMethodImpl()
|
837
|
+
|
838
|
+
def create_weights(
|
839
|
+
self,
|
840
|
+
layer: torch.nn.Module,
|
841
|
+
input_size_per_partition: int,
|
842
|
+
output_partition_sizes: List[int],
|
843
|
+
input_size: int,
|
844
|
+
output_size: int,
|
845
|
+
params_dtype: torch.dtype,
|
846
|
+
**extra_weight_attrs,
|
847
|
+
) -> None:
|
848
|
+
output_size_per_partition = sum(output_partition_sizes)
|
849
|
+
weight_loader = extra_weight_attrs.get("weight_loader")
|
850
|
+
|
851
|
+
weight_dict = self.quant_method.get_weight(
|
852
|
+
input_size_per_partition, output_size_per_partition, params_dtype
|
853
|
+
)
|
854
|
+
for weight_name, weight_param in weight_dict.items():
|
855
|
+
param = torch.nn.Parameter(weight_param, requires_grad=False)
|
856
|
+
set_weight_attrs(param, {"input_dim": 1, "output_dim": 0})
|
857
|
+
layer.register_parameter(weight_name, param)
|
858
|
+
set_weight_attrs(param, extra_weight_attrs)
|
859
|
+
|
860
|
+
pertensor_dict = self.quant_method.get_pertensor_param(params_dtype)
|
861
|
+
for pertensor_name, pertensor_param in pertensor_dict.items():
|
862
|
+
param = PerTensorScaleParameter(
|
863
|
+
data=pertensor_param, weight_loader=weight_loader
|
864
|
+
)
|
865
|
+
# disable warning
|
866
|
+
param.ignore_warning = True
|
867
|
+
layer.register_parameter(pertensor_name, param)
|
868
|
+
|
869
|
+
perchannel_dict = self.quant_method.get_perchannel_param(
|
870
|
+
output_size_per_partition, params_dtype
|
871
|
+
)
|
872
|
+
for perchannel_name, perchannel_param in perchannel_dict.items():
|
873
|
+
param = torch.nn.Parameter(perchannel_param, requires_grad=False)
|
874
|
+
set_weight_attrs(param, {"output_dim": 0})
|
875
|
+
layer.register_parameter(perchannel_name, param)
|
876
|
+
set_weight_attrs(param, extra_weight_attrs)
|
877
|
+
|
878
|
+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
879
|
+
if hasattr(self.quant_method, "process_weights_after_loading"):
|
880
|
+
self.quant_method.process_weights_after_loading(layer)
|
881
|
+
|
882
|
+
def apply(
|
883
|
+
self,
|
884
|
+
layer: torch.nn.Module,
|
885
|
+
x: torch.Tensor,
|
886
|
+
bias: Optional[torch.Tensor] = None,
|
887
|
+
) -> torch.Tensor:
|
888
|
+
if isinstance(layer, RowParallelLinear):
|
889
|
+
tp_rank = get_tensor_model_parallel_rank()
|
890
|
+
return self.quant_method.apply(layer, x, bias, tp_rank)
|
891
|
+
return self.quant_method.apply(layer, x, bias)
|
892
|
+
|
893
|
+
|
894
|
+
class NPU_W8A8MoEMethod:
|
895
|
+
"""MoE method for NPU quantization.
|
896
|
+
|
897
|
+
This class search for specific quantization
|
898
|
+
implementations supported on NPU hardware for moe methods.
|
899
|
+
|
900
|
+
Args:
|
901
|
+
quant_config: The NPU quantization config.
|
902
|
+
"""
|
903
|
+
|
904
|
+
def __init__(self, quantization_config: W8A8Int8Config) -> None:
|
905
|
+
self.quantization_config = quantization_config
|
906
|
+
self.quant_method = self
|
907
|
+
|
908
|
+
def create_weights(
|
909
|
+
self,
|
910
|
+
layer: torch.nn.Module,
|
911
|
+
num_experts: int,
|
912
|
+
hidden_size: int,
|
913
|
+
intermediate_size: List[int],
|
914
|
+
params_dtype: torch.dtype,
|
915
|
+
**extra_weight_attrs,
|
916
|
+
) -> None:
|
917
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
918
|
+
|
919
|
+
self.num_experts = num_experts
|
920
|
+
extra_weight_attrs.update(
|
921
|
+
{"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
|
922
|
+
)
|
923
|
+
|
924
|
+
# weight
|
925
|
+
w13_weight = torch.nn.Parameter(
|
926
|
+
torch.empty(
|
927
|
+
num_experts, 2 * intermediate_size, hidden_size, dtype=torch.int8
|
928
|
+
),
|
929
|
+
requires_grad=False,
|
930
|
+
)
|
931
|
+
layer.register_parameter("w13_weight", w13_weight)
|
932
|
+
set_weight_attrs(w13_weight, extra_weight_attrs)
|
933
|
+
w2_weight = torch.nn.Parameter(
|
934
|
+
torch.empty(num_experts, hidden_size, intermediate_size, dtype=torch.int8),
|
935
|
+
requires_grad=False,
|
936
|
+
)
|
937
|
+
layer.register_parameter("w2_weight", w2_weight)
|
938
|
+
set_weight_attrs(w2_weight, extra_weight_attrs)
|
939
|
+
# scale
|
940
|
+
w13_weight_scale = torch.nn.Parameter(
|
941
|
+
torch.empty(num_experts, 2 * intermediate_size, 1, dtype=torch.float32),
|
942
|
+
requires_grad=False,
|
943
|
+
)
|
944
|
+
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
945
|
+
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
946
|
+
w2_weight_scale = torch.nn.Parameter(
|
947
|
+
torch.empty(num_experts, hidden_size, 1, dtype=torch.float32),
|
948
|
+
requires_grad=False,
|
949
|
+
)
|
950
|
+
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
951
|
+
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
952
|
+
# offset
|
953
|
+
w13_weight_offset = torch.nn.Parameter(
|
954
|
+
torch.empty(num_experts, 2 * intermediate_size, 1, dtype=torch.float32),
|
955
|
+
requires_grad=False,
|
956
|
+
)
|
957
|
+
layer.register_parameter("w13_weight_offset", w13_weight_offset)
|
958
|
+
set_weight_attrs(w13_weight_offset, extra_weight_attrs)
|
959
|
+
w2_weight_offset = torch.nn.Parameter(
|
960
|
+
torch.empty(num_experts, hidden_size, 1, dtype=torch.float32),
|
961
|
+
requires_grad=False,
|
962
|
+
)
|
963
|
+
layer.register_parameter("w2_weight_offset", w2_weight_offset)
|
964
|
+
set_weight_attrs(w2_weight_offset, extra_weight_attrs)
|
965
|
+
|
966
|
+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
967
|
+
layer.w13_weight = Parameter(
|
968
|
+
layer.w13_weight.data.transpose(1, 2).contiguous(), requires_grad=False
|
969
|
+
)
|
970
|
+
layer.w2_weight = Parameter(
|
971
|
+
layer.w2_weight.data.transpose(1, 2).contiguous(), requires_grad=False
|
972
|
+
)
|
973
|
+
layer.w13_weight_scale = Parameter(
|
974
|
+
layer.w13_weight_scale.data.squeeze(-1).contiguous(), requires_grad=False
|
975
|
+
)
|
976
|
+
layer.w2_weight_scale = Parameter(
|
977
|
+
layer.w2_weight_scale.data.squeeze(-1).contiguous(), requires_grad=False
|
978
|
+
)
|
979
|
+
layer.w13_weight_offset = Parameter(
|
980
|
+
layer.w13_weight_offset.data.squeeze(-1).contiguous(), requires_grad=False
|
981
|
+
)
|
982
|
+
layer.w2_weight_offset = Parameter(
|
983
|
+
layer.w2_weight_offset.data.squeeze(-1).contiguous(), requires_grad=False
|
984
|
+
)
|
985
|
+
|
986
|
+
def apply(
|
987
|
+
self,
|
988
|
+
layer,
|
989
|
+
x,
|
990
|
+
router_logits,
|
991
|
+
top_k,
|
992
|
+
renormalize,
|
993
|
+
use_grouped_topk,
|
994
|
+
topk_group,
|
995
|
+
num_expert_group,
|
996
|
+
num_fused_shared_experts,
|
997
|
+
custom_routing_function,
|
998
|
+
correction_bias,
|
999
|
+
activation,
|
1000
|
+
apply_router_weight_on_input,
|
1001
|
+
routed_scaling_factor,
|
1002
|
+
**kwargs,
|
1003
|
+
) -> torch.Tensor:
|
1004
|
+
from sglang.srt.layers.moe.topk import select_experts
|
1005
|
+
|
1006
|
+
global_num_experts = router_logits.shape[-1]
|
1007
|
+
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
|
1008
|
+
if global_num_experts == 256:
|
1009
|
+
topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
|
1010
|
+
router_logits,
|
1011
|
+
k=top_k,
|
1012
|
+
bias=correction_bias,
|
1013
|
+
k_group=topk_group,
|
1014
|
+
group_count=num_expert_group,
|
1015
|
+
group_select_mode=1,
|
1016
|
+
renorm=0,
|
1017
|
+
norm_type=1,
|
1018
|
+
routed_scaling_factor=1,
|
1019
|
+
eps=float(1e-20),
|
1020
|
+
)
|
1021
|
+
else:
|
1022
|
+
topk_weights, topk_ids = select_experts(
|
1023
|
+
hidden_states=x,
|
1024
|
+
router_logits=router_logits,
|
1025
|
+
use_grouped_topk=use_grouped_topk,
|
1026
|
+
top_k=top_k,
|
1027
|
+
renormalize=renormalize,
|
1028
|
+
topk_group=topk_group,
|
1029
|
+
num_expert_group=num_expert_group,
|
1030
|
+
num_fused_shared_experts=num_fused_shared_experts,
|
1031
|
+
custom_routing_function=custom_routing_function,
|
1032
|
+
correction_bias=correction_bias,
|
1033
|
+
torch_native=True,
|
1034
|
+
routed_scaling_factor=routed_scaling_factor,
|
1035
|
+
)
|
1036
|
+
topk_ids = topk_ids.to(torch.int32)
|
1037
|
+
topk_weights = topk_weights.to(x.dtype)
|
1038
|
+
return npu_fused_experts(
|
1039
|
+
hidden_states=x,
|
1040
|
+
w13=layer.w13_weight,
|
1041
|
+
w13_scale=layer.w13_weight_scale,
|
1042
|
+
w2=layer.w2_weight,
|
1043
|
+
w2_scale=layer.w2_weight_scale,
|
1044
|
+
topk_weights=topk_weights,
|
1045
|
+
topk_ids=topk_ids,
|
1046
|
+
top_k=top_k,
|
1047
|
+
)
|