sglang 0.4.9.post1__py3-none-any.whl → 0.4.9.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (75) hide show
  1. sglang/srt/configs/model_config.py +24 -1
  2. sglang/srt/conversation.py +21 -2
  3. sglang/srt/disaggregation/ascend/__init__.py +6 -0
  4. sglang/srt/disaggregation/ascend/conn.py +44 -0
  5. sglang/srt/disaggregation/ascend/transfer_engine.py +58 -0
  6. sglang/srt/disaggregation/mooncake/conn.py +15 -14
  7. sglang/srt/disaggregation/mooncake/transfer_engine.py +17 -8
  8. sglang/srt/disaggregation/utils.py +25 -3
  9. sglang/srt/entrypoints/engine.py +1 -1
  10. sglang/srt/entrypoints/http_server.py +1 -0
  11. sglang/srt/entrypoints/openai/protocol.py +11 -0
  12. sglang/srt/entrypoints/openai/serving_chat.py +7 -0
  13. sglang/srt/function_call/function_call_parser.py +2 -0
  14. sglang/srt/function_call/kimik2_detector.py +220 -0
  15. sglang/srt/hf_transformers_utils.py +18 -0
  16. sglang/srt/jinja_template_utils.py +8 -0
  17. sglang/srt/layers/communicator.py +17 -4
  18. sglang/srt/layers/linear.py +12 -2
  19. sglang/srt/layers/moe/ep_moe/kernels.py +2 -1
  20. sglang/srt/layers/moe/ep_moe/layer.py +2 -1
  21. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -2
  22. sglang/srt/layers/moe/topk.py +8 -2
  23. sglang/srt/layers/parameter.py +19 -3
  24. sglang/srt/layers/quantization/fp8_kernel.py +2 -2
  25. sglang/srt/layers/quantization/moe_wna16.py +1 -2
  26. sglang/srt/layers/quantization/w8a8_int8.py +738 -14
  27. sglang/srt/managers/io_struct.py +27 -2
  28. sglang/srt/managers/mm_utils.py +55 -94
  29. sglang/srt/managers/schedule_batch.py +16 -5
  30. sglang/srt/managers/scheduler.py +21 -1
  31. sglang/srt/managers/tokenizer_manager.py +16 -0
  32. sglang/srt/mem_cache/memory_pool.py +65 -40
  33. sglang/srt/model_executor/forward_batch_info.py +13 -1
  34. sglang/srt/model_loader/loader.py +23 -12
  35. sglang/srt/models/deepseek_janus_pro.py +1 -1
  36. sglang/srt/models/deepseek_v2.py +62 -17
  37. sglang/srt/models/deepseek_vl2.py +1 -1
  38. sglang/srt/models/gemma3_mm.py +1 -1
  39. sglang/srt/models/gemma3n_mm.py +6 -3
  40. sglang/srt/models/internvl.py +8 -2
  41. sglang/srt/models/kimi_vl.py +8 -2
  42. sglang/srt/models/llama.py +2 -0
  43. sglang/srt/models/llava.py +3 -1
  44. sglang/srt/models/llavavid.py +1 -1
  45. sglang/srt/models/minicpmo.py +1 -2
  46. sglang/srt/models/minicpmv.py +1 -1
  47. sglang/srt/models/mixtral_quant.py +4 -0
  48. sglang/srt/models/mllama4.py +13 -4
  49. sglang/srt/models/phi4mm.py +8 -2
  50. sglang/srt/models/phimoe.py +553 -0
  51. sglang/srt/models/qwen2.py +2 -0
  52. sglang/srt/models/qwen2_5_vl.py +10 -7
  53. sglang/srt/models/qwen2_vl.py +12 -1
  54. sglang/srt/models/vila.py +8 -2
  55. sglang/srt/multimodal/processors/base_processor.py +197 -137
  56. sglang/srt/multimodal/processors/deepseek_vl_v2.py +1 -1
  57. sglang/srt/multimodal/processors/gemma3.py +4 -2
  58. sglang/srt/multimodal/processors/gemma3n.py +1 -1
  59. sglang/srt/multimodal/processors/internvl.py +1 -1
  60. sglang/srt/multimodal/processors/janus_pro.py +1 -1
  61. sglang/srt/multimodal/processors/kimi_vl.py +1 -1
  62. sglang/srt/multimodal/processors/minicpm.py +4 -3
  63. sglang/srt/multimodal/processors/mllama4.py +1 -1
  64. sglang/srt/multimodal/processors/phi4mm.py +1 -1
  65. sglang/srt/multimodal/processors/pixtral.py +1 -1
  66. sglang/srt/multimodal/processors/qwen_vl.py +203 -80
  67. sglang/srt/multimodal/processors/vila.py +1 -1
  68. sglang/srt/server_args.py +11 -4
  69. sglang/srt/utils.py +154 -31
  70. sglang/version.py +1 -1
  71. {sglang-0.4.9.post1.dist-info → sglang-0.4.9.post2.dist-info}/METADATA +4 -3
  72. {sglang-0.4.9.post1.dist-info → sglang-0.4.9.post2.dist-info}/RECORD +75 -70
  73. {sglang-0.4.9.post1.dist-info → sglang-0.4.9.post2.dist-info}/WHEEL +0 -0
  74. {sglang-0.4.9.post1.dist-info → sglang-0.4.9.post2.dist-info}/licenses/LICENSE +0 -0
  75. {sglang-0.4.9.post1.dist-info → sglang-0.4.9.post2.dist-info}/top_level.txt +0 -0
@@ -1,21 +1,37 @@
1
- from typing import Any, Callable, Dict, List, Optional
1
+ import importlib
2
+ import sys
3
+ from types import MappingProxyType
4
+ from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple, Union
2
5
 
3
6
  import torch
4
7
  from torch.nn.parameter import Parameter
5
8
 
6
- from sglang.srt.distributed import get_tensor_model_parallel_world_size
9
+ from sglang.srt.distributed import (
10
+ get_tensor_model_parallel_rank,
11
+ get_tensor_model_parallel_world_size,
12
+ )
7
13
  from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
8
- from sglang.srt.layers.linear import LinearMethodBase
9
- from sglang.srt.layers.parameter import ChannelQuantScaleParameter, ModelWeightParameter
14
+ from sglang.srt.layers.linear import (
15
+ LinearMethodBase,
16
+ RowParallelLinear,
17
+ UnquantizedLinearMethod,
18
+ )
19
+ from sglang.srt.layers.parameter import (
20
+ ChannelQuantScaleParameter,
21
+ ModelWeightParameter,
22
+ PerTensorScaleParameter,
23
+ )
10
24
  from sglang.srt.layers.quantization.base_config import (
11
25
  QuantizationConfig,
12
26
  QuantizeMethodBase,
13
27
  )
14
28
  from sglang.srt.layers.quantization.int8_kernel import per_token_quant_int8
15
29
  from sglang.srt.utils import (
30
+ apply_module_patch,
16
31
  cpu_has_amx_support,
17
32
  is_cpu,
18
33
  is_cuda,
34
+ is_npu,
19
35
  set_weight_attrs,
20
36
  use_intel_amx_backend,
21
37
  )
@@ -25,6 +41,134 @@ _is_cpu_amx_available = cpu_has_amx_support()
25
41
  _is_cpu = is_cpu()
26
42
  if _is_cuda:
27
43
  from sgl_kernel import int8_scaled_mm
44
+ _is_npu = is_npu()
45
+
46
+ if _is_npu:
47
+ import torch_npu
48
+
49
+ try:
50
+ from mindie_turbo import _ops as ops
51
+ from mindie_turbo.quantize.quant_utils import quant_per_tensor
52
+ except ImportError:
53
+ useMindIETurbo = False
54
+ else:
55
+ useMindIETurbo = True
56
+
57
+
58
+ # func refers to RMSNorm.__init__
59
+ def npu_wrapper_rmsnorm_init(func):
60
+ def init(self, hidden_size: int, **extra_args) -> None:
61
+ func(self, hidden_size, **extra_args)
62
+ self.ignore_anti = True
63
+ # The Ascend w8a8_int8 quantization requires adding a bias in rmsnorm
64
+ self.bias = torch.nn.Parameter(torch.zeros(hidden_size), requires_grad=False)
65
+
66
+ return init
67
+
68
+
69
+ # func refers to RMSNorm.forward_oot
70
+ def npu_wrapper_rmsnorm_forward(func):
71
+ def _rmsnorm_forward_oot(
72
+ self,
73
+ x: torch.Tensor,
74
+ residual: Optional[torch.Tensor] = None,
75
+ ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
76
+ if not x.is_contiguous():
77
+ x = x.contiguous()
78
+ original_dtype = x.dtype
79
+ x = x.to(torch.float32)
80
+ if residual is not None:
81
+ x = x + residual.to(torch.float32)
82
+ residual = x.to(original_dtype)
83
+
84
+ x = (
85
+ torch_npu.npu_rms_norm(
86
+ x, self.weight.to(torch.float32), self.variance_epsilon
87
+ )[0]
88
+ + self.bias
89
+ )
90
+
91
+ if residual is None:
92
+ return x.to(original_dtype)
93
+ return x.to(original_dtype), residual
94
+
95
+ return _rmsnorm_forward_oot
96
+
97
+
98
+ def npu_fused_experts(
99
+ hidden_states: torch.Tensor,
100
+ w13: torch.Tensor,
101
+ w13_scale: torch.Tensor,
102
+ w2: torch.Tensor,
103
+ w2_scale: torch.Tensor,
104
+ topk_weights: torch.Tensor,
105
+ topk_ids: torch.Tensor,
106
+ top_k: int,
107
+ ):
108
+ original_shape = hidden_states.shape
109
+ original_dtype = hidden_states.dtype
110
+ scale_dtype = original_dtype if original_dtype == torch.bfloat16 else torch.float32
111
+ if len(original_shape) == 3:
112
+ hidden_states = hidden_states.view(-1, hidden_states.shape[-1])
113
+ num_tokens = hidden_states.shape[0]
114
+ num_experts = w13.shape[0]
115
+ row_idx_len = num_tokens * top_k
116
+ row_idx = (
117
+ torch.arange(0, row_idx_len, dtype=torch.int32, device=topk_weights.device)
118
+ .view(top_k, -1)
119
+ .permute(1, 0)
120
+ .contiguous()
121
+ )
122
+ hidden_states, expanded_row_idx, expanded_expert_idx = (
123
+ torch_npu.npu_moe_init_routing(
124
+ hidden_states, row_idx=row_idx, expert_idx=topk_ids, active_num=num_tokens
125
+ )
126
+ )
127
+ expert_tokens = torch_npu.npu_moe_compute_expert_tokens(
128
+ expanded_expert_idx, num_experts
129
+ )
130
+ expert_tokens = expert_tokens.to(torch.int64)
131
+ # gmm1: gate_up_proj
132
+ hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states)
133
+ hidden_states = torch_npu.npu_grouped_matmul(
134
+ x=[hidden_states],
135
+ weight=[w13],
136
+ scale=[w13_scale.to(scale_dtype)],
137
+ per_token_scale=[pertoken_scale],
138
+ split_item=2,
139
+ group_list_type=0,
140
+ group_type=0,
141
+ group_list=expert_tokens,
142
+ output_dtype=original_dtype,
143
+ )[0]
144
+ # act_fn: swiglu
145
+ hidden_states = torch_npu.npu_swiglu(hidden_states)
146
+ hidden_states, pertoken_scale = torch_npu.npu_dynamic_quant(hidden_states)
147
+ # gmm2: down_proj
148
+ hidden_states = torch_npu.npu_grouped_matmul(
149
+ x=[hidden_states],
150
+ weight=[w2],
151
+ scale=[w2_scale.to(scale_dtype)],
152
+ per_token_scale=[pertoken_scale],
153
+ split_item=2,
154
+ group_list_type=0,
155
+ group_type=0,
156
+ group_list=expert_tokens,
157
+ output_dtype=original_dtype,
158
+ )[0]
159
+
160
+ final_hidden_states = torch_npu.npu_moe_finalize_routing(
161
+ hidden_states,
162
+ skip1=None,
163
+ skip2=None,
164
+ bias=None,
165
+ scales=topk_weights,
166
+ expanded_src_to_dst_row=expanded_row_idx,
167
+ export_for_source_row=topk_ids,
168
+ )
169
+ if len(original_shape) == 3:
170
+ final_hidden_states = final_hidden_states.view(original_shape)
171
+ return final_hidden_states
28
172
 
29
173
 
30
174
  class W8A8Int8Config(QuantizationConfig):
@@ -34,16 +178,47 @@ class W8A8Int8Config(QuantizationConfig):
34
178
  - Activation: dynamic, per-token, symmetric
35
179
  """
36
180
 
37
- def __init__(self):
38
- pass
181
+ def __init__(self, quant_config: Dict[str, Any]):
182
+ super().__init__()
183
+ self.quant_description = quant_config
184
+ self.is_dynamic = quant_config.get("is_dynamic", False)
185
+ if _is_npu:
186
+ if (
187
+ "packed_modules_mapping" in quant_config
188
+ and quant_config["packed_modules_mapping"] is not None
189
+ ):
190
+ self.packed_modules_mapping = quant_config["packed_modules_mapping"]
191
+
192
+ # Ascend w8a8_int8 quantization with bias, use wrappers to isolate the effects between models
193
+ for name in self.quant_description.keys():
194
+ if "norm.bias" in name:
195
+ apply_module_patch(
196
+ "sglang.srt.layers.layernorm.RMSNorm",
197
+ "__init__",
198
+ [npu_wrapper_rmsnorm_init],
199
+ )
200
+ apply_module_patch(
201
+ "sglang.srt.layers.layernorm.RMSNorm",
202
+ "forward_npu",
203
+ [npu_wrapper_rmsnorm_forward],
204
+ )
39
205
 
40
206
  @classmethod
41
207
  def get_supported_act_dtypes(cls) -> List[torch.dtype]:
42
- return [torch.float16, torch.bfloat16]
208
+ return (
209
+ [torch.float16, torch.bfloat16]
210
+ if not _is_npu
211
+ else [torch.int8, torch.float16, torch.bfloat16]
212
+ )
43
213
 
44
214
  @classmethod
45
215
  def get_min_capability(cls) -> int:
46
- return 75
216
+ if _is_npu:
217
+ raise NotImplementedError(
218
+ 'NPU hardware does not support "get_min_capability" feature.'
219
+ )
220
+ else:
221
+ return 75
47
222
 
48
223
  @classmethod
49
224
  def get_name(self) -> str:
@@ -55,7 +230,7 @@ class W8A8Int8Config(QuantizationConfig):
55
230
 
56
231
  @classmethod
57
232
  def from_config(cls, config: Dict[str, Any]) -> "W8A8Int8Config":
58
- return cls()
233
+ return cls(config)
59
234
 
60
235
  def get_quant_method(
61
236
  self,
@@ -65,11 +240,65 @@ class W8A8Int8Config(QuantizationConfig):
65
240
  from sglang.srt.layers.linear import LinearBase
66
241
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
67
242
 
68
- if isinstance(layer, LinearBase):
69
- return W8A8Int8LinearMethod(self)
70
- elif isinstance(layer, FusedMoE):
71
- return W8A8Int8MoEMethod(self)
72
- return None
243
+ if _is_npu:
244
+ if isinstance(layer, LinearBase):
245
+ prefix_in_quant_config = prefix
246
+ proj_name = prefix.split(".")[-1]
247
+ if proj_name in self.packed_modules_mapping:
248
+ prefix_in_quant_config = prefix.replace(
249
+ proj_name, self.packed_modules_mapping[proj_name][0]
250
+ )
251
+ self.is_dynamic = (
252
+ self.quant_description[prefix_in_quant_config + ".weight"]
253
+ == "W8A8_DYNAMIC"
254
+ )
255
+ if self.is_layer_skipped(prefix, self.packed_modules_mapping):
256
+ return UnquantizedLinearMethod()
257
+ return (
258
+ NPU_W8A8DynamicLinearMethod(self)
259
+ if self.is_dynamic
260
+ else NPU_W8A8LinearMethod(self)
261
+ )
262
+ elif isinstance(layer, FusedMoE):
263
+ return NPU_W8A8MoEMethod(self)
264
+ return None
265
+ else:
266
+ if isinstance(layer, LinearBase):
267
+ return W8A8Int8LinearMethod(self)
268
+ elif isinstance(layer, FusedMoE):
269
+ return W8A8Int8MoEMethod(self)
270
+ return None
271
+
272
+ def is_layer_skipped(
273
+ self, prefix: str, fused_mapping: Mapping[str, List[str]] = MappingProxyType({})
274
+ ):
275
+ # adapted from vllm.model_executor.layers.quantization.utils.quant_utils.is_layer_skipped
276
+ proj_name = prefix.split(".")[-1]
277
+ if proj_name in fused_mapping:
278
+ shard_prefixes = [
279
+ prefix.replace(proj_name, shard_proj_name)
280
+ for shard_proj_name in fused_mapping[proj_name]
281
+ ]
282
+
283
+ is_skipped = None
284
+ for shard_prefix in shard_prefixes:
285
+ is_shard_skipped = (
286
+ self.quant_description[shard_prefix + ".weight"] == "FLOAT"
287
+ )
288
+
289
+ if is_skipped is None:
290
+ is_skipped = is_shard_skipped
291
+ elif is_shard_skipped != is_skipped:
292
+ raise ValueError(
293
+ f"Detected some but not all shards of {prefix} "
294
+ "are quantized. All shards of fused layers "
295
+ "to have the same precision."
296
+ )
297
+ else:
298
+ is_skipped = self.quant_description[prefix + ".weight"] == "FLOAT"
299
+
300
+ assert is_skipped is not None
301
+ return is_skipped
73
302
 
74
303
  def get_scaled_act_names(self) -> List[str]:
75
304
  return []
@@ -321,3 +550,498 @@ class W8A8Int8MoEMethod:
321
550
  no_combine=no_combine,
322
551
  routed_scaling_factor=routed_scaling_factor,
323
552
  )
553
+
554
+
555
+ class NPU_W8A8LinearMethodImpl:
556
+ """Linear method for NPU W8A8."""
557
+
558
+ def __init__(self) -> None:
559
+ # aclnn quant matmul requires to transpose matrix B, set to true by default.
560
+ self.transpose_weight = True
561
+
562
+ @staticmethod
563
+ def get_weight(
564
+ input_size: int,
565
+ output_size: int,
566
+ params_dtype: torch.dtype = torch.bfloat16,
567
+ ) -> Dict[str, Any]:
568
+ params_dict = {"weight": torch.empty(output_size, input_size, dtype=torch.int8)}
569
+ return params_dict
570
+
571
+ @staticmethod
572
+ def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
573
+ params_dict = {}
574
+ params_dict["input_scale"] = torch.empty(1, dtype=params_dtype)
575
+ params_dict["input_offset"] = torch.empty(1, dtype=torch.int8)
576
+ return params_dict
577
+
578
+ @staticmethod
579
+ def get_perchannel_param(
580
+ output_size: int,
581
+ params_dtype: torch.dtype,
582
+ ) -> Dict[str, Any]:
583
+ params_dict = {}
584
+ params_dict["quant_bias"] = torch.empty(output_size, dtype=torch.int32)
585
+ if params_dtype == torch.bfloat16:
586
+ params_dict["deq_scale"] = torch.empty(output_size, dtype=torch.float32)
587
+ elif params_dtype == torch.float16:
588
+ params_dict["deq_scale"] = torch.empty(output_size, dtype=torch.int64)
589
+ params_dict["weight_scale"] = torch.empty(output_size, 1, dtype=params_dtype)
590
+ params_dict["weight_offset"] = torch.empty(output_size, 1, dtype=params_dtype)
591
+ return params_dict
592
+
593
+ @staticmethod
594
+ def apply(
595
+ layer: torch.nn.Module,
596
+ x: torch.Tensor,
597
+ bias: Optional[torch.Tensor] = None,
598
+ tp_rank: Optional[int] = 0,
599
+ ) -> torch.Tensor:
600
+ original_dtype = x.dtype
601
+ if original_dtype != torch.int8:
602
+ x = torch_npu.npu_quantize(
603
+ x,
604
+ layer.aclnn_input_scale,
605
+ layer.aclnn_input_offset,
606
+ torch.qint8,
607
+ -1,
608
+ True,
609
+ )
610
+
611
+ quant_bias = layer.quant_bias if tp_rank == 0 else None
612
+ return torch_npu.npu_quant_matmul(
613
+ x,
614
+ layer.weight,
615
+ layer.deq_scale,
616
+ bias=quant_bias,
617
+ output_dtype=original_dtype,
618
+ )
619
+
620
+ def process_weights_after_loading(self, layer):
621
+ expanding_factor = layer.weight.data.shape[1]
622
+ layer.aclnn_input_scale = torch.nn.Parameter(
623
+ layer.input_scale.data.repeat(expanding_factor).to(device="npu"),
624
+ requires_grad=False,
625
+ )
626
+ layer.aclnn_input_offset = torch.nn.Parameter(
627
+ layer.input_offset.data.repeat(expanding_factor).to(device="npu"),
628
+ requires_grad=False,
629
+ )
630
+ if self.transpose_weight:
631
+ layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
632
+ layer.weight_scale.data = torch.flatten(layer.weight_scale.data)
633
+ layer.weight_offset.data = torch.flatten(layer.weight_offset.data)
634
+
635
+
636
+ class NPU_W8A8LinearMethodMTImpl:
637
+ """Linear method for NPU W8A8."""
638
+
639
+ def __init__(self) -> None:
640
+ self.transpose_weight = True
641
+
642
+ @staticmethod
643
+ def get_weight(
644
+ input_size: int,
645
+ output_size: int,
646
+ params_dtype: torch.dtype = torch.bfloat16,
647
+ ) -> Dict[str, Any]:
648
+ params_dict = {"weight": torch.empty(output_size, input_size, dtype=torch.int8)}
649
+ return params_dict
650
+
651
+ @staticmethod
652
+ def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
653
+ params_dict = {}
654
+ params_dict["input_scale"] = torch.empty(1, dtype=params_dtype)
655
+ params_dict["input_offset"] = torch.empty(1, dtype=torch.int8)
656
+ return params_dict
657
+
658
+ @staticmethod
659
+ def get_perchannel_param(
660
+ output_size: int,
661
+ params_dtype: torch.dtype,
662
+ ) -> Dict[str, Any]:
663
+ params_dict = {}
664
+ params_dict["quant_bias"] = torch.empty(output_size, dtype=torch.int32)
665
+ if params_dtype == torch.bfloat16:
666
+ params_dict["deq_scale"] = torch.empty(output_size, dtype=torch.float32)
667
+ elif params_dtype == torch.float16:
668
+ params_dict["deq_scale"] = torch.empty(output_size, dtype=torch.int64)
669
+ params_dict["weight_scale"] = torch.empty(output_size, 1, dtype=params_dtype)
670
+ params_dict["weight_offset"] = torch.empty(output_size, 1, dtype=params_dtype)
671
+ return params_dict
672
+
673
+ @staticmethod
674
+ def apply(
675
+ layer: torch.nn.Module,
676
+ x: torch.Tensor,
677
+ bias: Optional[torch.Tensor] = None,
678
+ tp_rank: Optional[int] = 0,
679
+ ) -> torch.Tensor:
680
+ original_dtype = x.dtype
681
+ if original_dtype != torch.int8:
682
+ x = quant_per_tensor(x, layer.input_scale, layer.input_offset)
683
+
684
+ quant_bias = layer.quant_bias if tp_rank == 0 else None
685
+ return ops.quant_matmul(
686
+ x=x, weight=layer.weight, deq_scale=layer.deq_scale, deq_bias=quant_bias
687
+ )
688
+
689
+ def process_weights_after_loading(self, layer):
690
+ layer.aclnn_deq_scale = torch.nn.Parameter(
691
+ torch_npu.npu_trans_quant_param(layer.deq_scale.npu()).to(device="npu"),
692
+ requires_grad=False,
693
+ )
694
+
695
+
696
+ class NPU_W8A8LinearMethod(LinearMethodBase):
697
+ """Linear method for NPU quantization.
698
+
699
+ This class search for specific quantization
700
+ implementation supported on NPU hardware for linear methods.
701
+
702
+ Args:
703
+ quant_config: The NPU quantization config.
704
+ """
705
+
706
+ def __init__(self, quantization_config: W8A8Int8Config) -> None:
707
+ self.quantization_config = quantization_config
708
+ self.quant_method = (
709
+ NPU_W8A8LinearMethodMTImpl()
710
+ if useMindIETurbo
711
+ else NPU_W8A8LinearMethodImpl()
712
+ )
713
+
714
+ def create_weights(
715
+ self,
716
+ layer: torch.nn.Module,
717
+ input_size_per_partition: int,
718
+ output_partition_sizes: List[int],
719
+ input_size: int,
720
+ output_size: int,
721
+ params_dtype: torch.dtype,
722
+ **extra_weight_attrs,
723
+ ) -> None:
724
+ output_size_per_partition = sum(output_partition_sizes)
725
+ weight_loader = extra_weight_attrs.get("weight_loader")
726
+
727
+ weight_dict = self.quant_method.get_weight(
728
+ input_size_per_partition, output_size_per_partition, params_dtype
729
+ )
730
+ for weight_name, weight_param in weight_dict.items():
731
+ param = torch.nn.Parameter(weight_param, requires_grad=False)
732
+ set_weight_attrs(param, {"input_dim": 1, "output_dim": 0})
733
+ layer.register_parameter(weight_name, param)
734
+ set_weight_attrs(param, extra_weight_attrs)
735
+
736
+ pertensor_dict = self.quant_method.get_pertensor_param(params_dtype)
737
+ for pertensor_name, pertensor_param in pertensor_dict.items():
738
+ param = PerTensorScaleParameter(
739
+ data=pertensor_param, weight_loader=weight_loader
740
+ )
741
+ # disable warning
742
+ param.ignore_warning = True
743
+ layer.register_parameter(pertensor_name, param)
744
+
745
+ perchannel_dict = self.quant_method.get_perchannel_param(
746
+ output_size_per_partition, params_dtype
747
+ )
748
+ for perchannel_name, perchannel_param in perchannel_dict.items():
749
+ param = torch.nn.Parameter(perchannel_param, requires_grad=False)
750
+ set_weight_attrs(param, {"output_dim": 0})
751
+ layer.register_parameter(perchannel_name, param)
752
+ set_weight_attrs(param, extra_weight_attrs)
753
+
754
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
755
+ if hasattr(self.quant_method, "process_weights_after_loading"):
756
+ self.quant_method.process_weights_after_loading(layer)
757
+
758
+ def apply(
759
+ self,
760
+ layer: torch.nn.Module,
761
+ x: torch.Tensor,
762
+ bias: Optional[torch.Tensor] = None,
763
+ ) -> torch.Tensor:
764
+ if isinstance(layer, RowParallelLinear):
765
+ tp_rank = get_tensor_model_parallel_rank()
766
+ return self.quant_method.apply(layer, x, bias, tp_rank)
767
+ return self.quant_method.apply(layer, x, bias)
768
+
769
+
770
+ class NPU_W8A8DynamicLinearMethodImpl:
771
+ """Linear method for NPU W8A8_DYNAMIC."""
772
+
773
+ def __init__(self):
774
+ self.transpose_weight = True
775
+
776
+ @staticmethod
777
+ def get_weight(
778
+ input_size: int, output_size: int, params_dtype: torch.dtype
779
+ ) -> Dict[str, Any]:
780
+ params_dict = {"weight": torch.empty(output_size, input_size, dtype=torch.int8)}
781
+ return params_dict
782
+
783
+ @staticmethod
784
+ def get_pertensor_param(params_dtype: torch.dtype) -> Dict[str, Any]:
785
+ return {}
786
+
787
+ @staticmethod
788
+ def get_perchannel_param(
789
+ output_size: int,
790
+ params_dtype: torch.dtype,
791
+ ) -> Dict[str, Any]:
792
+ params_dict = {}
793
+ params_dict["weight_scale"] = torch.empty(output_size, 1, dtype=params_dtype)
794
+ params_dict["weight_offset"] = torch.empty(output_size, 1, dtype=params_dtype)
795
+ return params_dict
796
+
797
+ @staticmethod
798
+ def apply(
799
+ layer: torch.nn.Module,
800
+ x: torch.Tensor,
801
+ bias: Optional[torch.Tensor] = None,
802
+ tp_rank: Optional[int] = 0,
803
+ ) -> torch.Tensor:
804
+ original_dtype = x.dtype
805
+ # use ATB quantize
806
+ quant_out, dynamic_scale = torch_npu.npu_dynamic_quant(x)
807
+ return torch_npu.npu_quant_matmul(
808
+ quant_out,
809
+ layer.weight,
810
+ layer.weight_scale,
811
+ pertoken_scale=dynamic_scale,
812
+ bias=bias,
813
+ output_dtype=original_dtype,
814
+ )
815
+
816
+ def process_weights_after_loading(self, layer):
817
+ if self.transpose_weight:
818
+ layer.weight.data = layer.weight.data.transpose(0, 1).contiguous()
819
+ layer.weight_scale.data = layer.weight_scale.data.flatten()
820
+ layer.weight_scale_fp32 = layer.weight_scale.data.to(torch.float32)
821
+ layer.weight_offset.data = layer.weight_offset.data.flatten()
822
+
823
+
824
+ class NPU_W8A8DynamicLinearMethod(LinearMethodBase):
825
+ """Linear method for NPU quantization.
826
+
827
+ This class search for specific quantization
828
+ implementations supported on NPU hardware for linear methods.
829
+
830
+ Args:
831
+ quant_config: The NPU quantization config.
832
+ """
833
+
834
+ def __init__(self, quantization_config: W8A8Int8Config) -> None:
835
+ self.quantization_config = quantization_config
836
+ self.quant_method = NPU_W8A8DynamicLinearMethodImpl()
837
+
838
+ def create_weights(
839
+ self,
840
+ layer: torch.nn.Module,
841
+ input_size_per_partition: int,
842
+ output_partition_sizes: List[int],
843
+ input_size: int,
844
+ output_size: int,
845
+ params_dtype: torch.dtype,
846
+ **extra_weight_attrs,
847
+ ) -> None:
848
+ output_size_per_partition = sum(output_partition_sizes)
849
+ weight_loader = extra_weight_attrs.get("weight_loader")
850
+
851
+ weight_dict = self.quant_method.get_weight(
852
+ input_size_per_partition, output_size_per_partition, params_dtype
853
+ )
854
+ for weight_name, weight_param in weight_dict.items():
855
+ param = torch.nn.Parameter(weight_param, requires_grad=False)
856
+ set_weight_attrs(param, {"input_dim": 1, "output_dim": 0})
857
+ layer.register_parameter(weight_name, param)
858
+ set_weight_attrs(param, extra_weight_attrs)
859
+
860
+ pertensor_dict = self.quant_method.get_pertensor_param(params_dtype)
861
+ for pertensor_name, pertensor_param in pertensor_dict.items():
862
+ param = PerTensorScaleParameter(
863
+ data=pertensor_param, weight_loader=weight_loader
864
+ )
865
+ # disable warning
866
+ param.ignore_warning = True
867
+ layer.register_parameter(pertensor_name, param)
868
+
869
+ perchannel_dict = self.quant_method.get_perchannel_param(
870
+ output_size_per_partition, params_dtype
871
+ )
872
+ for perchannel_name, perchannel_param in perchannel_dict.items():
873
+ param = torch.nn.Parameter(perchannel_param, requires_grad=False)
874
+ set_weight_attrs(param, {"output_dim": 0})
875
+ layer.register_parameter(perchannel_name, param)
876
+ set_weight_attrs(param, extra_weight_attrs)
877
+
878
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
879
+ if hasattr(self.quant_method, "process_weights_after_loading"):
880
+ self.quant_method.process_weights_after_loading(layer)
881
+
882
+ def apply(
883
+ self,
884
+ layer: torch.nn.Module,
885
+ x: torch.Tensor,
886
+ bias: Optional[torch.Tensor] = None,
887
+ ) -> torch.Tensor:
888
+ if isinstance(layer, RowParallelLinear):
889
+ tp_rank = get_tensor_model_parallel_rank()
890
+ return self.quant_method.apply(layer, x, bias, tp_rank)
891
+ return self.quant_method.apply(layer, x, bias)
892
+
893
+
894
+ class NPU_W8A8MoEMethod:
895
+ """MoE method for NPU quantization.
896
+
897
+ This class search for specific quantization
898
+ implementations supported on NPU hardware for moe methods.
899
+
900
+ Args:
901
+ quant_config: The NPU quantization config.
902
+ """
903
+
904
+ def __init__(self, quantization_config: W8A8Int8Config) -> None:
905
+ self.quantization_config = quantization_config
906
+ self.quant_method = self
907
+
908
+ def create_weights(
909
+ self,
910
+ layer: torch.nn.Module,
911
+ num_experts: int,
912
+ hidden_size: int,
913
+ intermediate_size: List[int],
914
+ params_dtype: torch.dtype,
915
+ **extra_weight_attrs,
916
+ ) -> None:
917
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
918
+
919
+ self.num_experts = num_experts
920
+ extra_weight_attrs.update(
921
+ {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
922
+ )
923
+
924
+ # weight
925
+ w13_weight = torch.nn.Parameter(
926
+ torch.empty(
927
+ num_experts, 2 * intermediate_size, hidden_size, dtype=torch.int8
928
+ ),
929
+ requires_grad=False,
930
+ )
931
+ layer.register_parameter("w13_weight", w13_weight)
932
+ set_weight_attrs(w13_weight, extra_weight_attrs)
933
+ w2_weight = torch.nn.Parameter(
934
+ torch.empty(num_experts, hidden_size, intermediate_size, dtype=torch.int8),
935
+ requires_grad=False,
936
+ )
937
+ layer.register_parameter("w2_weight", w2_weight)
938
+ set_weight_attrs(w2_weight, extra_weight_attrs)
939
+ # scale
940
+ w13_weight_scale = torch.nn.Parameter(
941
+ torch.empty(num_experts, 2 * intermediate_size, 1, dtype=torch.float32),
942
+ requires_grad=False,
943
+ )
944
+ layer.register_parameter("w13_weight_scale", w13_weight_scale)
945
+ set_weight_attrs(w13_weight_scale, extra_weight_attrs)
946
+ w2_weight_scale = torch.nn.Parameter(
947
+ torch.empty(num_experts, hidden_size, 1, dtype=torch.float32),
948
+ requires_grad=False,
949
+ )
950
+ layer.register_parameter("w2_weight_scale", w2_weight_scale)
951
+ set_weight_attrs(w2_weight_scale, extra_weight_attrs)
952
+ # offset
953
+ w13_weight_offset = torch.nn.Parameter(
954
+ torch.empty(num_experts, 2 * intermediate_size, 1, dtype=torch.float32),
955
+ requires_grad=False,
956
+ )
957
+ layer.register_parameter("w13_weight_offset", w13_weight_offset)
958
+ set_weight_attrs(w13_weight_offset, extra_weight_attrs)
959
+ w2_weight_offset = torch.nn.Parameter(
960
+ torch.empty(num_experts, hidden_size, 1, dtype=torch.float32),
961
+ requires_grad=False,
962
+ )
963
+ layer.register_parameter("w2_weight_offset", w2_weight_offset)
964
+ set_weight_attrs(w2_weight_offset, extra_weight_attrs)
965
+
966
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
967
+ layer.w13_weight = Parameter(
968
+ layer.w13_weight.data.transpose(1, 2).contiguous(), requires_grad=False
969
+ )
970
+ layer.w2_weight = Parameter(
971
+ layer.w2_weight.data.transpose(1, 2).contiguous(), requires_grad=False
972
+ )
973
+ layer.w13_weight_scale = Parameter(
974
+ layer.w13_weight_scale.data.squeeze(-1).contiguous(), requires_grad=False
975
+ )
976
+ layer.w2_weight_scale = Parameter(
977
+ layer.w2_weight_scale.data.squeeze(-1).contiguous(), requires_grad=False
978
+ )
979
+ layer.w13_weight_offset = Parameter(
980
+ layer.w13_weight_offset.data.squeeze(-1).contiguous(), requires_grad=False
981
+ )
982
+ layer.w2_weight_offset = Parameter(
983
+ layer.w2_weight_offset.data.squeeze(-1).contiguous(), requires_grad=False
984
+ )
985
+
986
+ def apply(
987
+ self,
988
+ layer,
989
+ x,
990
+ router_logits,
991
+ top_k,
992
+ renormalize,
993
+ use_grouped_topk,
994
+ topk_group,
995
+ num_expert_group,
996
+ num_fused_shared_experts,
997
+ custom_routing_function,
998
+ correction_bias,
999
+ activation,
1000
+ apply_router_weight_on_input,
1001
+ routed_scaling_factor,
1002
+ **kwargs,
1003
+ ) -> torch.Tensor:
1004
+ from sglang.srt.layers.moe.topk import select_experts
1005
+
1006
+ global_num_experts = router_logits.shape[-1]
1007
+ # NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
1008
+ if global_num_experts == 256:
1009
+ topk_weights, topk_ids, _ = torch_npu.npu_moe_gating_top_k(
1010
+ router_logits,
1011
+ k=top_k,
1012
+ bias=correction_bias,
1013
+ k_group=topk_group,
1014
+ group_count=num_expert_group,
1015
+ group_select_mode=1,
1016
+ renorm=0,
1017
+ norm_type=1,
1018
+ routed_scaling_factor=1,
1019
+ eps=float(1e-20),
1020
+ )
1021
+ else:
1022
+ topk_weights, topk_ids = select_experts(
1023
+ hidden_states=x,
1024
+ router_logits=router_logits,
1025
+ use_grouped_topk=use_grouped_topk,
1026
+ top_k=top_k,
1027
+ renormalize=renormalize,
1028
+ topk_group=topk_group,
1029
+ num_expert_group=num_expert_group,
1030
+ num_fused_shared_experts=num_fused_shared_experts,
1031
+ custom_routing_function=custom_routing_function,
1032
+ correction_bias=correction_bias,
1033
+ torch_native=True,
1034
+ routed_scaling_factor=routed_scaling_factor,
1035
+ )
1036
+ topk_ids = topk_ids.to(torch.int32)
1037
+ topk_weights = topk_weights.to(x.dtype)
1038
+ return npu_fused_experts(
1039
+ hidden_states=x,
1040
+ w13=layer.w13_weight,
1041
+ w13_scale=layer.w13_weight_scale,
1042
+ w2=layer.w2_weight,
1043
+ w2_scale=layer.w2_weight_scale,
1044
+ topk_weights=topk_weights,
1045
+ topk_ids=topk_ids,
1046
+ top_k=top_k,
1047
+ )