sglang 0.4.4.post2__py3-none-any.whl → 0.4.4.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. sglang/bench_serving.py +72 -10
  2. sglang/srt/_custom_ops.py +59 -92
  3. sglang/srt/configs/deepseekvl2.py +10 -1
  4. sglang/srt/configs/model_config.py +6 -16
  5. sglang/srt/constrained/base_grammar_backend.py +5 -1
  6. sglang/srt/custom_op.py +5 -0
  7. sglang/srt/distributed/device_communicators/custom_all_reduce.py +28 -80
  8. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
  9. sglang/srt/distributed/parallel_state.py +32 -5
  10. sglang/srt/entrypoints/engine.py +0 -5
  11. sglang/srt/entrypoints/http_server.py +7 -1
  12. sglang/srt/entrypoints/verl_engine.py +2 -0
  13. sglang/srt/function_call_parser.py +0 -1
  14. sglang/srt/layers/attention/flashattention_backend.py +582 -125
  15. sglang/srt/layers/attention/flashinfer_backend.py +5 -7
  16. sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -3
  17. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  18. sglang/srt/layers/dp_attention.py +12 -1
  19. sglang/srt/layers/moe/ep_moe/kernels.py +142 -0
  20. sglang/srt/layers/moe/ep_moe/layer.py +79 -80
  21. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +382 -199
  22. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,block_shape=[128, 128].json +146 -0
  23. sglang/srt/layers/moe/fused_moe_triton/configs/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  24. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  25. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +403 -47
  26. sglang/srt/layers/moe/topk.py +79 -6
  27. sglang/srt/layers/quantization/__init__.py +137 -165
  28. sglang/srt/layers/quantization/awq.py +200 -0
  29. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +2 -1
  30. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +34 -10
  31. sglang/srt/layers/quantization/fp8_kernel.py +2 -1
  32. sglang/srt/layers/quantization/fp8_utils.py +1 -4
  33. sglang/srt/layers/quantization/gptq.py +30 -40
  34. sglang/srt/layers/quantization/moe_wna16.py +501 -0
  35. sglang/srt/layers/quantization/utils.py +1 -1
  36. sglang/srt/layers/quantization/w8a8_fp8.py +1 -1
  37. sglang/srt/lora/backend/base_backend.py +4 -4
  38. sglang/srt/lora/backend/flashinfer_backend.py +12 -9
  39. sglang/srt/lora/backend/triton_backend.py +5 -8
  40. sglang/srt/lora/layers.py +19 -33
  41. sglang/srt/lora/lora_manager.py +20 -7
  42. sglang/srt/lora/mem_pool.py +12 -6
  43. sglang/srt/lora/triton_ops/gate_up_lora_b.py +10 -4
  44. sglang/srt/lora/triton_ops/qkv_lora_b.py +8 -3
  45. sglang/srt/lora/triton_ops/sgemm_lora_a.py +16 -5
  46. sglang/srt/lora/triton_ops/sgemm_lora_b.py +11 -6
  47. sglang/srt/lora/utils.py +6 -0
  48. sglang/srt/managers/cache_controller.py +34 -11
  49. sglang/srt/managers/io_struct.py +4 -2
  50. sglang/srt/managers/mm_utils.py +202 -156
  51. sglang/srt/managers/multimodal_processor.py +0 -2
  52. sglang/srt/managers/multimodal_processors/base_processor.py +45 -77
  53. sglang/srt/managers/multimodal_processors/clip.py +44 -0
  54. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +17 -58
  55. sglang/srt/managers/multimodal_processors/gemma3.py +12 -27
  56. sglang/srt/managers/multimodal_processors/janus_pro.py +21 -47
  57. sglang/srt/managers/multimodal_processors/llava.py +34 -14
  58. sglang/srt/managers/multimodal_processors/minicpm.py +35 -38
  59. sglang/srt/managers/multimodal_processors/mlama.py +10 -23
  60. sglang/srt/managers/multimodal_processors/qwen_vl.py +22 -45
  61. sglang/srt/managers/schedule_batch.py +185 -127
  62. sglang/srt/managers/scheduler.py +29 -23
  63. sglang/srt/managers/tokenizer_manager.py +1 -2
  64. sglang/srt/managers/tp_worker.py +3 -0
  65. sglang/srt/managers/utils.py +1 -6
  66. sglang/srt/mem_cache/hiradix_cache.py +62 -52
  67. sglang/srt/mem_cache/memory_pool.py +72 -6
  68. sglang/srt/mem_cache/paged_allocator.py +39 -0
  69. sglang/srt/metrics/collector.py +23 -53
  70. sglang/srt/model_executor/cuda_graph_runner.py +16 -13
  71. sglang/srt/model_executor/forward_batch_info.py +10 -10
  72. sglang/srt/model_executor/model_runner.py +64 -59
  73. sglang/srt/model_loader/loader.py +19 -1
  74. sglang/srt/model_loader/weight_utils.py +6 -3
  75. sglang/srt/models/clip.py +568 -0
  76. sglang/srt/models/deepseek_janus_pro.py +12 -17
  77. sglang/srt/models/deepseek_v2.py +339 -123
  78. sglang/srt/models/deepseek_vl2.py +105 -104
  79. sglang/srt/models/gemma3_causal.py +12 -2
  80. sglang/srt/models/gemma3_mm.py +20 -80
  81. sglang/srt/models/llama.py +4 -1
  82. sglang/srt/models/llava.py +31 -19
  83. sglang/srt/models/llavavid.py +16 -7
  84. sglang/srt/models/minicpmo.py +63 -147
  85. sglang/srt/models/minicpmv.py +17 -27
  86. sglang/srt/models/mllama.py +29 -14
  87. sglang/srt/models/qwen2.py +9 -6
  88. sglang/srt/models/qwen2_5_vl.py +21 -31
  89. sglang/srt/models/qwen2_vl.py +20 -21
  90. sglang/srt/openai_api/adapter.py +106 -93
  91. sglang/srt/openai_api/protocol.py +10 -5
  92. sglang/srt/patch_torch.py +71 -0
  93. sglang/srt/platforms/interface.py +371 -0
  94. sglang/srt/server_args.py +120 -25
  95. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -5
  96. sglang/srt/speculative/eagle_utils.py +140 -28
  97. sglang/srt/speculative/eagle_worker.py +94 -25
  98. sglang/srt/utils.py +137 -51
  99. sglang/test/runners.py +27 -2
  100. sglang/test/test_custom_ops.py +55 -0
  101. sglang/test/test_utils.py +14 -27
  102. sglang/utils.py +2 -2
  103. sglang/version.py +1 -1
  104. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/METADATA +10 -5
  105. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/RECORD +108 -99
  106. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/WHEEL +0 -0
  107. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/licenses/LICENSE +0 -0
  108. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,501 @@
1
+ # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/moe_wna16.py
2
+
3
+ import logging
4
+ from typing import Any, Callable, Dict, List, Optional
5
+
6
+ import torch
7
+
8
+ from sglang.srt.distributed import get_tensor_model_parallel_rank
9
+ from sglang.srt.distributed.parallel_state import get_tp_group
10
+ from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
11
+ from sglang.srt.layers.quantization.awq import AWQConfig
12
+ from sglang.srt.layers.quantization.base_config import (
13
+ QuantizationConfig,
14
+ QuantizeMethodBase,
15
+ )
16
+ from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
17
+ from sglang.srt.utils import get_device_capability, set_weight_attrs
18
+
19
+ logger = logging.getLogger(__name__)
20
+
21
+
22
+ class MoeWNA16Config(QuantizationConfig):
23
+ """Config class for MOE WNA16 (W8A16/W4A16) quantization."""
24
+
25
+ def __init__(
26
+ self,
27
+ linear_quant_method: str,
28
+ weight_bits: int,
29
+ group_size: int,
30
+ has_zp: bool,
31
+ lm_head_quantized: bool,
32
+ modules_to_not_convert: Optional[List[str]],
33
+ full_config: Dict[str, Any],
34
+ ) -> None:
35
+ super().__init__()
36
+ self.weight_bits = weight_bits
37
+ self.group_size = group_size
38
+ self.has_zp = has_zp
39
+ self.bit8_pack_factor = 8 // self.weight_bits
40
+ self.lm_head_quantized = lm_head_quantized
41
+ self.linear_quant_method = linear_quant_method
42
+ self.full_config = full_config
43
+ self.use_marlin = False
44
+ # Avoid circular import
45
+
46
+ if self.linear_quant_method == "gptq":
47
+ self.use_marlin = GPTQMarlinConfig.is_gptq_marlin_compatible(full_config)
48
+ elif self.linear_quant_method == "awq":
49
+ capability_tuple = get_device_capability()
50
+ device_capability = (
51
+ -1
52
+ if capability_tuple is None
53
+ else capability_tuple[0] * 10 + capability_tuple[1]
54
+ )
55
+ awq_min_capability = AWQConfig.get_min_capability()
56
+ if device_capability < awq_min_capability:
57
+ raise ValueError(
58
+ "The quantization method moe_wna16 + awq is not supported "
59
+ "for the current GPU. "
60
+ f"Minimum capability: {awq_min_capability}. "
61
+ f"Current capability: {device_capability}."
62
+ )
63
+ else:
64
+ raise ValueError("moe_wna16 only support gptq and awq.")
65
+
66
+ if modules_to_not_convert is None:
67
+ self.modules_to_not_convert = []
68
+ else:
69
+ self.modules_to_not_convert = modules_to_not_convert
70
+
71
+ @classmethod
72
+ def get_name(cls) -> str:
73
+ return "moe_wna16"
74
+
75
+ @classmethod
76
+ def get_supported_act_dtypes(cls) -> List[torch.dtype]:
77
+ return [torch.bfloat16, torch.half]
78
+
79
+ @classmethod
80
+ def get_min_capability(cls) -> int:
81
+ return 70
82
+
83
+ @classmethod
84
+ def get_config_filenames(cls) -> List[str]:
85
+ return ["quantize_config.json"]
86
+
87
+ def get_scaled_act_names(self) -> List[str]:
88
+ raise NotImplementedError
89
+
90
+ @classmethod
91
+ def from_config(cls, config: Dict[str, Any]) -> "MoeWNA16Config":
92
+ quant_method = cls.get_from_keys(config, ["quant_method"])
93
+ weight_bits = cls.get_from_keys(config, ["bits"])
94
+ group_size = cls.get_from_keys(config, ["group_size"])
95
+ lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False)
96
+ if quant_method == "gptq":
97
+ has_zp = not cls.get_from_keys(config, ["sym"])
98
+ modules_to_not_convert = []
99
+ elif quant_method == "awq":
100
+ has_zp = cls.get_from_keys(config, ["zero_point"])
101
+ modules_to_not_convert = cls.get_from_keys_or(
102
+ config, ["modules_to_not_convert"], None
103
+ )
104
+ else:
105
+ raise ValueError("moe_wna16 only support gptq and awq.")
106
+
107
+ return cls(
108
+ quant_method,
109
+ weight_bits,
110
+ group_size,
111
+ has_zp,
112
+ lm_head_quantized,
113
+ modules_to_not_convert,
114
+ config,
115
+ )
116
+
117
+ @classmethod
118
+ def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]:
119
+ can_convert = cls.is_moe_wna16_compatible(hf_quant_cfg)
120
+ if can_convert and user_quant == "moe_wna16":
121
+ return cls.get_name()
122
+ return None
123
+
124
+ @classmethod
125
+ def is_moe_wna16_compatible(cls, quant_config: Dict[str, Any]):
126
+ # Extract data from quant config.
127
+ quant_method = quant_config.get("quant_method", "").lower()
128
+ num_bits = quant_config.get("bits")
129
+ desc_act = quant_config.get("desc_act")
130
+
131
+ capability_tuple = get_device_capability()
132
+ device_capability = (
133
+ -1
134
+ if capability_tuple is None
135
+ else capability_tuple[0] * 10 + capability_tuple[1]
136
+ )
137
+ # Avoid circular import
138
+ awq_min_capability = AWQConfig.get_min_capability()
139
+
140
+ gptq_compatible = quant_method == "gptq" and not desc_act and num_bits in [4, 8]
141
+ awq_compatible = (
142
+ quant_method == "awq"
143
+ and num_bits == 4
144
+ and device_capability >= awq_min_capability
145
+ )
146
+
147
+ return gptq_compatible or awq_compatible
148
+
149
+ def get_quant_method(
150
+ self, layer: torch.nn.Module, prefix: str
151
+ ) -> Optional["QuantizeMethodBase"]:
152
+ # avoid circular import
153
+ from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
154
+
155
+ if is_layer_skipped_quant(prefix, self.modules_to_not_convert):
156
+ return UnquantizedLinearMethod()
157
+ elif isinstance(layer, LinearBase):
158
+
159
+ if self.linear_quant_method == "gptq":
160
+ if self.use_marlin:
161
+ return GPTQMarlinConfig.from_config(
162
+ self.full_config
163
+ ).get_quant_method(layer, prefix)
164
+ else:
165
+ return GPTQConfig.from_config(self.full_config).get_quant_method(
166
+ layer, prefix
167
+ )
168
+ elif self.linear_quant_method == "awq":
169
+ return AWQConfig.from_config(self.full_config).get_quant_method(
170
+ layer, prefix
171
+ )
172
+ else:
173
+ raise ValueError("moe_wna16 only support gptq and awq.")
174
+ elif isinstance(layer, FusedMoE):
175
+ return MoeWNA16Method(self)
176
+ return None
177
+
178
+
179
+ def is_layer_skipped_quant(prefix: str, modules_to_not_convert: List[str]):
180
+ return any(module_name in prefix for module_name in modules_to_not_convert)
181
+
182
+
183
+ class MoeWNA16Method:
184
+ """Linear method for MOE WNA16 (W8A16/W4A16) quantization.
185
+
186
+ Args:
187
+ quant_config: The MOE WNA16 (W8A16/W4A16) quantization config.
188
+ """
189
+
190
+ def __new__(cls, *args, **kwargs):
191
+ # avoid circular import
192
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
193
+
194
+ if not hasattr(cls, "_initialized"):
195
+ original_init = cls.__init__
196
+ new_cls = type(
197
+ cls.__name__,
198
+ (FusedMoEMethodBase,),
199
+ {
200
+ "__init__": original_init,
201
+ **{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
202
+ },
203
+ )
204
+ obj = super(new_cls, new_cls).__new__(new_cls)
205
+ obj.__init__(*args, **kwargs)
206
+ return obj
207
+ return super().__new__(cls)
208
+
209
+ def __init__(self, quant_config: MoeWNA16Config):
210
+ self.quant_config = quant_config
211
+
212
+ def create_weights(
213
+ self,
214
+ layer: torch.nn.Module,
215
+ num_experts: int,
216
+ hidden_size: int,
217
+ intermediate_size_per_partition: int,
218
+ params_dtype: torch.dtype,
219
+ **extra_weight_attrs,
220
+ ):
221
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
222
+
223
+ layer.quant_config = self.quant_config
224
+ bit8_pack_factor = self.quant_config.bit8_pack_factor
225
+ group_size = self.quant_config.group_size
226
+ group_size_div_factor = 1
227
+
228
+ # make intermediate_size and hidden_size diviable by group_size
229
+ # we reduce the group size to ensure that
230
+ # and we would repeat the loaded_weight later
231
+ while intermediate_size_per_partition % group_size or hidden_size % group_size:
232
+ group_size = group_size // 2
233
+ group_size_div_factor *= 2
234
+ assert group_size >= 32
235
+ layer.group_size = group_size
236
+ layer.group_size_div_factor = group_size_div_factor
237
+
238
+ strategy = FusedMoeWeightScaleSupported.GROUP.value
239
+ extra_weight_attrs.update({"quant_method": strategy, "is_transposed": False})
240
+
241
+ assert "weight_loader" in extra_weight_attrs
242
+ weight_loader = extra_weight_attrs["weight_loader"]
243
+ wrapped_weight_loader = MoeWNA16Method.get_weight_loader(layer, weight_loader)
244
+ extra_weight_attrs["weight_loader"] = wrapped_weight_loader
245
+
246
+ # Fused gate_up_proj (column parallel)
247
+ w13_qweight = torch.nn.Parameter(
248
+ torch.empty(
249
+ num_experts,
250
+ 2 * intermediate_size_per_partition,
251
+ hidden_size // bit8_pack_factor,
252
+ dtype=torch.uint8,
253
+ ),
254
+ requires_grad=False,
255
+ )
256
+ layer.register_parameter("w13_qweight", w13_qweight)
257
+ set_weight_attrs(w13_qweight, extra_weight_attrs)
258
+
259
+ # down_proj (row parallel)
260
+ w2_qweight = torch.nn.Parameter(
261
+ torch.empty(
262
+ num_experts,
263
+ hidden_size,
264
+ intermediate_size_per_partition // bit8_pack_factor,
265
+ dtype=torch.uint8,
266
+ ),
267
+ requires_grad=False,
268
+ )
269
+ layer.register_parameter("w2_qweight", w2_qweight)
270
+ set_weight_attrs(w2_qweight, extra_weight_attrs)
271
+
272
+ w13_scales = torch.nn.Parameter(
273
+ torch.zeros(
274
+ num_experts,
275
+ 2 * intermediate_size_per_partition,
276
+ hidden_size // group_size,
277
+ dtype=params_dtype,
278
+ ),
279
+ requires_grad=False,
280
+ )
281
+ layer.register_parameter("w13_scales", w13_scales)
282
+ set_weight_attrs(w13_scales, extra_weight_attrs)
283
+
284
+ w2_scales = torch.nn.Parameter(
285
+ torch.zeros(
286
+ num_experts,
287
+ hidden_size,
288
+ intermediate_size_per_partition // group_size,
289
+ dtype=params_dtype,
290
+ ),
291
+ requires_grad=False,
292
+ )
293
+ layer.register_parameter("w2_scales", w2_scales)
294
+ set_weight_attrs(w2_scales, extra_weight_attrs)
295
+
296
+ if self.quant_config.has_zp:
297
+ w13_qzeros = torch.nn.Parameter(
298
+ torch.zeros(
299
+ num_experts,
300
+ 2 * intermediate_size_per_partition // bit8_pack_factor,
301
+ hidden_size // group_size,
302
+ dtype=torch.uint8,
303
+ ),
304
+ requires_grad=False,
305
+ )
306
+ layer.register_parameter("w13_qzeros", w13_qzeros)
307
+ set_weight_attrs(w13_qzeros, extra_weight_attrs)
308
+
309
+ w2_qzeros = torch.nn.Parameter(
310
+ torch.zeros(
311
+ num_experts,
312
+ hidden_size // bit8_pack_factor,
313
+ intermediate_size_per_partition // group_size,
314
+ dtype=torch.uint8,
315
+ ),
316
+ requires_grad=False,
317
+ )
318
+ layer.register_parameter("w2_qzeros", w2_qzeros)
319
+ set_weight_attrs(w2_qzeros, extra_weight_attrs)
320
+
321
+ if self.quant_config.linear_quant_method == "gptq":
322
+ # some param are unused, but we need to init them in order to
323
+ # load weights
324
+ invalid_param_keys = ["w13_g_idx", "w2_g_idx"]
325
+ if not self.quant_config.has_zp:
326
+ invalid_param_keys += ["w13_qzeros", "w2_qzeros"]
327
+ for key in invalid_param_keys:
328
+ param = torch.nn.Parameter(
329
+ torch.empty((0,), dtype=torch.int32), requires_grad=False
330
+ )
331
+ layer.register_parameter(key, param)
332
+ set_weight_attrs(param, extra_weight_attrs)
333
+
334
+ def apply(
335
+ self,
336
+ layer: torch.nn.Module,
337
+ x: torch.Tensor,
338
+ router_logits: torch.Tensor,
339
+ top_k: int,
340
+ renormalize: bool,
341
+ use_grouped_topk: bool = False,
342
+ topk_group: Optional[int] = None,
343
+ num_expert_group: Optional[int] = None,
344
+ custom_routing_function: Optional[Callable] = None,
345
+ correction_bias: Optional[torch.Tensor] = None,
346
+ activation: str = "silu",
347
+ inplace: bool = True,
348
+ no_combine: bool = False,
349
+ ) -> torch.Tensor:
350
+ # avoid circular import
351
+ from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
352
+ from sglang.srt.layers.moe.topk import select_experts
353
+
354
+ assert activation == "silu", "Only SiLU activation is supported."
355
+ topk_weights, topk_ids = select_experts(
356
+ hidden_states=x,
357
+ router_logits=router_logits,
358
+ top_k=top_k,
359
+ use_grouped_topk=use_grouped_topk,
360
+ renormalize=renormalize,
361
+ topk_group=topk_group,
362
+ num_expert_group=num_expert_group,
363
+ custom_routing_function=custom_routing_function,
364
+ correction_bias=correction_bias,
365
+ )
366
+
367
+ weight_bits = self.quant_config.weight_bits
368
+ has_zp = self.quant_config.has_zp
369
+
370
+ return fused_experts(
371
+ x,
372
+ layer.w13_qweight,
373
+ layer.w2_qweight,
374
+ topk_weights=topk_weights,
375
+ topk_ids=topk_ids,
376
+ inplace=inplace,
377
+ use_int4_w4a16=weight_bits == 4,
378
+ use_int8_w8a16=weight_bits == 8,
379
+ w1_scale=layer.w13_scales,
380
+ w2_scale=layer.w2_scales,
381
+ w1_zp=layer.w13_qzeros if has_zp else None,
382
+ w2_zp=layer.w2_qzeros if has_zp else None,
383
+ block_shape=[0, layer.group_size],
384
+ no_combine=no_combine,
385
+ )
386
+
387
+ @staticmethod
388
+ def get_weight_loader(layer, weight_loader):
389
+
390
+ def convert_awq_tensor(tensor, tensor_type):
391
+ # convert awq qweight/qzeros to a standard format (assume int4)
392
+ # qweight: (k, n // pack_factor_bit32) -> (n, k // pack_factor_bit8)
393
+ # qzeros: (k // group_size, n // pack_factor_bit32) ->
394
+ # (n // pack_factor_bit8, k // group_size)
395
+ # pack_factor_bit32 = 32 // weight_bits
396
+ # pack_factor_bit8 = 8 // weight_bits
397
+
398
+ # 0. suppose origin shape (a, b), dtype int32
399
+ # 1. convert to uint8, shape (a, b) -> (a, 4 * b)
400
+ size0 = tensor.size(0)
401
+ tensor = tensor.view(torch.uint8)
402
+
403
+ # 2. unpack to uint4 (only when weight_bits == 4)
404
+ # shape (a, 4 * b) -> (a, 4 * b, 2)
405
+ shifter = torch.tensor([0, 4], dtype=torch.uint8, device=tensor.device)
406
+ tensor = (tensor[:, :, None] >> shifter) & 0xF
407
+
408
+ # 3. change order, see
409
+ # https://github.com/casper-hansen/AutoAWQ/blob/v0.2.8/awq/utils/quant_utils.py
410
+ # shape -> (a, 4 * b * pack_factor_bit8)
411
+ reverse_awq_pack_order = [0, 4, 1, 5, 2, 6, 3, 7]
412
+ tensor = tensor.view(-1, 8)[:, reverse_awq_pack_order]
413
+ tensor = tensor.view(size0, -1)
414
+
415
+ # 4. transpose, shape -> (4 * b * pack_factor_bit8, a)
416
+ tensor = tensor.T.contiguous()
417
+
418
+ # 5. repack (only when weight_bits == 4)
419
+ # qweight shape -> (4 * b * pack_factor_bit8, a // pack_factor_bit8)
420
+ # qzeros shape -> (4 * b, a)
421
+
422
+ if tensor_type == "qweight":
423
+ tensor = tensor[:, 1::2] * 16 + tensor[:, ::2]
424
+ elif tensor_type == "qzeros":
425
+ tensor = tensor[1::2, :] * 16 + tensor[::2, :]
426
+ return tensor
427
+
428
+ def convert_gptq_int4_qzeros(tensor):
429
+ tensor = tensor.view(torch.uint8)
430
+ shifter = torch.tensor([0, 4], dtype=torch.uint8, device=tensor.device)
431
+ tensor = (tensor[:, :, None] >> shifter) & 0xF
432
+ tensor = tensor + 1
433
+ tensor = tensor[:, :, 0] + tensor[:, :, 1] * 16
434
+ return tensor
435
+
436
+ def moe_wna16_weight_loader(
437
+ param: torch.nn.Parameter,
438
+ loaded_weight: torch.Tensor,
439
+ weight_name: str,
440
+ shard_id: str,
441
+ expert_id: int,
442
+ ):
443
+ if "g_idx" in weight_name:
444
+ return
445
+ if not layer.quant_config.has_zp and "qzeros" in weight_name:
446
+ return
447
+
448
+ device = get_tp_group().device
449
+ tp_rank = get_tensor_model_parallel_rank()
450
+ loaded_weight = loaded_weight.to(device)
451
+ shard_size = layer.intermediate_size_per_partition
452
+
453
+ # convert gptq and awq weight to a standard format
454
+ if layer.quant_config.linear_quant_method == "awq":
455
+ assert layer.quant_config.weight_bits == 4
456
+ if "weight" in weight_name:
457
+ loaded_weight = convert_awq_tensor(loaded_weight, "qweight")
458
+ elif "zeros" in weight_name:
459
+ loaded_weight = convert_awq_tensor(loaded_weight, "qzeros")
460
+ else:
461
+ loaded_weight = loaded_weight.T
462
+ elif layer.quant_config.linear_quant_method == "gptq":
463
+ assert layer.quant_config.weight_bits in [4, 8]
464
+ if "weight" in weight_name:
465
+ loaded_weight = loaded_weight.T.contiguous().view(torch.uint8)
466
+ elif "zeros" in weight_name:
467
+ # add 1 to gptq qzeros to align with awq
468
+ loaded_weight = loaded_weight.view(torch.uint8)
469
+ if layer.quant_config.weight_bits == 4:
470
+ loaded_weight = convert_gptq_int4_qzeros(loaded_weight).T
471
+ else:
472
+ loaded_weight = loaded_weight.T + 1
473
+ else:
474
+ loaded_weight = loaded_weight.T
475
+
476
+ # repeat the qzeros/scales to fit new group size
477
+ if (
478
+ layer.group_size_div_factor > 1
479
+ and "qzeros" in weight_name
480
+ or "scales" in weight_name
481
+ ):
482
+ loaded_weight = loaded_weight.repeat_interleave(
483
+ layer.group_size_div_factor, 1
484
+ )
485
+
486
+ if "w13_qzeros" in weight_name:
487
+ tensor = loaded_weight.view(layer.tp_size, -1, loaded_weight.size(1))[
488
+ tp_rank
489
+ ]
490
+ if shard_id == "w1":
491
+ param.data[expert_id, : shard_size // 2] = tensor
492
+ else:
493
+ param.data[expert_id, shard_size // 2 :] = tensor
494
+ elif "w2_qzeros" in weight_name:
495
+ param.data[expert_id] = loaded_weight.view(
496
+ loaded_weight.size(0), layer.tp_size, -1
497
+ )[:, tp_rank]
498
+ else:
499
+ weight_loader(param, loaded_weight, weight_name, shard_id, expert_id)
500
+
501
+ return moe_wna16_weight_loader
@@ -1,7 +1,7 @@
1
1
  # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/quant_utils.py
2
2
 
3
3
  from types import MappingProxyType
4
- from typing import List, Mapping, Tuple, Union
4
+ from typing import List, Mapping, Optional, Tuple, Union
5
5
 
6
6
  import torch
7
7
 
@@ -37,7 +37,7 @@ class W8A8Fp8Config(QuantizationConfig):
37
37
  Note:
38
38
  - For models without offline quantization, weights will be quantized during model loading
39
39
  - If CUTLASS is supported: Per-channel weight quantization is used
40
- - If CUTLASS is not supported: Falls back to per-token weight quantization
40
+ - If CUTLASS is not supported: Falls back to per-tensor weight quantization
41
41
  """
42
42
 
43
43
  def __init__(self, is_checkpoint_fp8_serialized: bool = False):
@@ -5,7 +5,7 @@ import torch
5
5
  from sglang.srt.lora.utils import LoRABatchInfo
6
6
 
7
7
 
8
- def get_fuse_output_scaling_add_from_name(name: str) -> bool:
8
+ def get_fuse_output_add_from_name(name: str) -> bool:
9
9
  mapping = {
10
10
  "triton": True,
11
11
  "flashinfer": False,
@@ -28,14 +28,14 @@ class BaseLoRABackend:
28
28
  Args:
29
29
  name: name of backend
30
30
  batch_info: information of current batch for use
31
- fuse_output_scaling_add: if set to True, the output buffer for storing result will be passed in when doing lora_b forward,
32
- and the operation of scaling and adding will be fused into kernel
31
+ fuse_output_add: if set to True, the output buffer for storing result will be passed in when doing lora_b forward,
32
+ and the operation of adding will be fused into kernel
33
33
  """
34
34
 
35
35
  def __init__(self, name: str, batch_info: LoRABatchInfo = None):
36
36
  self.name = name
37
37
  self.batch_info = batch_info
38
- self.fuse_output_scaling_add = get_fuse_output_scaling_add_from_name(name)
38
+ self.fuse_output_add = get_fuse_output_add_from_name(name)
39
39
  self.fuse_stacked_lora_b = get_fuse_stacked_lora_b_from_name(name)
40
40
 
41
41
  def run_lora_a_sgemm(
@@ -37,13 +37,16 @@ class FlashInferLoRABackend(BaseLoRABackend):
37
37
  self, x: torch.Tensor, weights: torch.Tensor, *args, **kwargs
38
38
  ) -> torch.Tensor:
39
39
 
40
- return self.segment_gemm.run(
41
- x=x,
42
- weights=weights,
43
- batch_size=self.batch_info.bs,
44
- weight_column_major=True,
45
- seg_indptr=self.batch_info.seg_indptr,
46
- weight_indices=self.batch_info.weight_indices,
40
+ return (
41
+ self.segment_gemm.run(
42
+ x=x,
43
+ weights=weights,
44
+ batch_size=self.batch_info.bs,
45
+ weight_column_major=True,
46
+ seg_indptr=self.batch_info.seg_indptr,
47
+ weight_indices=self.batch_info.weight_indices,
48
+ )
49
+ * self.batch_info.scalings[0]
47
50
  )
48
51
 
49
52
  def run_qkv_lora(
@@ -90,7 +93,7 @@ class FlashInferLoRABackend(BaseLoRABackend):
90
93
  weights=kv_lora_b[1],
91
94
  )
92
95
 
93
- return lora_output
96
+ return lora_output * self.batch_info.scalings[0]
94
97
 
95
98
  def run_gate_up_lora(
96
99
  self,
@@ -125,4 +128,4 @@ class FlashInferLoRABackend(BaseLoRABackend):
125
128
  weights=gate_up_lora_b[1],
126
129
  )
127
130
 
128
- return lora_output
131
+ return lora_output * self.batch_info.scalings[0]
@@ -25,11 +25,10 @@ class TritonLoRABackend(BaseLoRABackend):
25
25
  x: torch.Tensor,
26
26
  weights: torch.Tensor,
27
27
  base_output: torch.Tensor = None,
28
- scaling: float = 1.0,
29
28
  *args,
30
29
  **kwargs
31
30
  ) -> torch.Tensor:
32
- return sgemm_lora_b_fwd(x, weights, self.batch_info, base_output, scaling)
31
+ return sgemm_lora_b_fwd(x, weights, self.batch_info, base_output)
33
32
 
34
33
  def run_qkv_lora(
35
34
  self,
@@ -39,7 +38,6 @@ class TritonLoRABackend(BaseLoRABackend):
39
38
  output_offset: torch.Tensor,
40
39
  max_qkv_out_dim: int,
41
40
  base_output: torch.Tensor = None,
42
- scaling: float = 1.0,
43
41
  *args,
44
42
  **kwargs
45
43
  ) -> torch.Tensor:
@@ -49,7 +47,7 @@ class TritonLoRABackend(BaseLoRABackend):
49
47
  # qkv_lora_b: (num_lora, output_dim_q + 2 * output_dim_kv, r)
50
48
  assert isinstance(qkv_lora_b, torch.Tensor)
51
49
 
52
- lora_a_output = sgemm_lora_a_fwd(x, qkv_lora_a, self.batch_info)
50
+ lora_a_output = sgemm_lora_a_fwd(x, qkv_lora_a, self.batch_info, stack_num=3)
53
51
  lora_output = qkv_lora_b_fwd(
54
52
  lora_a_output,
55
53
  qkv_lora_b,
@@ -57,7 +55,6 @@ class TritonLoRABackend(BaseLoRABackend):
57
55
  output_offset,
58
56
  max_qkv_out_dim,
59
57
  base_output,
60
- scaling,
61
58
  )
62
59
  return lora_output
63
60
 
@@ -67,7 +64,6 @@ class TritonLoRABackend(BaseLoRABackend):
67
64
  gate_up_lora_a: torch.Tensor,
68
65
  gate_up_lora_b: torch.Tensor,
69
66
  base_output: torch.Tensor = None,
70
- scaling: float = 1.0,
71
67
  *args,
72
68
  **kwargs
73
69
  ) -> torch.Tensor:
@@ -79,13 +75,14 @@ class TritonLoRABackend(BaseLoRABackend):
79
75
  output_dim = gate_up_lora_b.shape[-2] // 2
80
76
 
81
77
  # lora_a_output: (s, 2 * r)
82
- lora_a_output = sgemm_lora_a_fwd(x, gate_up_lora_a, self.batch_info)
78
+ lora_a_output = sgemm_lora_a_fwd(
79
+ x, gate_up_lora_a, self.batch_info, stack_num=2
80
+ )
83
81
  lora_output = gate_up_lora_b_fwd(
84
82
  lora_a_output,
85
83
  gate_up_lora_b,
86
84
  self.batch_info,
87
85
  output_dim,
88
86
  base_output,
89
- scaling,
90
87
  )
91
88
  return lora_output