sglang 0.5.4.post1__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (150) hide show
  1. sglang/bench_one_batch.py +149 -34
  2. sglang/bench_serving.py +18 -3
  3. sglang/compile_deep_gemm.py +13 -7
  4. sglang/srt/batch_invariant_ops/__init__.py +2 -0
  5. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +120 -0
  6. sglang/srt/checkpoint_engine/__init__.py +9 -0
  7. sglang/srt/checkpoint_engine/update.py +317 -0
  8. sglang/srt/configs/__init__.py +2 -0
  9. sglang/srt/configs/deepseek_ocr.py +542 -10
  10. sglang/srt/configs/deepseekvl2.py +95 -194
  11. sglang/srt/configs/kimi_linear.py +160 -0
  12. sglang/srt/configs/mamba_utils.py +66 -0
  13. sglang/srt/configs/model_config.py +25 -2
  14. sglang/srt/constants.py +7 -0
  15. sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
  16. sglang/srt/disaggregation/decode.py +34 -6
  17. sglang/srt/disaggregation/nixl/conn.py +2 -2
  18. sglang/srt/disaggregation/prefill.py +25 -3
  19. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
  20. sglang/srt/distributed/parallel_state.py +9 -5
  21. sglang/srt/entrypoints/engine.py +13 -5
  22. sglang/srt/entrypoints/http_server.py +22 -3
  23. sglang/srt/entrypoints/openai/protocol.py +7 -1
  24. sglang/srt/entrypoints/openai/serving_chat.py +42 -0
  25. sglang/srt/entrypoints/openai/serving_completions.py +10 -0
  26. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  27. sglang/srt/environ.py +7 -0
  28. sglang/srt/eplb/expert_distribution.py +34 -1
  29. sglang/srt/eplb/expert_location.py +106 -36
  30. sglang/srt/grpc/compile_proto.py +3 -0
  31. sglang/srt/layers/attention/ascend_backend.py +233 -5
  32. sglang/srt/layers/attention/attention_registry.py +3 -0
  33. sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
  34. sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
  35. sglang/srt/layers/attention/fla/kda.py +1359 -0
  36. sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
  37. sglang/srt/layers/attention/flashattention_backend.py +7 -6
  38. sglang/srt/layers/attention/flashinfer_mla_backend.py +3 -1
  39. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  40. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
  41. sglang/srt/layers/attention/mamba/mamba.py +20 -11
  42. sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
  43. sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
  44. sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
  45. sglang/srt/layers/attention/nsa/transform_index.py +1 -1
  46. sglang/srt/layers/attention/nsa_backend.py +157 -23
  47. sglang/srt/layers/attention/triton_backend.py +4 -1
  48. sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
  49. sglang/srt/layers/attention/trtllm_mla_backend.py +10 -2
  50. sglang/srt/layers/communicator.py +23 -1
  51. sglang/srt/layers/layernorm.py +16 -2
  52. sglang/srt/layers/logits_processor.py +4 -20
  53. sglang/srt/layers/moe/ep_moe/layer.py +0 -18
  54. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
  56. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
  57. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
  58. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
  59. sglang/srt/layers/moe/moe_runner/deep_gemm.py +53 -33
  60. sglang/srt/layers/moe/token_dispatcher/deepep.py +12 -9
  61. sglang/srt/layers/moe/topk.py +31 -6
  62. sglang/srt/layers/pooler.py +21 -2
  63. sglang/srt/layers/quantization/__init__.py +9 -78
  64. sglang/srt/layers/quantization/auto_round.py +394 -0
  65. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  66. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  67. sglang/srt/layers/quantization/modelopt_quant.py +168 -11
  68. sglang/srt/layers/rotary_embedding.py +117 -45
  69. sglang/srt/lora/lora_registry.py +9 -0
  70. sglang/srt/managers/async_mm_data_processor.py +122 -0
  71. sglang/srt/managers/data_parallel_controller.py +30 -3
  72. sglang/srt/managers/detokenizer_manager.py +3 -0
  73. sglang/srt/managers/io_struct.py +26 -4
  74. sglang/srt/managers/multi_tokenizer_mixin.py +5 -0
  75. sglang/srt/managers/schedule_batch.py +74 -15
  76. sglang/srt/managers/scheduler.py +164 -129
  77. sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
  78. sglang/srt/managers/scheduler_pp_mixin.py +7 -2
  79. sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
  80. sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
  81. sglang/srt/managers/session_controller.py +6 -5
  82. sglang/srt/managers/tokenizer_manager.py +154 -59
  83. sglang/srt/managers/tp_worker.py +24 -1
  84. sglang/srt/mem_cache/base_prefix_cache.py +23 -4
  85. sglang/srt/mem_cache/common.py +1 -0
  86. sglang/srt/mem_cache/memory_pool.py +171 -57
  87. sglang/srt/mem_cache/memory_pool_host.py +12 -5
  88. sglang/srt/mem_cache/radix_cache.py +4 -0
  89. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
  90. sglang/srt/metrics/collector.py +46 -3
  91. sglang/srt/model_executor/cuda_graph_runner.py +15 -3
  92. sglang/srt/model_executor/forward_batch_info.py +11 -11
  93. sglang/srt/model_executor/model_runner.py +76 -21
  94. sglang/srt/model_executor/npu_graph_runner.py +7 -3
  95. sglang/srt/model_loader/weight_utils.py +1 -1
  96. sglang/srt/models/bailing_moe.py +9 -2
  97. sglang/srt/models/deepseek_nextn.py +11 -2
  98. sglang/srt/models/deepseek_v2.py +149 -34
  99. sglang/srt/models/glm4.py +391 -77
  100. sglang/srt/models/glm4v.py +196 -55
  101. sglang/srt/models/glm4v_moe.py +0 -1
  102. sglang/srt/models/gpt_oss.py +1 -10
  103. sglang/srt/models/kimi_linear.py +678 -0
  104. sglang/srt/models/llama4.py +1 -1
  105. sglang/srt/models/llama_eagle3.py +11 -1
  106. sglang/srt/models/longcat_flash.py +2 -2
  107. sglang/srt/models/minimax_m2.py +1 -1
  108. sglang/srt/models/qwen2.py +1 -1
  109. sglang/srt/models/qwen2_moe.py +30 -15
  110. sglang/srt/models/qwen3.py +1 -1
  111. sglang/srt/models/qwen3_moe.py +16 -8
  112. sglang/srt/models/qwen3_next.py +7 -0
  113. sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
  114. sglang/srt/multiplex/multiplexing_mixin.py +209 -0
  115. sglang/srt/multiplex/pdmux_context.py +164 -0
  116. sglang/srt/parser/conversation.py +7 -1
  117. sglang/srt/sampling/custom_logit_processor.py +67 -1
  118. sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
  119. sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
  120. sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
  121. sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
  122. sglang/srt/server_args.py +103 -22
  123. sglang/srt/single_batch_overlap.py +4 -1
  124. sglang/srt/speculative/draft_utils.py +16 -0
  125. sglang/srt/speculative/eagle_info.py +42 -36
  126. sglang/srt/speculative/eagle_info_v2.py +68 -25
  127. sglang/srt/speculative/eagle_utils.py +261 -16
  128. sglang/srt/speculative/eagle_worker.py +11 -3
  129. sglang/srt/speculative/eagle_worker_v2.py +15 -9
  130. sglang/srt/speculative/spec_info.py +305 -31
  131. sglang/srt/speculative/spec_utils.py +44 -8
  132. sglang/srt/tracing/trace.py +121 -12
  133. sglang/srt/utils/common.py +55 -32
  134. sglang/srt/utils/hf_transformers_utils.py +38 -16
  135. sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
  136. sglang/test/kits/radix_cache_server_kit.py +50 -0
  137. sglang/test/runners.py +31 -7
  138. sglang/test/simple_eval_common.py +5 -3
  139. sglang/test/simple_eval_humaneval.py +1 -0
  140. sglang/test/simple_eval_math.py +1 -0
  141. sglang/test/simple_eval_mmlu.py +1 -0
  142. sglang/test/simple_eval_mmmu_vlm.py +1 -0
  143. sglang/test/test_utils.py +7 -1
  144. sglang/version.py +1 -1
  145. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +10 -24
  146. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +150 -136
  147. /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
  148. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
  149. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
  150. {sglang-0.5.4.post1.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
@@ -7,33 +7,16 @@ from typing import TYPE_CHECKING, Dict, Optional, Type
7
7
 
8
8
  import torch
9
9
 
10
- try:
11
- from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
12
- from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig
13
- from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig
14
- from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
15
- from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
16
- GPTQMarlin24Config,
17
- )
18
- from vllm.model_executor.layers.quantization.marlin import MarlinConfig
19
- from vllm.model_executor.layers.quantization.qqq import QQQConfig
20
- from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig
21
-
22
- VLLM_AVAILABLE = True
23
- except ImportError as e:
24
- VLLM_AVAILABLE = False
25
- VLLM_IMPORT_ERROR = e
26
10
 
27
- # Define empty classes as placeholders when vllm is not available
28
- class DummyConfig:
29
- def override_quantization_method(self, *args, **kwargs):
30
- return None
11
+ # Define empty classes as placeholders when vllm is not available
12
+ class DummyConfig:
13
+ def override_quantization_method(self, *args, **kwargs):
14
+ return None
31
15
 
32
- AQLMConfig = BitsAndBytesConfig = CompressedTensorsConfig = DeepSpeedFPConfig = (
33
- ExpertsInt8Config
34
- ) = GPTQMarlin24Config = MarlinConfig = QQQConfig = Int8TpuConfig = DummyConfig
35
16
 
17
+ CompressedTensorsConfig = DummyConfig
36
18
 
19
+ from sglang.srt.layers.quantization.auto_round import AutoRoundConfig
37
20
  from sglang.srt.layers.quantization.awq import AWQConfig, AWQMarlinConfig
38
21
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
39
22
  from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config
@@ -62,7 +45,7 @@ _is_mxfp_supported = mxfp_supported()
62
45
  if TYPE_CHECKING:
63
46
  from sglang.srt.layers.moe.topk import TopKOutput
64
47
 
65
- # Base quantization methods that don't depend on vllm
48
+ # Base quantization methods
66
49
  BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
67
50
  "fp8": Fp8Config,
68
51
  "blockwise_int8": BlockInt8Config,
@@ -82,6 +65,7 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
82
65
  "w4afp8": W4AFp8Config,
83
66
  "petit_nvfp4": PetitNvFp4Config,
84
67
  "fbgemm_fp8": FBGEMMFp8Config,
68
+ "auto-round": AutoRoundConfig,
85
69
  }
86
70
 
87
71
 
@@ -101,19 +85,8 @@ elif _is_mxfp_supported and is_hip():
101
85
  "mxfp4": Mxfp4Config,
102
86
  }
103
87
  )
104
- # VLLM-dependent quantization methods
105
- VLLM_QUANTIZATION_METHODS = {
106
- "aqlm": AQLMConfig,
107
- "deepspeedfp": DeepSpeedFPConfig,
108
- "tpu_int8": Int8TpuConfig,
109
- "marlin": MarlinConfig,
110
- "gptq_marlin_24": GPTQMarlin24Config,
111
- "bitsandbytes": BitsAndBytesConfig,
112
- "qqq": QQQConfig,
113
- "experts_int8": ExpertsInt8Config,
114
- }
115
88
 
116
- QUANTIZATION_METHODS = {**BASE_QUANTIZATION_METHODS, **VLLM_QUANTIZATION_METHODS}
89
+ QUANTIZATION_METHODS = {**BASE_QUANTIZATION_METHODS}
117
90
 
118
91
 
119
92
  def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
@@ -122,50 +95,8 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
122
95
  f"Invalid quantization method: {quantization}. "
123
96
  f"Available methods: {list(QUANTIZATION_METHODS.keys())}"
124
97
  )
125
- if quantization in VLLM_QUANTIZATION_METHODS and not VLLM_AVAILABLE:
126
- raise ValueError(
127
- f"{quantization} quantization requires some operators from vllm. "
128
- f"Please install vllm by `pip install vllm==0.9.0.1`\n"
129
- f"Import error: {VLLM_IMPORT_ERROR}"
130
- )
131
98
 
132
99
  return QUANTIZATION_METHODS[quantization]
133
100
 
134
101
 
135
102
  original_isinstance = builtins.isinstance
136
-
137
-
138
- def monkey_patch_isinstance_for_vllm_base_layer(reverse: bool = False):
139
- """
140
- Patch isinstance so that the `get_quant_method` in vllm's QuantizationConfig
141
- can recognize sglang layers
142
- """
143
- if not VLLM_AVAILABLE:
144
- return
145
-
146
- if reverse:
147
- builtins.isinstance = original_isinstance
148
- return
149
-
150
- from vllm.model_executor.layers.fused_moe import FusedMoE
151
- from vllm.model_executor.layers.linear import LinearBase
152
- from vllm.model_executor.layers.vocab_parallel_embedding import (
153
- VocabParallelEmbedding,
154
- )
155
-
156
- from sglang.srt.layers.linear import LinearBase as PatchedLinearBase
157
- from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE as PatchedFusedMoE
158
- from sglang.srt.layers.vocab_parallel_embedding import (
159
- VocabParallelEmbedding as PatchedVocabParallelEmbedding,
160
- )
161
-
162
- def patched_isinstance(obj, classinfo):
163
- if classinfo is LinearBase:
164
- return original_isinstance(obj, PatchedLinearBase)
165
- if classinfo is FusedMoE:
166
- return original_isinstance(obj, PatchedFusedMoE)
167
- if classinfo is VocabParallelEmbedding:
168
- return original_isinstance(obj, PatchedVocabParallelEmbedding)
169
- return original_isinstance(obj, classinfo)
170
-
171
- builtins.isinstance = patched_isinstance
@@ -0,0 +1,394 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+
3
+ import logging
4
+ import re
5
+ from fractions import Fraction
6
+ from typing import Any, Optional, Union
7
+
8
+ import torch
9
+
10
+ logger = logging.getLogger(__name__)
11
+
12
+ from sglang.srt.layers.quantization.utils import get_scalar_types
13
+
14
+ ScalarType, scalar_types = get_scalar_types()
15
+
16
+
17
+ from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
18
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
19
+
20
+
21
+ class AutoRoundConfig(QuantizationConfig):
22
+ """Config class for AutoRound.
23
+ Reference: https://arxiv.org/pdf/2309.05516
24
+ """
25
+
26
+ SUPPORTED_BITS = {2, 3, 4, 8}
27
+ SUPPORTED_DTYPES = {"int"}
28
+ SUPPORTED_FORMATS = {"auto_round:auto_gptq", "auto_round:auto_awq"}
29
+ SUPPORTED_BACKENDS = {"auto", "gptq", "gptq:marlin", "awq", "awq:marlin", "marlin"}
30
+
31
+ def __init__(
32
+ self,
33
+ weight_bits: int,
34
+ group_size: int,
35
+ sym: bool = True,
36
+ packing_format: str = "auto_round:auto_gptq",
37
+ block_name_to_quantize: Optional[Union[str, list[str]]] = None,
38
+ extra_config: Optional[dict[str, Any]] = None,
39
+ data_type: str = "int",
40
+ backend: str = "auto",
41
+ ) -> None:
42
+ super().__init__()
43
+ if weight_bits not in self.SUPPORTED_BITS:
44
+ raise ValueError(
45
+ f"Unsupported weight_bits: {weight_bits}, "
46
+ f"currently only support {self.SUPPORTED_BITS}"
47
+ )
48
+ if data_type not in self.SUPPORTED_DTYPES:
49
+ raise ValueError(
50
+ f"Unsupported data_type: {data_type},"
51
+ f" currently only support {self.SUPPORTED_DTYPES}"
52
+ )
53
+ if packing_format not in self.SUPPORTED_FORMATS:
54
+ raise ValueError(
55
+ f"Unsupported packing_format: {packing_format}, "
56
+ f"currently only support {self.SUPPORTED_FORMATS}"
57
+ )
58
+ if backend not in self.SUPPORTED_BACKENDS:
59
+ raise ValueError(
60
+ f"Unsupported backend: {backend}, "
61
+ f"currently only support {self.SUPPORTED_BACKENDS}"
62
+ )
63
+
64
+ self.weight_bits = weight_bits
65
+ self.group_size = group_size
66
+ self.sym = sym
67
+ self.packing_format = packing_format
68
+ self.block_name_to_quantize = (
69
+ block_name_to_quantize.split(",")
70
+ if isinstance(block_name_to_quantize, str)
71
+ else block_name_to_quantize
72
+ )
73
+ self.extra_config = extra_config
74
+ self.data_type = data_type
75
+ self.backend = backend
76
+ self.pack_factor = Fraction(32, weight_bits)
77
+
78
+ def __repr__(self) -> str:
79
+ return (
80
+ f"AutoRoundConfig(weight_bits={self.weight_bits}, "
81
+ f"group_size={self.group_size}, sym={self.sym})"
82
+ )
83
+
84
+ @classmethod
85
+ def get_name(cls):
86
+ return "auto-round"
87
+
88
+ @classmethod
89
+ def get_supported_act_dtypes(cls) -> list[torch.dtype]:
90
+ return [torch.half, torch.bfloat16]
91
+
92
+ @classmethod
93
+ def get_min_capability(cls) -> int:
94
+ return 60
95
+
96
+ @classmethod
97
+ def get_config_filenames(cls) -> list[str]:
98
+ return ["quantization_config.json"]
99
+
100
+ @classmethod
101
+ def from_config(cls, config: dict[str, Any]) -> "AutoRoundConfig":
102
+ return cls(
103
+ weight_bits=cls.get_from_keys(config, ["bits"]),
104
+ group_size=cls.get_from_keys(config, ["group_size"]),
105
+ sym=cls.get_from_keys(config, ["sym"]),
106
+ packing_format=cls.get_from_keys_or(
107
+ config,
108
+ ["packing_format"],
109
+ "auto_round:auto_gptq",
110
+ ),
111
+ block_name_to_quantize=cls.get_from_keys_or(
112
+ config, ["block_name_to_quantize", "to_quant_block_names"], None
113
+ ),
114
+ extra_config=cls.get_from_keys_or(config, ["extra_config"], None),
115
+ data_type=cls.get_from_keys_or(config, ["data_type"], "int"),
116
+ backend=cls.get_from_keys_or(
117
+ config, ["backend", "vllm_backend", "sglang_backend"], "auto"
118
+ ),
119
+ )
120
+
121
+ def get_scaled_act_names(self) -> list[str]:
122
+ """Returns the activation function names that should be post-scaled.
123
+
124
+ For now, this is only used by AWQ.
125
+ """
126
+ raise NotImplementedError
127
+
128
+ def get_layer_config(self, layer, layer_name: str):
129
+ from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
130
+
131
+ def get_config(name: str, quantized: bool = True):
132
+ if not self.extra_config:
133
+ return (
134
+ self.weight_bits if quantized else 16,
135
+ self.group_size if quantized else -1,
136
+ self.sym if quantized else True,
137
+ )
138
+
139
+ # Exact match first
140
+ if name in self.extra_config:
141
+ cfg = self.extra_config[name]
142
+ return (
143
+ cfg.get("bits", self.weight_bits if quantized else 16),
144
+ cfg.get("group_size", self.group_size if quantized else -1),
145
+ cfg.get("sym", self.sym if quantized else True),
146
+ )
147
+
148
+ REGEX_SPECIAL_CHARS = set(r"*+?^$()[]{}|\\")
149
+ for pattern, cfg in self.extra_config.items():
150
+ if not isinstance(pattern, str) or not any(
151
+ c in REGEX_SPECIAL_CHARS for c in pattern
152
+ ):
153
+ continue
154
+
155
+ try:
156
+ if re.fullmatch(pattern, name):
157
+ return (
158
+ cfg.get("bits", self.weight_bits if quantized else 16),
159
+ cfg.get("group_size", self.group_size if quantized else -1),
160
+ cfg.get("sym", self.sym if quantized else True),
161
+ )
162
+ except re.error:
163
+ # Invalid regex, ignore.
164
+ continue
165
+
166
+ return (
167
+ self.weight_bits if quantized else 16,
168
+ self.group_size if quantized else -1,
169
+ self.sym if quantized else True,
170
+ )
171
+
172
+ # 1. Exact match from config
173
+ if self.extra_config and layer_name in self.extra_config:
174
+ return get_config(layer_name)
175
+
176
+ # 2. Determine whether layer should be quantized
177
+ quantized = not isinstance(layer, ParallelLMHead)
178
+ if self.block_name_to_quantize:
179
+ quantized = any(
180
+ layer_name.startswith(name) for name in self.block_name_to_quantize
181
+ )
182
+
183
+ # 3. Handle fused MoE
184
+ if self.extra_config and "fusedmoe" in layer.__class__.__name__.lower():
185
+ moe_configs = [
186
+ get_config(name, quantized)
187
+ for name in self.extra_config
188
+ if name.startswith(layer_name)
189
+ ]
190
+ if moe_configs:
191
+ if len(set(moe_configs)) == 1:
192
+ return moe_configs[0]
193
+ raise ValueError(
194
+ f"Fused MoE layer '{layer_name}' requires "
195
+ f"consistent quant config for all sub-layers"
196
+ )
197
+
198
+ # 4. Handle fused QKV or other patterns
199
+ if self.extra_config:
200
+ for fusion_key, sub_keys in self.packed_modules_mapping.items():
201
+ if fusion_key in layer_name and layer_name.count(fusion_key) == 1:
202
+ sub_names = [
203
+ layer_name.replace(fusion_key, sub_key) for sub_key in sub_keys
204
+ ]
205
+ sub_configs = [get_config(name, quantized) for name in sub_names]
206
+ if len(set(sub_configs)) == 1:
207
+ return sub_configs[0]
208
+ raise ValueError(
209
+ f"Fused module '{layer_name}' requires "
210
+ f"consistent quant config for {sub_names}"
211
+ )
212
+
213
+ # 5. Fallback or try a regular expression match
214
+ return get_config(layer_name, quantized)
215
+
216
+ def check_quantized(self, weight_bits: int) -> bool:
217
+ return weight_bits < 16
218
+
219
+ def apply_awq_quant_layer(self, layer, prefix: str, backend: str = "auto"):
220
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
221
+ from sglang.srt.layers.quantization.marlin_utils import (
222
+ check_marlin_supported,
223
+ check_moe_marlin_supports_layer,
224
+ )
225
+ from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
226
+
227
+ weight_bits, group_size, sym = self.get_layer_config(layer, prefix)
228
+ if not self.check_quantized(weight_bits):
229
+ if isinstance(layer, (LinearBase, ParallelLMHead)):
230
+ return UnquantizedLinearMethod()
231
+ else:
232
+ return None
233
+ logger.debug(
234
+ "[%s] Type: %s, Bits: %s, Group Size: %s, Sym: %s",
235
+ prefix,
236
+ layer.__class__.__name__,
237
+ weight_bits,
238
+ group_size,
239
+ sym,
240
+ )
241
+ if backend == "auto" or "marlin" in backend:
242
+ AWQ_TYPE_MAP = {
243
+ 4: scalar_types.uint4,
244
+ 8: scalar_types.uint8,
245
+ }
246
+ use_marlin = (weight_bits in AWQ_TYPE_MAP) and check_marlin_supported(
247
+ AWQ_TYPE_MAP[weight_bits], group_size, not sym
248
+ )
249
+ if isinstance(layer, FusedMoE):
250
+ use_marlin = use_marlin and check_moe_marlin_supports_layer(
251
+ layer, group_size
252
+ )
253
+
254
+ else:
255
+ use_marlin = False
256
+ if use_marlin:
257
+ from sglang.srt.layers.quantization.awq import (
258
+ AWQMarlinConfig,
259
+ AWQMarlinLinearMethod,
260
+ AWQMoEMethod,
261
+ )
262
+
263
+ quant_args_marlin = AWQMarlinConfig(
264
+ weight_bits=weight_bits,
265
+ group_size=group_size,
266
+ zero_point=not sym,
267
+ lm_head_quantized=False,
268
+ full_config={},
269
+ modules_to_not_convert=[],
270
+ )
271
+ else:
272
+ from sglang.srt.layers.quantization.awq import AWQConfig, AWQLinearMethod
273
+
274
+ quant_args = AWQConfig(
275
+ weight_bits=weight_bits,
276
+ group_size=group_size,
277
+ zero_point=not sym,
278
+ )
279
+
280
+ if isinstance(layer, FusedMoE):
281
+ if use_marlin:
282
+ return AWQMoEMethod(quant_args_marlin)
283
+ from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config
284
+
285
+ config = {
286
+ "quant_method": "awq",
287
+ "bits": weight_bits,
288
+ "group_size": group_size,
289
+ "zero_point": not sym,
290
+ "lm_head": False,
291
+ }
292
+ return MoeWNA16Config.from_config(config).get_quant_method(layer, prefix)
293
+
294
+ if isinstance(layer, (LinearBase, ParallelLMHead)):
295
+ if use_marlin:
296
+ return AWQMarlinLinearMethod(quant_args_marlin)
297
+ else:
298
+ return AWQLinearMethod(quant_args)
299
+ return None
300
+
301
+ def apply_gptq_quant_layer(self, layer, prefix: str, backend: str = "auto"):
302
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
303
+ from sglang.srt.layers.quantization.marlin_utils import (
304
+ check_marlin_supported,
305
+ check_moe_marlin_supports_layer,
306
+ )
307
+ from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
308
+
309
+ weight_bits, group_size, sym = self.get_layer_config(layer, prefix)
310
+ if not self.check_quantized(weight_bits):
311
+ if isinstance(layer, (LinearBase, ParallelLMHead)):
312
+ return UnquantizedLinearMethod()
313
+ else:
314
+ return None
315
+
316
+ logger.debug(
317
+ "[%s] Type: %s, Bits: %s, Group Size: %s, Sym: %s",
318
+ prefix,
319
+ layer.__class__.__name__,
320
+ weight_bits,
321
+ group_size,
322
+ sym,
323
+ )
324
+ if backend == "auto" or "marlin" in backend:
325
+ GPTQ_TYPE_MAP = {
326
+ (4, True): scalar_types.uint4b8,
327
+ (8, True): scalar_types.uint8b128,
328
+ }
329
+ use_marlin = (weight_bits, sym) in GPTQ_TYPE_MAP and check_marlin_supported(
330
+ GPTQ_TYPE_MAP[(weight_bits, sym)], group_size, has_zp=not sym
331
+ )
332
+ if isinstance(layer, FusedMoE):
333
+ use_marlin = use_marlin and check_moe_marlin_supports_layer(
334
+ layer, group_size
335
+ )
336
+ else:
337
+ use_marlin = False
338
+ if use_marlin:
339
+ from sglang.srt.layers.quantization.gptq import (
340
+ GPTQMarlinConfig,
341
+ GPTQMarlinLinearMethod,
342
+ GPTQMarlinMoEMethod,
343
+ )
344
+
345
+ quant_args_marlin = GPTQMarlinConfig(
346
+ weight_bits=weight_bits,
347
+ group_size=group_size,
348
+ is_sym=sym,
349
+ lm_head_quantized=False,
350
+ desc_act=False,
351
+ dynamic={},
352
+ full_config={},
353
+ )
354
+ else:
355
+ from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQLinearMethod
356
+
357
+ quant_args = GPTQConfig(
358
+ weight_bits=weight_bits,
359
+ group_size=group_size,
360
+ lm_head_quantized=False,
361
+ desc_act=False,
362
+ dynamic={},
363
+ )
364
+
365
+ if isinstance(layer, FusedMoE):
366
+ if use_marlin:
367
+ from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config
368
+
369
+ config = {
370
+ "quant_method": "gptq",
371
+ "bits": weight_bits,
372
+ "group_size": group_size,
373
+ "sym": sym,
374
+ "lm_head": False,
375
+ }
376
+ return MoeWNA16Config.from_config(config).get_quant_method(
377
+ layer, prefix
378
+ )
379
+ return GPTQMarlinMoEMethod(quant_args_marlin)
380
+
381
+ if isinstance(layer, (LinearBase, ParallelLMHead)):
382
+ if use_marlin:
383
+ return GPTQMarlinLinearMethod(quant_args_marlin)
384
+ else:
385
+ return GPTQLinearMethod(quant_args)
386
+
387
+ return None
388
+
389
+ def get_quant_method(self, layer: torch.nn.Module, prefix: str):
390
+ # TODO enable CPU quant method later
391
+ if "gptq" in self.packing_format or "gptq" in self.backend:
392
+ return self.apply_gptq_quant_layer(layer, prefix)
393
+ if "awq" in self.packing_format or "awq" in self.backend:
394
+ return self.apply_awq_quant_layer(layer, prefix)
@@ -459,7 +459,7 @@ def create_per_token_group_quant_fp8_output_scale(
459
459
  x_shape[:-2] + (x_shape[-1] // group_size, aligned_size),
460
460
  device=device,
461
461
  dtype=torch.float32,
462
- ).permute(-1, -2)[: x_shape[-2], :]
462
+ ).transpose(-1, -2)[: x_shape[-2], :]
463
463
  else:
464
464
  return torch.empty(
465
465
  (x_shape[-1] // group_size,) + x_shape[:-1],
@@ -5,7 +5,7 @@ import torch
5
5
  from sglang.srt.layers import deep_gemm_wrapper
6
6
  from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8
7
7
  from sglang.srt.layers.quantization.mxfp4_tensor import MXFP4QuantizeUtil
8
- from sglang.srt.utils import ceil_div, is_sm100_supported, offloader
8
+ from sglang.srt.utils import ceil_div, is_blackwell_supported, offloader
9
9
 
10
10
  try:
11
11
  from vllm import _custom_ops as ops
@@ -129,7 +129,7 @@ def cutlass_block_fp8_supported() -> bool:
129
129
  CUTLASS_BLOCK_FP8_SUPPORTED = cutlass_block_fp8_supported()
130
130
  ENABLE_FLASHINFER_GEMM = (
131
131
  get_bool_env_var("SGLANG_ENABLE_FLASHINFER_GEMM")
132
- and is_sm100_supported()
132
+ and is_blackwell_supported()
133
133
  and is_flashinfer_available()
134
134
  )
135
135
  if ENABLE_FLASHINFER_GEMM: