sglang 0.4.4.post2__py3-none-any.whl → 0.4.4.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. sglang/bench_serving.py +72 -10
  2. sglang/srt/_custom_ops.py +59 -92
  3. sglang/srt/configs/deepseekvl2.py +10 -1
  4. sglang/srt/configs/model_config.py +6 -16
  5. sglang/srt/constrained/base_grammar_backend.py +5 -1
  6. sglang/srt/custom_op.py +5 -0
  7. sglang/srt/distributed/device_communicators/custom_all_reduce.py +28 -80
  8. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
  9. sglang/srt/distributed/parallel_state.py +32 -5
  10. sglang/srt/entrypoints/engine.py +0 -5
  11. sglang/srt/entrypoints/http_server.py +7 -1
  12. sglang/srt/entrypoints/verl_engine.py +2 -0
  13. sglang/srt/function_call_parser.py +0 -1
  14. sglang/srt/layers/attention/flashattention_backend.py +582 -125
  15. sglang/srt/layers/attention/flashinfer_backend.py +5 -7
  16. sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -3
  17. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  18. sglang/srt/layers/dp_attention.py +12 -1
  19. sglang/srt/layers/moe/ep_moe/kernels.py +142 -0
  20. sglang/srt/layers/moe/ep_moe/layer.py +79 -80
  21. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +382 -199
  22. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,block_shape=[128, 128].json +146 -0
  23. sglang/srt/layers/moe/fused_moe_triton/configs/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  24. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  25. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +403 -47
  26. sglang/srt/layers/moe/topk.py +79 -6
  27. sglang/srt/layers/quantization/__init__.py +137 -165
  28. sglang/srt/layers/quantization/awq.py +200 -0
  29. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +2 -1
  30. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +34 -10
  31. sglang/srt/layers/quantization/fp8_kernel.py +2 -1
  32. sglang/srt/layers/quantization/fp8_utils.py +1 -4
  33. sglang/srt/layers/quantization/gptq.py +30 -40
  34. sglang/srt/layers/quantization/moe_wna16.py +501 -0
  35. sglang/srt/layers/quantization/utils.py +1 -1
  36. sglang/srt/layers/quantization/w8a8_fp8.py +1 -1
  37. sglang/srt/lora/backend/base_backend.py +4 -4
  38. sglang/srt/lora/backend/flashinfer_backend.py +12 -9
  39. sglang/srt/lora/backend/triton_backend.py +5 -8
  40. sglang/srt/lora/layers.py +19 -33
  41. sglang/srt/lora/lora_manager.py +20 -7
  42. sglang/srt/lora/mem_pool.py +12 -6
  43. sglang/srt/lora/triton_ops/gate_up_lora_b.py +10 -4
  44. sglang/srt/lora/triton_ops/qkv_lora_b.py +8 -3
  45. sglang/srt/lora/triton_ops/sgemm_lora_a.py +16 -5
  46. sglang/srt/lora/triton_ops/sgemm_lora_b.py +11 -6
  47. sglang/srt/lora/utils.py +6 -0
  48. sglang/srt/managers/cache_controller.py +34 -11
  49. sglang/srt/managers/io_struct.py +4 -2
  50. sglang/srt/managers/mm_utils.py +202 -156
  51. sglang/srt/managers/multimodal_processor.py +0 -2
  52. sglang/srt/managers/multimodal_processors/base_processor.py +45 -77
  53. sglang/srt/managers/multimodal_processors/clip.py +44 -0
  54. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +17 -58
  55. sglang/srt/managers/multimodal_processors/gemma3.py +12 -27
  56. sglang/srt/managers/multimodal_processors/janus_pro.py +21 -47
  57. sglang/srt/managers/multimodal_processors/llava.py +34 -14
  58. sglang/srt/managers/multimodal_processors/minicpm.py +35 -38
  59. sglang/srt/managers/multimodal_processors/mlama.py +10 -23
  60. sglang/srt/managers/multimodal_processors/qwen_vl.py +22 -45
  61. sglang/srt/managers/schedule_batch.py +185 -127
  62. sglang/srt/managers/scheduler.py +29 -23
  63. sglang/srt/managers/tokenizer_manager.py +1 -2
  64. sglang/srt/managers/tp_worker.py +3 -0
  65. sglang/srt/managers/utils.py +1 -6
  66. sglang/srt/mem_cache/hiradix_cache.py +62 -52
  67. sglang/srt/mem_cache/memory_pool.py +72 -6
  68. sglang/srt/mem_cache/paged_allocator.py +39 -0
  69. sglang/srt/metrics/collector.py +23 -53
  70. sglang/srt/model_executor/cuda_graph_runner.py +16 -13
  71. sglang/srt/model_executor/forward_batch_info.py +10 -10
  72. sglang/srt/model_executor/model_runner.py +64 -59
  73. sglang/srt/model_loader/loader.py +19 -1
  74. sglang/srt/model_loader/weight_utils.py +6 -3
  75. sglang/srt/models/clip.py +568 -0
  76. sglang/srt/models/deepseek_janus_pro.py +12 -17
  77. sglang/srt/models/deepseek_v2.py +339 -123
  78. sglang/srt/models/deepseek_vl2.py +105 -104
  79. sglang/srt/models/gemma3_causal.py +12 -2
  80. sglang/srt/models/gemma3_mm.py +20 -80
  81. sglang/srt/models/llama.py +4 -1
  82. sglang/srt/models/llava.py +31 -19
  83. sglang/srt/models/llavavid.py +16 -7
  84. sglang/srt/models/minicpmo.py +63 -147
  85. sglang/srt/models/minicpmv.py +17 -27
  86. sglang/srt/models/mllama.py +29 -14
  87. sglang/srt/models/qwen2.py +9 -6
  88. sglang/srt/models/qwen2_5_vl.py +21 -31
  89. sglang/srt/models/qwen2_vl.py +20 -21
  90. sglang/srt/openai_api/adapter.py +106 -93
  91. sglang/srt/openai_api/protocol.py +10 -5
  92. sglang/srt/patch_torch.py +71 -0
  93. sglang/srt/platforms/interface.py +371 -0
  94. sglang/srt/server_args.py +120 -25
  95. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -5
  96. sglang/srt/speculative/eagle_utils.py +140 -28
  97. sglang/srt/speculative/eagle_worker.py +94 -25
  98. sglang/srt/utils.py +137 -51
  99. sglang/test/runners.py +27 -2
  100. sglang/test/test_custom_ops.py +55 -0
  101. sglang/test/test_utils.py +14 -27
  102. sglang/utils.py +2 -2
  103. sglang/version.py +1 -1
  104. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/METADATA +10 -5
  105. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/RECORD +108 -99
  106. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/WHEEL +0 -0
  107. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/licenses/LICENSE +0 -0
  108. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,200 @@
1
+ # SPDX-License-Identifier: Apache-2.0
2
+ import logging
3
+ from typing import Any, Dict, List, Optional
4
+
5
+ import torch
6
+ from sgl_kernel import awq_dequantize
7
+
8
+ from sglang.srt.layers.linear import (
9
+ LinearBase,
10
+ LinearMethodBase,
11
+ UnquantizedLinearMethod,
12
+ )
13
+ from sglang.srt.layers.parameter import GroupQuantScaleParameter, PackedvLLMParameter
14
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
15
+
16
+ logger = logging.getLogger(__name__)
17
+
18
+
19
+ def is_layer_skipped_awq(prefix: str, modules_to_not_convert: List[str]):
20
+ return any(module_name in prefix for module_name in modules_to_not_convert)
21
+
22
+
23
+ class AWQConfig(QuantizationConfig):
24
+ """Config class for AWQ.
25
+
26
+ Reference: https://arxiv.org/abs/2306.00978
27
+ """
28
+
29
+ def __init__(
30
+ self,
31
+ weight_bits: int,
32
+ group_size: int,
33
+ zero_point: bool,
34
+ modules_to_not_convert: Optional[List[str]] = None,
35
+ ) -> None:
36
+ super().__init__()
37
+ self.weight_bits = weight_bits
38
+ self.group_size = group_size
39
+ self.zero_point = zero_point
40
+ self.modules_to_not_convert = modules_to_not_convert or []
41
+
42
+ if self.weight_bits != 4:
43
+ raise ValueError(
44
+ "Currently, only 4-bit weight quantization is supported for "
45
+ f"AWQ, but got {self.weight_bits} bits."
46
+ )
47
+ self.pack_factor = 32 // self.weight_bits
48
+
49
+ def __repr__(self) -> str:
50
+ return (
51
+ f"AWQConfig(weight_bits={self.weight_bits}, "
52
+ f"group_size={self.group_size}, "
53
+ f"zero_point={self.zero_point}, "
54
+ f"modules_to_not_convert={self.modules_to_not_convert})"
55
+ )
56
+
57
+ def get_scaled_act_names(self) -> List[str]:
58
+ return []
59
+
60
+ def get_name(self) -> str:
61
+ return "awq"
62
+
63
+ def get_supported_act_dtypes(self) -> List[torch.dtype]:
64
+ return [torch.half]
65
+
66
+ @classmethod
67
+ def get_min_capability(cls) -> int:
68
+ # The AWQ kernel only supports Turing or newer GPUs.
69
+ return 75
70
+
71
+ @staticmethod
72
+ def get_config_filenames() -> List[str]:
73
+ return [
74
+ "quant_config.json", # E.g., casperhansen/vicuna-7b-v1.5-awq
75
+ # E.g., abhinavkulkarni/mosaicml-mpt-7b-instruct-w4-g128-awq
76
+ "quantize_config.json",
77
+ ]
78
+
79
+ @classmethod
80
+ def from_config(cls, config: Dict[str, Any]) -> "AWQConfig":
81
+ weight_bits = cls.get_from_keys(config, ["w_bit", "bits"])
82
+ group_size = cls.get_from_keys(config, ["q_group_size", "group_size"])
83
+ zero_point = cls.get_from_keys(config, ["zero_point"])
84
+ modules_to_not_convert = cls.get_from_keys_or(
85
+ config, ["modules_to_not_convert"], None
86
+ )
87
+ return cls(weight_bits, group_size, zero_point, modules_to_not_convert)
88
+
89
+ def get_quant_method(
90
+ self, layer: torch.nn.Module, prefix: str
91
+ ) -> Optional["LinearMethodBase"]:
92
+
93
+ if isinstance(layer, LinearBase):
94
+ if is_layer_skipped_awq(prefix, self.modules_to_not_convert):
95
+ return UnquantizedLinearMethod()
96
+ return AWQLinearMethod(self)
97
+ return None
98
+
99
+
100
+ class AWQLinearMethod(LinearMethodBase):
101
+ """Linear method for AWQ.
102
+
103
+ Args:
104
+ quant_config: The AWQ quantization config.
105
+ """
106
+
107
+ def __init__(self, quant_config: AWQConfig):
108
+ self.quant_config = quant_config
109
+
110
+ def create_weights(
111
+ self,
112
+ layer: torch.nn.Module,
113
+ input_size_per_partition: int,
114
+ output_partition_sizes: List[int],
115
+ input_size: int,
116
+ output_size: int,
117
+ params_dtype: torch.dtype,
118
+ **extra_weight_attrs,
119
+ ):
120
+ if input_size_per_partition % self.quant_config.group_size != 0:
121
+ raise ValueError(
122
+ "The input size is not aligned with the quantized "
123
+ "weight shape. This can be caused by too large "
124
+ "tensor parallel size."
125
+ )
126
+
127
+ output_size_per_partition = sum(output_partition_sizes)
128
+ if output_size_per_partition % self.quant_config.pack_factor != 0:
129
+ raise ValueError(
130
+ "The output size is not aligned with the quantized "
131
+ "weight shape. This can be caused by too large "
132
+ "tensor parallel size."
133
+ )
134
+
135
+ weight_loader = extra_weight_attrs.get("weight_loader")
136
+ qweight = PackedvLLMParameter(
137
+ data=torch.empty(
138
+ input_size_per_partition,
139
+ output_size_per_partition // self.quant_config.pack_factor,
140
+ dtype=torch.int32,
141
+ ),
142
+ input_dim=0,
143
+ output_dim=1,
144
+ packed_dim=1,
145
+ packed_factor=self.quant_config.pack_factor,
146
+ weight_loader=weight_loader,
147
+ )
148
+
149
+ qzeros = PackedvLLMParameter(
150
+ data=torch.empty(
151
+ input_size_per_partition // self.quant_config.group_size,
152
+ output_size_per_partition // self.quant_config.pack_factor,
153
+ dtype=torch.int32,
154
+ ),
155
+ input_dim=0,
156
+ output_dim=1,
157
+ packed_dim=1,
158
+ packed_factor=self.quant_config.pack_factor,
159
+ weight_loader=weight_loader,
160
+ )
161
+
162
+ scales = GroupQuantScaleParameter(
163
+ data=torch.empty(
164
+ input_size_per_partition // self.quant_config.group_size,
165
+ output_size_per_partition,
166
+ dtype=params_dtype,
167
+ ),
168
+ input_dim=0,
169
+ output_dim=1,
170
+ weight_loader=weight_loader,
171
+ )
172
+
173
+ layer.register_parameter("qweight", qweight)
174
+ layer.register_parameter("qzeros", qzeros)
175
+ layer.register_parameter("scales", scales)
176
+
177
+ def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
178
+ layer.qweight = torch.nn.Parameter(layer.qweight.data, requires_grad=False)
179
+ layer.qzeros = torch.nn.Parameter(layer.qzeros.data, requires_grad=False)
180
+ layer.scales = torch.nn.Parameter(layer.scales.data, requires_grad=False)
181
+
182
+ def apply(
183
+ self,
184
+ layer: torch.nn.Module,
185
+ x: torch.Tensor,
186
+ bias: Optional[torch.Tensor] = None,
187
+ ) -> torch.Tensor:
188
+ qweight = layer.qweight
189
+ scales = layer.scales
190
+ qzeros = layer.qzeros
191
+ pack_factor = self.quant_config.pack_factor
192
+ out_shape = x.shape[:-1] + (qweight.shape[-1] * pack_factor,)
193
+ reshaped_x = x.reshape(-1, x.shape[-1])
194
+
195
+ out = awq_dequantize(qweight, scales, qzeros)
196
+ out = torch.matmul(reshaped_x, out)
197
+
198
+ if bias is not None:
199
+ out.add_(bias)
200
+ return out.reshape(out_shape)
@@ -23,7 +23,6 @@ from sglang.srt.layers.linear import (
23
23
  LinearMethodBase,
24
24
  UnquantizedLinearMethod,
25
25
  )
26
- from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
27
26
  from sglang.srt.layers.quantization.base_config import (
28
27
  QuantizationConfig,
29
28
  QuantizeMethodBase,
@@ -123,6 +122,8 @@ class CompressedTensorsConfig(QuantizationConfig):
123
122
  return UnquantizedLinearMethod()
124
123
  layer.scheme = scheme
125
124
  return CompressedTensorsLinearMethod(self)
125
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
126
+
126
127
  if isinstance(layer, FusedMoE):
127
128
  return CompressedTensorsMoEMethod.get_moe_method(self)
128
129
  return None
@@ -4,18 +4,19 @@
4
4
  import enum
5
5
  import logging
6
6
  from enum import Enum
7
- from typing import Callable, List, Optional
7
+ from typing import TYPE_CHECKING, Callable, List, Optional
8
8
 
9
9
  import torch
10
10
  from compressed_tensors import CompressionFormat
11
11
  from compressed_tensors.quantization import QuantizationStrategy
12
12
 
13
- from sglang.srt.layers.moe.fused_moe_triton import (
14
- FusedMoE,
15
- FusedMoEMethodBase,
16
- FusedMoeWeightScaleSupported,
17
- )
18
- from sglang.srt.layers.moe.topk import select_experts
13
+ if TYPE_CHECKING:
14
+ from sglang.srt.layers.moe.fused_moe_triton import (
15
+ FusedMoE,
16
+ FusedMoEMethodBase,
17
+ FusedMoeWeightScaleSupported,
18
+ )
19
+
19
20
  from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
20
21
  from sglang.srt.layers.quantization.utils import (
21
22
  all_close_1d,
@@ -55,7 +56,13 @@ __all__ = [
55
56
  ]
56
57
 
57
58
 
58
- class CompressedTensorsMoEMethod(FusedMoEMethodBase):
59
+ class CompressedTensorsMoEMethod:
60
+ def __new__(cls, *args, **kwargs):
61
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
62
+
63
+ if cls is CompressedTensorsMoEMethod:
64
+ return super().__new__(cls)
65
+ return super().__new__(cls)
59
66
 
60
67
  @staticmethod
61
68
  def get_moe_method(
@@ -85,6 +92,11 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
85
92
  def __init__(
86
93
  self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
87
94
  ):
95
+ from sglang.srt.layers.moe.fused_moe_triton import (
96
+ FusedMoEMethodBase,
97
+ FusedMoeWeightScaleSupported,
98
+ )
99
+
88
100
  self.quant_config = quant_config
89
101
  self.weight_quant = self.quant_config.target_scheme_map["Linear"].get("weights")
90
102
  self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
@@ -112,6 +124,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
112
124
  params_dtype: torch.dtype,
113
125
  **extra_weight_attrs,
114
126
  ):
127
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
115
128
 
116
129
  params_dtype = torch.float8_e4m3fn
117
130
 
@@ -270,8 +283,11 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
270
283
  scoring_func: str = "softmax",
271
284
  correction_bias: Optional[torch.Tensor] = None,
272
285
  activation: str = "silu",
286
+ inplace: bool = True,
287
+ no_combine: bool = False,
273
288
  ) -> torch.Tensor:
274
- from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
289
+ from sglang.srt.layers.moe.fused_moe_triton import fused_experts
290
+ from sglang.srt.layers.moe.topk import select_experts
275
291
 
276
292
  topk_weights, topk_ids = select_experts(
277
293
  hidden_states=x,
@@ -291,7 +307,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
291
307
  layer.w2_weight,
292
308
  topk_weights=topk_weights,
293
309
  topk_ids=topk_ids,
294
- inplace=True,
310
+ inplace=inplace,
295
311
  activation=activation,
296
312
  use_fp8_w8a8=True,
297
313
  w1_scale=layer.w13_weight_scale,
@@ -306,6 +322,11 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
306
322
  def __init__(
307
323
  self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
308
324
  ):
325
+ from sglang.srt.layers.moe.fused_moe_triton import (
326
+ FusedMoEMethodBase,
327
+ FusedMoeWeightScaleSupported,
328
+ )
329
+
309
330
  self.quant_config = quant_config
310
331
  # TODO: @dsikka: refactor this to use schemes as other kernels
311
332
  # are supported + check if the layer is being ignored.
@@ -617,6 +638,9 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
617
638
  correction_bias: Optional[torch.Tensor] = None,
618
639
  activation: str = "silu",
619
640
  ) -> torch.Tensor:
641
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
642
+ from sglang.srt.layers.moe.topk import select_experts
643
+
620
644
  assert activation == "silu", "Only SiLU activation is supported."
621
645
  if not VLLM_AVAILABLE:
622
646
  raise ImportError(
@@ -24,6 +24,7 @@ import triton.language as tl
24
24
 
25
25
  from sglang.srt.utils import (
26
26
  direct_register_custom_op,
27
+ get_bool_env_var,
27
28
  get_device_core_count,
28
29
  get_device_name,
29
30
  get_device_sm,
@@ -43,7 +44,7 @@ if _is_cuda:
43
44
  from sgl_kernel import sgl_per_token_group_quant_fp8, sgl_per_token_quant_fp8
44
45
 
45
46
  sm_version = get_device_sm()
46
- if sm_version >= 90 and int(os.getenv("SGL_ENABLE_JIT_DEEPGEMM", "1")):
47
+ if sm_version >= 90 and get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="true"):
47
48
  _enable_jit_deepgemm = True
48
49
 
49
50
 
@@ -457,12 +457,9 @@ class Fp8LinearOp:
457
457
  qinput, x_scale = sgl_scaled_fp8_quant(
458
458
  input_2d,
459
459
  input_scale,
460
+ num_token_padding=self.output_padding,
460
461
  use_per_token_if_dynamic=use_per_token_if_dynamic,
461
462
  )
462
- if self.output_padding:
463
- pad_size = max(self.output_padding - qinput.shape[0], 0)
464
- if pad_size > 0:
465
- qinput = torch.nn.functional.pad(qinput, (0, 0, 0, pad_size))
466
463
  else:
467
464
  qinput, x_scale = ops.scaled_fp8_quant(
468
465
  input_2d,
@@ -11,12 +11,29 @@ from sglang.srt.utils import is_cuda
11
11
  _is_cuda = is_cuda()
12
12
 
13
13
  try:
14
- import vllm
14
+ from vllm.model_executor.layers.quantization.base_config import QuantizeMethodBase
15
+ from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
16
+ from vllm.model_executor.layers.quantization.gptq_marlin import (
17
+ GPTQMarlinLinearMethod,
18
+ GPTQMarlinMoEMethod,
19
+ )
20
+ from vllm.model_executor.layers.quantization.marlin import MarlinLinearMethod
21
+ from vllm.model_executor.layers.quantization.utils.marlin_utils import (
22
+ check_marlin_supported,
23
+ )
24
+ from vllm.scalar_type import scalar_types
15
25
 
16
26
  VLLM_AVAILABLE = True
17
27
  except ImportError:
18
28
  VLLM_AVAILABLE = False
19
29
 
30
+ GPTQLinearMethod = MarlinLinearMethod = QuantizeMethodBase = Any
31
+
32
+ class scalar_types:
33
+ uint4b8 = "uint4b8"
34
+ uint8b128 = "uint8b128"
35
+
36
+
20
37
  logger = logging.getLogger(__name__)
21
38
 
22
39
 
@@ -117,12 +134,8 @@ class GPTQConfig(QuantizationConfig):
117
134
 
118
135
  def get_quant_method(
119
136
  self, layer: torch.nn.Module, prefix: str
120
- ) -> Optional["GPTQLinearMethod"]:
121
- if not VLLM_AVAILABLE:
122
- raise ImportError("vllm is not installed")
123
-
124
- from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
125
-
137
+ ) -> Optional[GPTQLinearMethod]:
138
+ # Delay the import to avoid circular dependency
126
139
  from sglang.srt.layers.quantization import get_linear_quant_method
127
140
 
128
141
  return get_linear_quant_method(self, layer, prefix, GPTQLinearMethod)
@@ -131,16 +144,11 @@ class GPTQConfig(QuantizationConfig):
131
144
  class GPTQMarlinConfig(QuantizationConfig):
132
145
  """Config class for GPTQ Marlin"""
133
146
 
134
- if VLLM_AVAILABLE:
135
- from vllm.scalar_type import scalar_types
136
-
137
- # (num_bits, is_sym) -> quant_type
138
- TYPE_MAP = {
139
- (4, True): scalar_types.uint4b8,
140
- (8, True): scalar_types.uint8b128,
141
- }
142
- else:
143
- raise ImportError("vllm is not installed")
147
+ # (num_bits, is_sym) -> quant_type
148
+ TYPE_MAP = {
149
+ (4, True): scalar_types.uint4b8,
150
+ (8, True): scalar_types.uint8b128,
151
+ }
144
152
 
145
153
  def __init__(
146
154
  self,
@@ -197,6 +205,7 @@ class GPTQMarlinConfig(QuantizationConfig):
197
205
  "Unsupported quantization config: " f"bits={weight_bits}, sym={is_sym}"
198
206
  )
199
207
 
208
+ # (num_bits, is_sym) -> quant_type
200
209
  self.quant_type = self.TYPE_MAP[(weight_bits, is_sym)]
201
210
 
202
211
  def __repr__(self) -> str:
@@ -278,15 +287,8 @@ class GPTQMarlinConfig(QuantizationConfig):
278
287
 
279
288
  def get_quant_method(
280
289
  self, layer: torch.nn.Module, prefix: str
281
- ) -> Optional["QuantizeMethodBase"]:
282
- if not VLLM_AVAILABLE:
283
- raise ImportError("vllm is not installed")
284
-
285
- from vllm.model_executor.layers.quantization.gptq_marlin import (
286
- GPTQMarlinLinearMethod,
287
- GPTQMarlinMoEMethod,
288
- )
289
-
290
+ ) -> Optional[QuantizeMethodBase]:
291
+ # Delay the import to avoid circular dependency
290
292
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
291
293
  from sglang.srt.layers.quantization import get_linear_quant_method
292
294
 
@@ -304,19 +306,12 @@ class GPTQMarlinConfig(QuantizationConfig):
304
306
 
305
307
  @classmethod
306
308
  def is_gptq_marlin_compatible(cls, quant_config: Dict[str, Any]):
307
- if not VLLM_AVAILABLE:
308
- return False
309
-
310
309
  quant_method = quant_config.get("quant_method", "").lower()
311
310
  num_bits = quant_config.get("bits")
312
311
  group_size = quant_config.get("group_size")
313
312
  sym = quant_config.get("sym")
314
313
  desc_act = quant_config.get("desc_act")
315
314
 
316
- from vllm.model_executor.layers.quantization.utils.marlin_utils import (
317
- check_marlin_supported,
318
- )
319
-
320
315
  if not _is_cuda:
321
316
  return False
322
317
 
@@ -427,13 +422,8 @@ class MarlinConfig(QuantizationConfig):
427
422
 
428
423
  def get_quant_method(
429
424
  self, layer: torch.nn.Module, prefix: str
430
- ) -> Optional["MarlinLinearMethod"]:
431
- if not VLLM_AVAILABLE:
432
- raise ImportError("vllm is not installed")
433
-
434
- from vllm.model_executor.layers.quantization.marlin import MarlinLinearMethod
435
-
436
- # Delay import to avoid circular dependency
425
+ ) -> Optional[MarlinLinearMethod]:
426
+ # Delay the import to avoid circular dependency
437
427
  from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
438
428
 
439
429
  if isinstance(layer, LinearBase) or (