sglang 0.4.4.post2__py3-none-any.whl → 0.4.4.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (108) hide show
  1. sglang/bench_serving.py +72 -10
  2. sglang/srt/_custom_ops.py +59 -92
  3. sglang/srt/configs/deepseekvl2.py +10 -1
  4. sglang/srt/configs/model_config.py +6 -16
  5. sglang/srt/constrained/base_grammar_backend.py +5 -1
  6. sglang/srt/custom_op.py +5 -0
  7. sglang/srt/distributed/device_communicators/custom_all_reduce.py +28 -80
  8. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
  9. sglang/srt/distributed/parallel_state.py +32 -5
  10. sglang/srt/entrypoints/engine.py +0 -5
  11. sglang/srt/entrypoints/http_server.py +7 -1
  12. sglang/srt/entrypoints/verl_engine.py +2 -0
  13. sglang/srt/function_call_parser.py +0 -1
  14. sglang/srt/layers/attention/flashattention_backend.py +582 -125
  15. sglang/srt/layers/attention/flashinfer_backend.py +5 -7
  16. sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -3
  17. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  18. sglang/srt/layers/dp_attention.py +12 -1
  19. sglang/srt/layers/moe/ep_moe/kernels.py +142 -0
  20. sglang/srt/layers/moe/ep_moe/layer.py +79 -80
  21. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +382 -199
  22. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,block_shape=[128, 128].json +146 -0
  23. sglang/srt/layers/moe/fused_moe_triton/configs/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  24. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  25. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +403 -47
  26. sglang/srt/layers/moe/topk.py +79 -6
  27. sglang/srt/layers/quantization/__init__.py +137 -165
  28. sglang/srt/layers/quantization/awq.py +200 -0
  29. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +2 -1
  30. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +34 -10
  31. sglang/srt/layers/quantization/fp8_kernel.py +2 -1
  32. sglang/srt/layers/quantization/fp8_utils.py +1 -4
  33. sglang/srt/layers/quantization/gptq.py +30 -40
  34. sglang/srt/layers/quantization/moe_wna16.py +501 -0
  35. sglang/srt/layers/quantization/utils.py +1 -1
  36. sglang/srt/layers/quantization/w8a8_fp8.py +1 -1
  37. sglang/srt/lora/backend/base_backend.py +4 -4
  38. sglang/srt/lora/backend/flashinfer_backend.py +12 -9
  39. sglang/srt/lora/backend/triton_backend.py +5 -8
  40. sglang/srt/lora/layers.py +19 -33
  41. sglang/srt/lora/lora_manager.py +20 -7
  42. sglang/srt/lora/mem_pool.py +12 -6
  43. sglang/srt/lora/triton_ops/gate_up_lora_b.py +10 -4
  44. sglang/srt/lora/triton_ops/qkv_lora_b.py +8 -3
  45. sglang/srt/lora/triton_ops/sgemm_lora_a.py +16 -5
  46. sglang/srt/lora/triton_ops/sgemm_lora_b.py +11 -6
  47. sglang/srt/lora/utils.py +6 -0
  48. sglang/srt/managers/cache_controller.py +34 -11
  49. sglang/srt/managers/io_struct.py +4 -2
  50. sglang/srt/managers/mm_utils.py +202 -156
  51. sglang/srt/managers/multimodal_processor.py +0 -2
  52. sglang/srt/managers/multimodal_processors/base_processor.py +45 -77
  53. sglang/srt/managers/multimodal_processors/clip.py +44 -0
  54. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +17 -58
  55. sglang/srt/managers/multimodal_processors/gemma3.py +12 -27
  56. sglang/srt/managers/multimodal_processors/janus_pro.py +21 -47
  57. sglang/srt/managers/multimodal_processors/llava.py +34 -14
  58. sglang/srt/managers/multimodal_processors/minicpm.py +35 -38
  59. sglang/srt/managers/multimodal_processors/mlama.py +10 -23
  60. sglang/srt/managers/multimodal_processors/qwen_vl.py +22 -45
  61. sglang/srt/managers/schedule_batch.py +185 -127
  62. sglang/srt/managers/scheduler.py +29 -23
  63. sglang/srt/managers/tokenizer_manager.py +1 -2
  64. sglang/srt/managers/tp_worker.py +3 -0
  65. sglang/srt/managers/utils.py +1 -6
  66. sglang/srt/mem_cache/hiradix_cache.py +62 -52
  67. sglang/srt/mem_cache/memory_pool.py +72 -6
  68. sglang/srt/mem_cache/paged_allocator.py +39 -0
  69. sglang/srt/metrics/collector.py +23 -53
  70. sglang/srt/model_executor/cuda_graph_runner.py +16 -13
  71. sglang/srt/model_executor/forward_batch_info.py +10 -10
  72. sglang/srt/model_executor/model_runner.py +64 -59
  73. sglang/srt/model_loader/loader.py +19 -1
  74. sglang/srt/model_loader/weight_utils.py +6 -3
  75. sglang/srt/models/clip.py +568 -0
  76. sglang/srt/models/deepseek_janus_pro.py +12 -17
  77. sglang/srt/models/deepseek_v2.py +339 -123
  78. sglang/srt/models/deepseek_vl2.py +105 -104
  79. sglang/srt/models/gemma3_causal.py +12 -2
  80. sglang/srt/models/gemma3_mm.py +20 -80
  81. sglang/srt/models/llama.py +4 -1
  82. sglang/srt/models/llava.py +31 -19
  83. sglang/srt/models/llavavid.py +16 -7
  84. sglang/srt/models/minicpmo.py +63 -147
  85. sglang/srt/models/minicpmv.py +17 -27
  86. sglang/srt/models/mllama.py +29 -14
  87. sglang/srt/models/qwen2.py +9 -6
  88. sglang/srt/models/qwen2_5_vl.py +21 -31
  89. sglang/srt/models/qwen2_vl.py +20 -21
  90. sglang/srt/openai_api/adapter.py +106 -93
  91. sglang/srt/openai_api/protocol.py +10 -5
  92. sglang/srt/patch_torch.py +71 -0
  93. sglang/srt/platforms/interface.py +371 -0
  94. sglang/srt/server_args.py +120 -25
  95. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -5
  96. sglang/srt/speculative/eagle_utils.py +140 -28
  97. sglang/srt/speculative/eagle_worker.py +94 -25
  98. sglang/srt/utils.py +137 -51
  99. sglang/test/runners.py +27 -2
  100. sglang/test/test_custom_ops.py +55 -0
  101. sglang/test/test_utils.py +14 -27
  102. sglang/utils.py +2 -2
  103. sglang/version.py +1 -1
  104. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/METADATA +10 -5
  105. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/RECORD +108 -99
  106. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/WHEEL +0 -0
  107. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/licenses/LICENSE +0 -0
  108. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post4.dist-info}/top_level.txt +0 -0
@@ -12,17 +12,19 @@
12
12
  # limitations under the License.
13
13
  # ==============================================================================
14
14
 
15
+ import os
15
16
  from typing import Callable, Optional
16
17
 
17
18
  import torch
18
19
  import torch.nn.functional as F
19
20
 
21
+ from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
22
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
20
23
  from sglang.srt.utils import get_compiler_backend, is_cuda, is_hip
21
24
 
22
25
  _is_cuda = is_cuda()
23
26
  _is_hip = is_hip()
24
27
 
25
- from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
26
28
 
27
29
  expert_distribution_recorder = ExpertDistributionRecorder()
28
30
 
@@ -102,11 +104,13 @@ def grouped_topk(
102
104
  renormalize: bool,
103
105
  num_expert_group: int = 0,
104
106
  topk_group: int = 0,
107
+ n_share_experts_fusion: int = 0,
105
108
  ):
106
109
  assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
107
110
 
108
111
  scores = torch.softmax(gating_output, dim=-1)
109
112
  num_token = scores.shape[0]
113
+ num_experts = scores.shape[1]
110
114
  group_scores = (
111
115
  scores.view(num_token, num_expert_group, -1).max(dim=-1).values
112
116
  ) # [n, n_group]
@@ -122,15 +126,30 @@ def grouped_topk(
122
126
  ) # [n, e]
123
127
  tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
124
128
  topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
129
+ if n_share_experts_fusion:
130
+ topk_ids[:, -1] = torch.randint(
131
+ low=num_experts,
132
+ high=num_experts + n_share_experts_fusion,
133
+ size=(topk_ids.size(0),),
134
+ dtype=topk_ids.dtype,
135
+ device=topk_ids.device,
136
+ )
137
+ topk_weights[:, -1] = (
138
+ topk_weights[:, :-1].sum(dim=-1) / 2.5
139
+ ) # 2.5 is the routed_scaling_factor.
125
140
 
126
141
  if renormalize:
127
- topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
142
+ topk_weights_sum = (
143
+ topk_weights.sum(dim=-1, keepdim=True)
144
+ if n_share_experts_fusion == 0
145
+ else topk_weights[:, :-1].sum(dim=-1, keepdim=True)
146
+ )
147
+ topk_weights = topk_weights / topk_weights_sum
128
148
 
129
149
  return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
130
150
 
131
151
 
132
- @torch.compile(dynamic=True, backend=get_compiler_backend())
133
- def biased_grouped_topk(
152
+ def biased_grouped_topk_impl(
134
153
  hidden_states: torch.Tensor,
135
154
  gating_output: torch.Tensor,
136
155
  correction_bias: torch.Tensor,
@@ -138,11 +157,13 @@ def biased_grouped_topk(
138
157
  renormalize: bool,
139
158
  num_expert_group: int = 0,
140
159
  topk_group: int = 0,
160
+ n_share_experts_fusion: int = 0,
141
161
  ):
142
162
  assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
143
163
 
144
164
  scores = gating_output.sigmoid()
145
165
  num_token = scores.shape[0]
166
+ num_experts = scores.shape[1]
146
167
  scores_for_choice = scores.view(num_token, -1) + correction_bias.unsqueeze(0)
147
168
  group_scores = (
148
169
  scores_for_choice.view(num_token, num_expert_group, -1)
@@ -165,12 +186,59 @@ def biased_grouped_topk(
165
186
  _, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
166
187
  topk_weights = scores.gather(1, topk_ids)
167
188
 
189
+ if n_share_experts_fusion:
190
+ topk_ids[:, -1] = torch.randint(
191
+ low=num_experts,
192
+ high=num_experts + n_share_experts_fusion,
193
+ size=(topk_ids.size(0),),
194
+ dtype=topk_ids.dtype,
195
+ device=topk_ids.device,
196
+ )
197
+ topk_weights[:, -1] = (
198
+ topk_weights[:, :-1].sum(dim=-1) / 2.5
199
+ ) # 2.5 is the routed_scaling_factor.
200
+
168
201
  if renormalize:
169
- topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
202
+ topk_weights_sum = (
203
+ topk_weights.sum(dim=-1, keepdim=True)
204
+ if n_share_experts_fusion == 0
205
+ else topk_weights[:, :-1].sum(dim=-1, keepdim=True)
206
+ )
207
+ topk_weights = topk_weights / topk_weights_sum
170
208
 
171
209
  return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
172
210
 
173
211
 
212
+ def biased_grouped_topk(
213
+ hidden_states: torch.Tensor,
214
+ gating_output: torch.Tensor,
215
+ correction_bias: torch.Tensor,
216
+ topk: int,
217
+ renormalize: bool,
218
+ num_expert_group: int = 0,
219
+ topk_group: int = 0,
220
+ compiled: bool = True,
221
+ n_share_experts_fusion: int = 0,
222
+ ):
223
+ biased_grouped_topk_fn = (
224
+ torch.compile(
225
+ biased_grouped_topk_impl, dynamic=True, backend=get_compiler_backend()
226
+ )
227
+ if compiled
228
+ else biased_grouped_topk_impl
229
+ )
230
+ return biased_grouped_topk_fn(
231
+ hidden_states,
232
+ gating_output,
233
+ correction_bias,
234
+ topk,
235
+ renormalize,
236
+ num_expert_group,
237
+ topk_group,
238
+ n_share_experts_fusion=n_share_experts_fusion,
239
+ )
240
+
241
+
174
242
  def select_experts(
175
243
  hidden_states: torch.Tensor,
176
244
  router_logits: torch.Tensor,
@@ -183,7 +251,10 @@ def select_experts(
183
251
  correction_bias: Optional[torch.Tensor] = None,
184
252
  torch_native: bool = False,
185
253
  ):
186
- # DeekSeekv2 uses grouped_top_k
254
+ n_share_experts_fusion = 0
255
+ if global_server_args_dict["n_share_experts_fusion"] is not None:
256
+ n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
257
+ # DeekSeek V2/V3/R1 serices models uses grouped_top_k
187
258
  if use_grouped_topk:
188
259
  assert topk_group is not None
189
260
  assert num_expert_group is not None
@@ -195,6 +266,7 @@ def select_experts(
195
266
  renormalize=renormalize,
196
267
  num_expert_group=num_expert_group,
197
268
  topk_group=topk_group,
269
+ n_share_experts_fusion=n_share_experts_fusion,
198
270
  )
199
271
  else:
200
272
  topk_weights, topk_ids = biased_grouped_topk(
@@ -205,6 +277,7 @@ def select_experts(
205
277
  renormalize=renormalize,
206
278
  num_expert_group=num_expert_group,
207
279
  topk_group=topk_group,
280
+ n_share_experts_fusion=n_share_experts_fusion,
208
281
  )
209
282
  elif torch_native and custom_routing_function is None:
210
283
  topk_weights, topk_ids = fused_topk_native(
@@ -9,13 +9,24 @@ import torch
9
9
 
10
10
  try:
11
11
  from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
12
- from vllm.model_executor.layers.quantization.awq import AWQConfig
13
- from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig
12
+ from vllm.model_executor.layers.quantization.awq_marlin import (
13
+ AWQMarlinConfig,
14
+ AWQMoEMethod,
15
+ )
14
16
  from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig
17
+ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import (
18
+ CompressedTensorsW8A8Fp8MoEMethod,
19
+ CompressedTensorsWNA16MoEMethod,
20
+ )
15
21
  from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig
16
22
  from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
17
23
  from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config
18
24
  from vllm.model_executor.layers.quantization.gguf import GGUFConfig
25
+ from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
26
+ from vllm.model_executor.layers.quantization.gptq_marlin import (
27
+ GPTQMarlinLinearMethod,
28
+ GPTQMarlinMoEMethod,
29
+ )
19
30
  from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
20
31
  GPTQMarlin24Config,
21
32
  )
@@ -23,33 +34,39 @@ try:
23
34
  from vllm.model_executor.layers.quantization.qqq import QQQConfig
24
35
  from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig
25
36
 
26
- from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
27
-
28
37
  VLLM_AVAILABLE = True
29
38
  except ImportError:
30
39
  VLLM_AVAILABLE = False
31
40
 
32
41
  # Define empty classes as placeholders when vllm is not available
33
42
  class DummyConfig:
34
- pass
43
+ def override_quantization_method(self, *args, **kwargs):
44
+ return None
45
+
46
+ AQLMConfig = AWQMarlinConfig = BitsAndBytesConfig = CompressedTensorsConfig = (
47
+ DeepSpeedFPConfig
48
+ ) = ExpertsInt8Config = FBGEMMFp8Config = GGUFConfig = GPTQMarlin24Config = (
49
+ MarlinConfig
50
+ ) = QQQConfig = Int8TpuConfig = DummyConfig
35
51
 
36
- AQLMConfig = AWQConfig = AWQMarlinConfig = BitsAndBytesConfig = (
37
- CompressedTensorsConfig
38
- ) = DummyConfig
39
- DeepSpeedFPConfig = ExpertsInt8Config = FBGEMMFp8Config = GGUFConfig = (
40
- GPTQMarlin24Config
41
- ) = DummyConfig
42
- MarlinConfig = QQQConfig = Int8TpuConfig = DummyConfig
43
52
 
53
+ from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
54
+ from sglang.srt.layers.quantization.awq import AWQConfig
44
55
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
45
56
  from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config
46
57
  from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import (
47
58
  CompressedTensorsConfig,
48
59
  )
49
60
  from sglang.srt.layers.quantization.fp8 import Fp8Config
61
+ from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
50
62
  from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp8Config
63
+ from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config
51
64
  from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
52
65
  from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
66
+ from sglang.srt.layers.vocab_parallel_embedding import (
67
+ ParallelLMHead,
68
+ UnquantizedEmbeddingMethod,
69
+ )
53
70
 
54
71
  # Base quantization methods that don't depend on vllm
55
72
  BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
@@ -58,29 +75,29 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
58
75
  "modelopt": ModelOptFp8Config,
59
76
  "w8a8_int8": W8A8Int8Config,
60
77
  "w8a8_fp8": W8A8Fp8Config,
78
+ "moe_wna16": MoeWNA16Config,
61
79
  "compressed-tensors": CompressedTensorsConfig,
62
80
  }
63
81
 
64
- # Add vllm-dependent methods if available
65
- QUANTIZATION_METHODS = BASE_QUANTIZATION_METHODS.copy()
66
- if VLLM_AVAILABLE:
67
- VLLM_QUANTIZATION_METHODS = {
68
- "aqlm": AQLMConfig,
69
- "awq": AWQConfig,
70
- "deepspeedfp": DeepSpeedFPConfig,
71
- "tpu_int8": Int8TpuConfig,
72
- "fbgemm_fp8": FBGEMMFp8Config,
73
- "marlin": MarlinConfig,
74
- "gguf": GGUFConfig,
75
- "gptq_marlin_24": GPTQMarlin24Config,
76
- "awq_marlin": AWQMarlinConfig,
77
- "bitsandbytes": BitsAndBytesConfig,
78
- "qqq": QQQConfig,
79
- "experts_int8": ExpertsInt8Config,
80
- "gptq_marlin": GPTQMarlinConfig,
81
- "gptq": GPTQConfig,
82
- }
83
- QUANTIZATION_METHODS.update(VLLM_QUANTIZATION_METHODS)
82
+ # VLLM-dependent quantization methods
83
+ VLLM_QUANTIZATION_METHODS = {
84
+ "aqlm": AQLMConfig,
85
+ "awq": AWQConfig,
86
+ "deepspeedfp": DeepSpeedFPConfig,
87
+ "tpu_int8": Int8TpuConfig,
88
+ "fbgemm_fp8": FBGEMMFp8Config,
89
+ "marlin": MarlinConfig,
90
+ "gguf": GGUFConfig,
91
+ "gptq_marlin_24": GPTQMarlin24Config,
92
+ "awq_marlin": AWQMarlinConfig,
93
+ "bitsandbytes": BitsAndBytesConfig,
94
+ "qqq": QQQConfig,
95
+ "experts_int8": ExpertsInt8Config,
96
+ "gptq_marlin": GPTQMarlinConfig,
97
+ "gptq": GPTQConfig,
98
+ }
99
+
100
+ QUANTIZATION_METHODS = {**BASE_QUANTIZATION_METHODS, **VLLM_QUANTIZATION_METHODS}
84
101
 
85
102
 
86
103
  def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
@@ -89,6 +106,12 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
89
106
  f"Invalid quantization method: {quantization}. "
90
107
  f"Available methods: {list(QUANTIZATION_METHODS.keys())}"
91
108
  )
109
+ if quantization in VLLM_QUANTIZATION_METHODS and not VLLM_AVAILABLE:
110
+ raise ValueError(
111
+ f"{quantization} quantization requires some operators from vllm. "
112
+ "Pleaes install vllm by `pip install vllm==0.7.2`"
113
+ )
114
+
92
115
  return QUANTIZATION_METHODS[quantization]
93
116
 
94
117
 
@@ -153,13 +176,6 @@ def get_linear_quant_method(
153
176
  prefix: str,
154
177
  linear_method_cls: type,
155
178
  ):
156
-
157
- from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
158
- from sglang.srt.layers.vocab_parallel_embedding import (
159
- ParallelLMHead,
160
- UnquantizedEmbeddingMethod,
161
- )
162
-
163
179
  cloned_config = deepcopy(config)
164
180
  parallel_lm_head_quantized = (
165
181
  isinstance(layer, ParallelLMHead) and cloned_config.lm_head_quantized
@@ -186,31 +202,19 @@ def get_linear_quant_method(
186
202
 
187
203
 
188
204
  def gptq_get_quant_method(self, layer, prefix):
189
- if not VLLM_AVAILABLE:
190
- return None
205
+ from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
191
206
 
192
- try:
193
- from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
194
- from vllm.model_executor.layers.quantization.gptq_marlin import (
195
- GPTQMarlinLinearMethod,
196
- GPTQMarlinMoEMethod,
197
- )
207
+ if isinstance(layer, FusedMoE):
208
+ return GPTQMarlinMoEMethod(self)
198
209
 
199
- from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
200
-
201
- if isinstance(layer, FusedMoE):
202
- return GPTQMarlinMoEMethod(self)
203
-
204
- if isinstance(self, GPTQConfig):
205
- return get_linear_quant_method(
206
- self, layer, prefix=prefix, linear_method_cls=GPTQLinearMethod
207
- )
208
- elif isinstance(self, GPTQMarlinConfig):
209
- return get_linear_quant_method(
210
- self, layer, prefix=prefix, linear_method_cls=GPTQMarlinLinearMethod
211
- )
212
- except ImportError:
213
- pass
210
+ if isinstance(self, GPTQConfig):
211
+ return get_linear_quant_method(
212
+ self, layer, prefix=prefix, linear_method_cls=GPTQLinearMethod
213
+ )
214
+ elif isinstance(self, GPTQMarlinConfig):
215
+ return get_linear_quant_method(
216
+ self, layer, prefix=prefix, linear_method_cls=GPTQMarlinLinearMethod
217
+ )
214
218
  return None
215
219
 
216
220
 
@@ -229,33 +233,28 @@ def monkey_patch_isinstance_for_vllm_base_layer(reverse: bool = False):
229
233
  builtins.isinstance = original_isinstance
230
234
  return
231
235
 
232
- try:
233
- from vllm.model_executor.layers.fused_moe import FusedMoE
234
- from vllm.model_executor.layers.linear import LinearBase
235
- from vllm.model_executor.layers.vocab_parallel_embedding import (
236
- VocabParallelEmbedding,
237
- )
236
+ from vllm.model_executor.layers.fused_moe import FusedMoE
237
+ from vllm.model_executor.layers.linear import LinearBase
238
+ from vllm.model_executor.layers.vocab_parallel_embedding import (
239
+ VocabParallelEmbedding,
240
+ )
238
241
 
239
- from sglang.srt.layers.linear import LinearBase as PatchedLinearBase
240
- from sglang.srt.layers.moe.fused_moe_triton.layer import (
241
- FusedMoE as PatchedFusedMoE,
242
- )
243
- from sglang.srt.layers.vocab_parallel_embedding import (
244
- VocabParallelEmbedding as PatchedVocabParallelEmbedding,
245
- )
242
+ from sglang.srt.layers.linear import LinearBase as PatchedLinearBase
243
+ from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE as PatchedFusedMoE
244
+ from sglang.srt.layers.vocab_parallel_embedding import (
245
+ VocabParallelEmbedding as PatchedVocabParallelEmbedding,
246
+ )
246
247
 
247
- def patched_isinstance(obj, classinfo):
248
- if classinfo is LinearBase:
249
- return original_isinstance(obj, PatchedLinearBase)
250
- if classinfo is FusedMoE:
251
- return original_isinstance(obj, PatchedFusedMoE)
252
- if classinfo is VocabParallelEmbedding:
253
- return original_isinstance(obj, PatchedVocabParallelEmbedding)
254
- return original_isinstance(obj, classinfo)
255
-
256
- builtins.isinstance = patched_isinstance
257
- except ImportError:
258
- return
248
+ def patched_isinstance(obj, classinfo):
249
+ if classinfo is LinearBase:
250
+ return original_isinstance(obj, PatchedLinearBase)
251
+ if classinfo is FusedMoE:
252
+ return original_isinstance(obj, PatchedFusedMoE)
253
+ if classinfo is VocabParallelEmbedding:
254
+ return original_isinstance(obj, PatchedVocabParallelEmbedding)
255
+ return original_isinstance(obj, classinfo)
256
+
257
+ builtins.isinstance = patched_isinstance
259
258
 
260
259
 
261
260
  def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
@@ -263,91 +262,64 @@ def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
263
262
  Monkey patch the apply function of vllm's FusedMoEMethodBase.
264
263
  Convert sglang arguments to vllm arguments.
265
264
  """
266
- if not VLLM_AVAILABLE:
267
- return
268
-
269
- try:
270
- original_apply = class_obj.apply
271
- sig = inspect.signature(original_apply)
272
- param_names = list(sig.parameters.keys())
273
- has_correction_bias = "e_score_correction_bias" in param_names
274
-
275
- def new_apply(
276
- self,
277
- layer: torch.nn.Module,
278
- x: torch.Tensor,
279
- router_logits: torch.Tensor,
280
- top_k: int,
281
- renormalize: bool,
282
- use_grouped_topk: bool,
283
- topk_group: Optional[int] = None,
284
- num_expert_group: Optional[int] = None,
285
- custom_routing_function: Optional[Callable] = None,
286
- correction_bias: Optional[torch.Tensor] = None,
287
- activation: str = "silu",
288
- inplace: bool = True,
289
- no_combine: bool = False,
290
- ):
291
- assert activation == "silu"
292
- assert inplace and not no_combine
293
-
294
- kwargs = {
295
- "self": self,
296
- "layer": layer,
297
- "x": x,
298
- "router_logits": router_logits,
299
- "top_k": top_k,
300
- "renormalize": renormalize,
301
- "use_grouped_topk": use_grouped_topk,
302
- "topk_group": topk_group,
303
- "num_expert_group": num_expert_group,
304
- "custom_routing_function": custom_routing_function,
305
- }
306
- if correction_bias is not None:
307
- if not has_correction_bias:
308
- raise ValueError(
309
- "Please increase the version of your vllm. Try `pip install vllm==0.7.2`"
310
- )
311
- kwargs["e_score_correction_bias"] = correction_bias
312
- return original_apply(**kwargs)
313
-
314
- setattr(class_obj, "apply", new_apply)
315
- except (ImportError, AttributeError):
316
- return
265
+ original_apply = class_obj.apply
266
+ sig = inspect.signature(original_apply)
267
+ param_names = list(sig.parameters.keys())
268
+ has_correction_bias = "e_score_correction_bias" in param_names
269
+
270
+ def new_apply(
271
+ self,
272
+ layer: torch.nn.Module,
273
+ x: torch.Tensor,
274
+ router_logits: torch.Tensor,
275
+ top_k: int,
276
+ renormalize: bool,
277
+ use_grouped_topk: bool,
278
+ topk_group: Optional[int] = None,
279
+ num_expert_group: Optional[int] = None,
280
+ custom_routing_function: Optional[Callable] = None,
281
+ correction_bias: Optional[torch.Tensor] = None,
282
+ activation: str = "silu",
283
+ inplace: bool = True,
284
+ no_combine: bool = False,
285
+ ):
286
+ assert activation == "silu"
287
+ assert inplace and not no_combine
288
+
289
+ kwargs = {
290
+ "self": self,
291
+ "layer": layer,
292
+ "x": x,
293
+ "router_logits": router_logits,
294
+ "top_k": top_k,
295
+ "renormalize": renormalize,
296
+ "use_grouped_topk": use_grouped_topk,
297
+ "topk_group": topk_group,
298
+ "num_expert_group": num_expert_group,
299
+ "custom_routing_function": custom_routing_function,
300
+ }
301
+ if correction_bias is not None:
302
+ if not has_correction_bias:
303
+ raise ValueError(
304
+ "Please increase the version of your vllm. Try `pip install vllm==0.7.2`"
305
+ )
306
+ kwargs["e_score_correction_bias"] = correction_bias
307
+ return original_apply(**kwargs)
308
+
309
+ setattr(class_obj, "apply", new_apply)
317
310
 
318
311
 
319
312
  def monkey_patch_quant_configs():
320
313
  """Apply all monkey patches in one place."""
321
- if not VLLM_AVAILABLE:
322
- return
314
+ setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
315
+ setattr(GPTQConfig, "get_quant_method", gptq_get_quant_method)
323
316
 
324
- try:
325
- from vllm.model_executor.layers.quantization.awq_marlin import AWQMoEMethod
326
- from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import (
327
- CompressedTensorsW8A8Fp8MoEMethod,
328
- CompressedTensorsWNA16MoEMethod,
329
- )
330
- from vllm.model_executor.layers.quantization.gptq_marlin import (
331
- GPTQMarlinMoEMethod,
332
- )
333
-
334
- setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
335
- setattr(GPTQConfig, "get_quant_method", gptq_get_quant_method)
336
-
337
- monkey_patch_moe_apply(AWQMoEMethod)
338
- monkey_patch_moe_apply(GPTQMarlinMoEMethod)
339
- monkey_patch_moe_apply(CompressedTensorsW8A8Fp8MoEMethod)
340
- monkey_patch_moe_apply(CompressedTensorsWNA16MoEMethod)
341
- except ImportError:
342
- return
317
+ monkey_patch_moe_apply(AWQMoEMethod)
318
+ monkey_patch_moe_apply(GPTQMarlinMoEMethod)
319
+ monkey_patch_moe_apply(CompressedTensorsW8A8Fp8MoEMethod)
320
+ monkey_patch_moe_apply(CompressedTensorsWNA16MoEMethod)
343
321
 
344
322
 
345
323
  # Only apply monkey patches if vllm is available
346
324
  if VLLM_AVAILABLE:
347
325
  monkey_patch_quant_configs()
348
-
349
-
350
- __all__ = [
351
- "get_quantization_config",
352
- "QUANTIZATION_METHODS",
353
- ]