sglang 0.4.4__py3-none-any.whl → 0.4.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (176) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +6 -0
  3. sglang/bench_one_batch.py +1 -1
  4. sglang/bench_one_batch_server.py +1 -1
  5. sglang/bench_serving.py +3 -1
  6. sglang/check_env.py +3 -4
  7. sglang/lang/backend/openai.py +18 -5
  8. sglang/lang/chat_template.py +28 -7
  9. sglang/lang/interpreter.py +7 -3
  10. sglang/lang/ir.py +10 -0
  11. sglang/srt/_custom_ops.py +1 -1
  12. sglang/srt/code_completion_parser.py +174 -0
  13. sglang/srt/configs/__init__.py +2 -6
  14. sglang/srt/configs/deepseekvl2.py +667 -0
  15. sglang/srt/configs/janus_pro.py +3 -4
  16. sglang/srt/configs/load_config.py +1 -0
  17. sglang/srt/configs/model_config.py +63 -11
  18. sglang/srt/configs/utils.py +25 -0
  19. sglang/srt/connector/__init__.py +51 -0
  20. sglang/srt/connector/base_connector.py +112 -0
  21. sglang/srt/connector/redis.py +85 -0
  22. sglang/srt/connector/s3.py +122 -0
  23. sglang/srt/connector/serde/__init__.py +31 -0
  24. sglang/srt/connector/serde/safe_serde.py +29 -0
  25. sglang/srt/connector/serde/serde.py +43 -0
  26. sglang/srt/connector/utils.py +35 -0
  27. sglang/srt/conversation.py +88 -0
  28. sglang/srt/disaggregation/conn.py +81 -0
  29. sglang/srt/disaggregation/decode.py +495 -0
  30. sglang/srt/disaggregation/mini_lb.py +285 -0
  31. sglang/srt/disaggregation/prefill.py +249 -0
  32. sglang/srt/disaggregation/utils.py +44 -0
  33. sglang/srt/distributed/parallel_state.py +10 -3
  34. sglang/srt/entrypoints/engine.py +55 -5
  35. sglang/srt/entrypoints/http_server.py +71 -12
  36. sglang/srt/function_call_parser.py +164 -54
  37. sglang/srt/hf_transformers_utils.py +28 -3
  38. sglang/srt/layers/activation.py +4 -2
  39. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  40. sglang/srt/layers/attention/flashattention_backend.py +295 -0
  41. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  42. sglang/srt/layers/attention/flashmla_backend.py +284 -0
  43. sglang/srt/layers/attention/triton_backend.py +171 -38
  44. sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
  45. sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
  46. sglang/srt/layers/attention/utils.py +53 -0
  47. sglang/srt/layers/attention/vision.py +9 -28
  48. sglang/srt/layers/dp_attention.py +62 -23
  49. sglang/srt/layers/elementwise.py +411 -0
  50. sglang/srt/layers/layernorm.py +24 -2
  51. sglang/srt/layers/linear.py +17 -5
  52. sglang/srt/layers/logits_processor.py +26 -7
  53. sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
  54. sglang/srt/layers/moe/ep_moe/layer.py +273 -1
  55. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
  56. sglang/srt/layers/moe/fused_moe_native.py +2 -1
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
  63. sglang/srt/layers/moe/router.py +342 -0
  64. sglang/srt/layers/moe/topk.py +31 -18
  65. sglang/srt/layers/parameter.py +1 -1
  66. sglang/srt/layers/quantization/__init__.py +184 -126
  67. sglang/srt/layers/quantization/base_config.py +5 -0
  68. sglang/srt/layers/quantization/blockwise_int8.py +1 -1
  69. sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
  70. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
  71. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
  72. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
  73. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
  74. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
  75. sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
  76. sglang/srt/layers/quantization/fp8.py +76 -34
  77. sglang/srt/layers/quantization/fp8_kernel.py +24 -8
  78. sglang/srt/layers/quantization/fp8_utils.py +284 -28
  79. sglang/srt/layers/quantization/gptq.py +36 -9
  80. sglang/srt/layers/quantization/kv_cache.py +98 -0
  81. sglang/srt/layers/quantization/modelopt_quant.py +9 -7
  82. sglang/srt/layers/quantization/utils.py +153 -0
  83. sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
  84. sglang/srt/layers/rotary_embedding.py +66 -87
  85. sglang/srt/layers/sampler.py +1 -1
  86. sglang/srt/lora/layers.py +68 -0
  87. sglang/srt/lora/lora.py +2 -22
  88. sglang/srt/lora/lora_manager.py +47 -23
  89. sglang/srt/lora/mem_pool.py +110 -51
  90. sglang/srt/lora/utils.py +12 -1
  91. sglang/srt/managers/cache_controller.py +4 -5
  92. sglang/srt/managers/data_parallel_controller.py +31 -9
  93. sglang/srt/managers/expert_distribution.py +81 -0
  94. sglang/srt/managers/io_struct.py +39 -3
  95. sglang/srt/managers/mm_utils.py +373 -0
  96. sglang/srt/managers/multimodal_processor.py +68 -0
  97. sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
  98. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
  99. sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
  100. sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
  101. sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
  102. sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
  103. sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
  104. sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
  105. sglang/srt/managers/schedule_batch.py +134 -31
  106. sglang/srt/managers/scheduler.py +325 -38
  107. sglang/srt/managers/scheduler_output_processor_mixin.py +4 -1
  108. sglang/srt/managers/session_controller.py +1 -1
  109. sglang/srt/managers/tokenizer_manager.py +59 -23
  110. sglang/srt/managers/tp_worker.py +1 -1
  111. sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
  112. sglang/srt/managers/utils.py +6 -1
  113. sglang/srt/mem_cache/hiradix_cache.py +27 -8
  114. sglang/srt/mem_cache/memory_pool.py +258 -98
  115. sglang/srt/mem_cache/paged_allocator.py +2 -2
  116. sglang/srt/mem_cache/radix_cache.py +4 -4
  117. sglang/srt/model_executor/cuda_graph_runner.py +85 -28
  118. sglang/srt/model_executor/forward_batch_info.py +81 -15
  119. sglang/srt/model_executor/model_runner.py +70 -6
  120. sglang/srt/model_loader/loader.py +160 -2
  121. sglang/srt/model_loader/weight_utils.py +45 -0
  122. sglang/srt/models/deepseek_janus_pro.py +29 -86
  123. sglang/srt/models/deepseek_nextn.py +22 -10
  124. sglang/srt/models/deepseek_v2.py +326 -192
  125. sglang/srt/models/deepseek_vl2.py +358 -0
  126. sglang/srt/models/gemma3_causal.py +684 -0
  127. sglang/srt/models/gemma3_mm.py +462 -0
  128. sglang/srt/models/grok.py +374 -119
  129. sglang/srt/models/llama.py +47 -7
  130. sglang/srt/models/llama_eagle.py +1 -0
  131. sglang/srt/models/llama_eagle3.py +196 -0
  132. sglang/srt/models/llava.py +3 -3
  133. sglang/srt/models/llavavid.py +3 -3
  134. sglang/srt/models/minicpmo.py +1995 -0
  135. sglang/srt/models/minicpmv.py +62 -137
  136. sglang/srt/models/mllama.py +4 -4
  137. sglang/srt/models/phi3_small.py +1 -1
  138. sglang/srt/models/qwen2.py +3 -0
  139. sglang/srt/models/qwen2_5_vl.py +68 -146
  140. sglang/srt/models/qwen2_classification.py +75 -0
  141. sglang/srt/models/qwen2_moe.py +9 -1
  142. sglang/srt/models/qwen2_vl.py +25 -63
  143. sglang/srt/openai_api/adapter.py +145 -47
  144. sglang/srt/openai_api/protocol.py +23 -2
  145. sglang/srt/sampling/sampling_batch_info.py +1 -1
  146. sglang/srt/sampling/sampling_params.py +6 -6
  147. sglang/srt/server_args.py +104 -14
  148. sglang/srt/speculative/build_eagle_tree.py +7 -347
  149. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
  150. sglang/srt/speculative/eagle_utils.py +208 -252
  151. sglang/srt/speculative/eagle_worker.py +139 -53
  152. sglang/srt/speculative/spec_info.py +6 -1
  153. sglang/srt/torch_memory_saver_adapter.py +22 -0
  154. sglang/srt/utils.py +182 -21
  155. sglang/test/__init__.py +0 -0
  156. sglang/test/attention/__init__.py +0 -0
  157. sglang/test/attention/test_flashattn_backend.py +312 -0
  158. sglang/test/runners.py +2 -0
  159. sglang/test/test_activation.py +2 -1
  160. sglang/test/test_block_fp8.py +5 -4
  161. sglang/test/test_block_fp8_ep.py +2 -1
  162. sglang/test/test_dynamic_grad_mode.py +58 -0
  163. sglang/test/test_layernorm.py +3 -2
  164. sglang/test/test_utils.py +55 -4
  165. sglang/utils.py +31 -0
  166. sglang/version.py +1 -1
  167. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/METADATA +12 -8
  168. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/RECORD +171 -125
  169. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/WHEEL +1 -1
  170. sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
  171. sglang/srt/managers/image_processor.py +0 -55
  172. sglang/srt/managers/image_processors/base_image_processor.py +0 -219
  173. sglang/srt/managers/image_processors/minicpmv.py +0 -86
  174. sglang/srt/managers/multi_modality_padding.py +0 -134
  175. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info/licenses}/LICENSE +0 -0
  176. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/top_level.txt +0 -0
@@ -6,53 +6,82 @@ from copy import deepcopy
6
6
  from typing import Callable, Dict, Optional, Type, Union
7
7
 
8
8
  import torch
9
- from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
10
- from vllm.model_executor.layers.quantization.awq import AWQConfig
11
- from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig
12
- from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig
13
- from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import (
14
- CompressedTensorsConfig,
15
- )
16
- from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig
17
- from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
18
- from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config
19
- from vllm.model_executor.layers.quantization.gguf import GGUFConfig
20
- from vllm.model_executor.layers.quantization.gptq_marlin_24 import GPTQMarlin24Config
21
- from vllm.model_executor.layers.quantization.marlin import MarlinConfig
22
- from vllm.model_executor.layers.quantization.qqq import QQQConfig
23
- from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig
9
+
10
+ try:
11
+ from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
12
+ from vllm.model_executor.layers.quantization.awq import AWQConfig
13
+ from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig
14
+ from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig
15
+ from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig
16
+ from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
17
+ from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config
18
+ from vllm.model_executor.layers.quantization.gguf import GGUFConfig
19
+ from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
20
+ GPTQMarlin24Config,
21
+ )
22
+ from vllm.model_executor.layers.quantization.marlin import MarlinConfig
23
+ from vllm.model_executor.layers.quantization.qqq import QQQConfig
24
+ from vllm.model_executor.layers.quantization.tpu_int8 import Int8TpuConfig
25
+
26
+ from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
27
+
28
+ VLLM_AVAILABLE = True
29
+ except ImportError:
30
+ VLLM_AVAILABLE = False
31
+
32
+ # Define empty classes as placeholders when vllm is not available
33
+ class DummyConfig:
34
+ pass
35
+
36
+ AQLMConfig = AWQConfig = AWQMarlinConfig = BitsAndBytesConfig = (
37
+ CompressedTensorsConfig
38
+ ) = DummyConfig
39
+ DeepSpeedFPConfig = ExpertsInt8Config = FBGEMMFp8Config = GGUFConfig = (
40
+ GPTQMarlin24Config
41
+ ) = DummyConfig
42
+ MarlinConfig = QQQConfig = Int8TpuConfig = DummyConfig
24
43
 
25
44
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
26
45
  from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config
46
+ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import (
47
+ CompressedTensorsConfig,
48
+ )
27
49
  from sglang.srt.layers.quantization.fp8 import Fp8Config
28
- from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
29
50
  from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp8Config
30
51
  from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
31
52
  from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
32
53
 
33
- QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
34
- "aqlm": AQLMConfig,
35
- "awq": AWQConfig,
36
- "deepspeedfp": DeepSpeedFPConfig,
37
- "tpu_int8": Int8TpuConfig,
54
+ # Base quantization methods that don't depend on vllm
55
+ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
38
56
  "fp8": Fp8Config,
39
57
  "blockwise_int8": BlockInt8Config,
40
- "fbgemm_fp8": FBGEMMFp8Config,
41
- "marlin": MarlinConfig,
42
58
  "modelopt": ModelOptFp8Config,
43
- "gguf": GGUFConfig,
44
- "gptq_marlin_24": GPTQMarlin24Config,
45
- "gptq_marlin": GPTQMarlinConfig,
46
- "awq_marlin": AWQMarlinConfig,
47
- "gptq": GPTQConfig,
48
- "compressed-tensors": CompressedTensorsConfig,
49
- "bitsandbytes": BitsAndBytesConfig,
50
- "qqq": QQQConfig,
51
- "experts_int8": ExpertsInt8Config,
52
59
  "w8a8_int8": W8A8Int8Config,
53
60
  "w8a8_fp8": W8A8Fp8Config,
61
+ "compressed-tensors": CompressedTensorsConfig,
54
62
  }
55
63
 
64
+ # Add vllm-dependent methods if available
65
+ QUANTIZATION_METHODS = BASE_QUANTIZATION_METHODS.copy()
66
+ if VLLM_AVAILABLE:
67
+ VLLM_QUANTIZATION_METHODS = {
68
+ "aqlm": AQLMConfig,
69
+ "awq": AWQConfig,
70
+ "deepspeedfp": DeepSpeedFPConfig,
71
+ "tpu_int8": Int8TpuConfig,
72
+ "fbgemm_fp8": FBGEMMFp8Config,
73
+ "marlin": MarlinConfig,
74
+ "gguf": GGUFConfig,
75
+ "gptq_marlin_24": GPTQMarlin24Config,
76
+ "awq_marlin": AWQMarlinConfig,
77
+ "bitsandbytes": BitsAndBytesConfig,
78
+ "qqq": QQQConfig,
79
+ "experts_int8": ExpertsInt8Config,
80
+ "gptq_marlin": GPTQMarlinConfig,
81
+ "gptq": GPTQConfig,
82
+ }
83
+ QUANTIZATION_METHODS.update(VLLM_QUANTIZATION_METHODS)
84
+
56
85
 
57
86
  def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
58
87
  if quantization not in QUANTIZATION_METHODS:
@@ -157,25 +186,31 @@ def get_linear_quant_method(
157
186
 
158
187
 
159
188
  def gptq_get_quant_method(self, layer, prefix):
160
- from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
161
- from vllm.model_executor.layers.quantization.gptq_marlin import (
162
- GPTQMarlinLinearMethod,
163
- GPTQMarlinMoEMethod,
164
- )
189
+ if not VLLM_AVAILABLE:
190
+ return None
191
+
192
+ try:
193
+ from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
194
+ from vllm.model_executor.layers.quantization.gptq_marlin import (
195
+ GPTQMarlinLinearMethod,
196
+ GPTQMarlinMoEMethod,
197
+ )
165
198
 
166
- from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
199
+ from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
167
200
 
168
- if isinstance(layer, FusedMoE):
169
- return GPTQMarlinMoEMethod(self)
201
+ if isinstance(layer, FusedMoE):
202
+ return GPTQMarlinMoEMethod(self)
170
203
 
171
- if isinstance(self, GPTQConfig):
172
- return get_linear_quant_method(
173
- self, layer, prefix=prefix, linear_method_cls=GPTQLinearMethod
174
- )
175
- elif isinstance(self, GPTQMarlinConfig):
176
- return get_linear_quant_method(
177
- self, layer, prefix=prefix, linear_method_cls=GPTQMarlinLinearMethod
178
- )
204
+ if isinstance(self, GPTQConfig):
205
+ return get_linear_quant_method(
206
+ self, layer, prefix=prefix, linear_method_cls=GPTQLinearMethod
207
+ )
208
+ elif isinstance(self, GPTQMarlinConfig):
209
+ return get_linear_quant_method(
210
+ self, layer, prefix=prefix, linear_method_cls=GPTQMarlinLinearMethod
211
+ )
212
+ except ImportError:
213
+ pass
179
214
  return None
180
215
 
181
216
 
@@ -187,33 +222,40 @@ def monkey_patch_isinstance_for_vllm_base_layer(reverse: bool = False):
187
222
  Patch isinstance so that the `get_quant_method` in vllm's QuantizationConfig
188
223
  can recognize sglang layers
189
224
  """
225
+ if not VLLM_AVAILABLE:
226
+ return
190
227
 
191
228
  if reverse:
192
229
  builtins.isinstance = original_isinstance
193
230
  return
194
231
 
195
- from vllm.model_executor.layers.fused_moe import FusedMoE
196
- from vllm.model_executor.layers.linear import LinearBase
197
- from vllm.model_executor.layers.vocab_parallel_embedding import (
198
- VocabParallelEmbedding,
199
- )
200
-
201
- from sglang.srt.layers.linear import LinearBase as PatchedLinearBase
202
- from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE as PatchedFusedMoE
203
- from sglang.srt.layers.vocab_parallel_embedding import (
204
- VocabParallelEmbedding as PatchedVocabParallelEmbedding,
205
- )
232
+ try:
233
+ from vllm.model_executor.layers.fused_moe import FusedMoE
234
+ from vllm.model_executor.layers.linear import LinearBase
235
+ from vllm.model_executor.layers.vocab_parallel_embedding import (
236
+ VocabParallelEmbedding,
237
+ )
206
238
 
207
- def patched_isinstance(obj, classinfo):
208
- if classinfo is LinearBase:
209
- return original_isinstance(obj, PatchedLinearBase)
210
- if classinfo is FusedMoE:
211
- return original_isinstance(obj, PatchedFusedMoE)
212
- if classinfo is VocabParallelEmbedding:
213
- return original_isinstance(obj, PatchedVocabParallelEmbedding)
214
- return original_isinstance(obj, classinfo)
239
+ from sglang.srt.layers.linear import LinearBase as PatchedLinearBase
240
+ from sglang.srt.layers.moe.fused_moe_triton.layer import (
241
+ FusedMoE as PatchedFusedMoE,
242
+ )
243
+ from sglang.srt.layers.vocab_parallel_embedding import (
244
+ VocabParallelEmbedding as PatchedVocabParallelEmbedding,
245
+ )
215
246
 
216
- builtins.isinstance = patched_isinstance
247
+ def patched_isinstance(obj, classinfo):
248
+ if classinfo is LinearBase:
249
+ return original_isinstance(obj, PatchedLinearBase)
250
+ if classinfo is FusedMoE:
251
+ return original_isinstance(obj, PatchedFusedMoE)
252
+ if classinfo is VocabParallelEmbedding:
253
+ return original_isinstance(obj, PatchedVocabParallelEmbedding)
254
+ return original_isinstance(obj, classinfo)
255
+
256
+ builtins.isinstance = patched_isinstance
257
+ except ImportError:
258
+ return
217
259
 
218
260
 
219
261
  def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
@@ -221,72 +263,88 @@ def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
221
263
  Monkey patch the apply function of vllm's FusedMoEMethodBase.
222
264
  Convert sglang arguments to vllm arguments.
223
265
  """
224
- original_apply = class_obj.apply
225
- sig = inspect.signature(original_apply)
226
- param_names = list(sig.parameters.keys())
227
- has_correction_bias = "e_score_correction_bias" in param_names
228
-
229
- def new_apply(
230
- self,
231
- layer: torch.nn.Module,
232
- x: torch.Tensor,
233
- router_logits: torch.Tensor,
234
- top_k: int,
235
- renormalize: bool,
236
- use_grouped_topk: bool,
237
- topk_group: Optional[int] = None,
238
- num_expert_group: Optional[int] = None,
239
- custom_routing_function: Optional[Callable] = None,
240
- correction_bias: Optional[torch.Tensor] = None,
241
- activation: str = "silu",
242
- inplace: bool = True,
243
- no_combine: bool = False,
244
- ):
245
- assert activation == "silu"
246
- assert inplace and not no_combine
247
-
248
- kwargs = {
249
- "self": self,
250
- "layer": layer,
251
- "x": x,
252
- "router_logits": router_logits,
253
- "top_k": top_k,
254
- "renormalize": renormalize,
255
- "use_grouped_topk": use_grouped_topk,
256
- "topk_group": topk_group,
257
- "num_expert_group": num_expert_group,
258
- "custom_routing_function": custom_routing_function,
259
- }
260
- if correction_bias is not None:
261
- if not has_correction_bias:
262
- raise ValueError(
263
- "Please increase the version of your vllm. Try `pip install vllm==0.7.2`"
264
- )
265
- kwargs["e_score_correction_bias"] = correction_bias
266
- return original_apply(**kwargs)
267
-
268
- setattr(class_obj, "apply", new_apply)
266
+ if not VLLM_AVAILABLE:
267
+ return
268
+
269
+ try:
270
+ original_apply = class_obj.apply
271
+ sig = inspect.signature(original_apply)
272
+ param_names = list(sig.parameters.keys())
273
+ has_correction_bias = "e_score_correction_bias" in param_names
274
+
275
+ def new_apply(
276
+ self,
277
+ layer: torch.nn.Module,
278
+ x: torch.Tensor,
279
+ router_logits: torch.Tensor,
280
+ top_k: int,
281
+ renormalize: bool,
282
+ use_grouped_topk: bool,
283
+ topk_group: Optional[int] = None,
284
+ num_expert_group: Optional[int] = None,
285
+ custom_routing_function: Optional[Callable] = None,
286
+ correction_bias: Optional[torch.Tensor] = None,
287
+ activation: str = "silu",
288
+ inplace: bool = True,
289
+ no_combine: bool = False,
290
+ ):
291
+ assert activation == "silu"
292
+ assert inplace and not no_combine
293
+
294
+ kwargs = {
295
+ "self": self,
296
+ "layer": layer,
297
+ "x": x,
298
+ "router_logits": router_logits,
299
+ "top_k": top_k,
300
+ "renormalize": renormalize,
301
+ "use_grouped_topk": use_grouped_topk,
302
+ "topk_group": topk_group,
303
+ "num_expert_group": num_expert_group,
304
+ "custom_routing_function": custom_routing_function,
305
+ }
306
+ if correction_bias is not None:
307
+ if not has_correction_bias:
308
+ raise ValueError(
309
+ "Please increase the version of your vllm. Try `pip install vllm==0.7.2`"
310
+ )
311
+ kwargs["e_score_correction_bias"] = correction_bias
312
+ return original_apply(**kwargs)
313
+
314
+ setattr(class_obj, "apply", new_apply)
315
+ except (ImportError, AttributeError):
316
+ return
269
317
 
270
318
 
271
319
  def monkey_patch_quant_configs():
272
320
  """Apply all monkey patches in one place."""
273
- from vllm.model_executor.layers.quantization.awq_marlin import AWQMoEMethod
274
- from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import (
275
- CompressedTensorsW8A8Fp8MoEMethod,
276
- CompressedTensorsWNA16MoEMethod,
277
- )
278
- from vllm.model_executor.layers.quantization.gptq_marlin import GPTQMarlinMoEMethod
321
+ if not VLLM_AVAILABLE:
322
+ return
279
323
 
280
- setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
281
- setattr(GPTQConfig, "get_quant_method", gptq_get_quant_method)
324
+ try:
325
+ from vllm.model_executor.layers.quantization.awq_marlin import AWQMoEMethod
326
+ from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import (
327
+ CompressedTensorsW8A8Fp8MoEMethod,
328
+ CompressedTensorsWNA16MoEMethod,
329
+ )
330
+ from vllm.model_executor.layers.quantization.gptq_marlin import (
331
+ GPTQMarlinMoEMethod,
332
+ )
282
333
 
283
- monkey_patch_moe_apply(AWQMoEMethod)
284
- monkey_patch_moe_apply(GPTQMarlinMoEMethod)
285
- monkey_patch_moe_apply(CompressedTensorsW8A8Fp8MoEMethod)
286
- monkey_patch_moe_apply(CompressedTensorsWNA16MoEMethod)
334
+ setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
335
+ setattr(GPTQConfig, "get_quant_method", gptq_get_quant_method)
336
+
337
+ monkey_patch_moe_apply(AWQMoEMethod)
338
+ monkey_patch_moe_apply(GPTQMarlinMoEMethod)
339
+ monkey_patch_moe_apply(CompressedTensorsW8A8Fp8MoEMethod)
340
+ monkey_patch_moe_apply(CompressedTensorsWNA16MoEMethod)
341
+ except ImportError:
342
+ return
287
343
 
288
344
 
289
- monkey_patch_quant_configs()
345
+ # Only apply monkey patches if vllm is available
346
+ if VLLM_AVAILABLE:
347
+ monkey_patch_quant_configs()
290
348
 
291
349
 
292
350
  __all__ = [
@@ -38,6 +38,11 @@ class QuantizeMethodBase(ABC):
38
38
  class QuantizationConfig(ABC):
39
39
  """Base class for quantization configs."""
40
40
 
41
+ def __init__(self):
42
+ super().__init__()
43
+ # mapping is updated by models as they initialize
44
+ self.packed_modules_mapping: Dict[str, List[str]] = dict()
45
+
41
46
  @abstractmethod
42
47
  def get_name(self) -> str:
43
48
  """Name of the quantization method."""
@@ -5,7 +5,6 @@ from typing import Any, Callable, Dict, List, Optional
5
5
 
6
6
  import torch
7
7
  from torch.nn import Module
8
- from vllm.model_executor.layers.quantization.utils.quant_utils import is_layer_skipped
9
8
 
10
9
  from sglang.srt.distributed import get_tensor_model_parallel_world_size
11
10
  from sglang.srt.layers.linear import (
@@ -19,6 +18,7 @@ from sglang.srt.layers.quantization.base_config import (
19
18
  QuantizeMethodBase,
20
19
  )
21
20
  from sglang.srt.layers.quantization.int8_utils import apply_w8a8_block_int8_linear
21
+ from sglang.srt.layers.quantization.utils import is_layer_skipped
22
22
  from sglang.srt.utils import set_weight_attrs
23
23
 
24
24
  ACTIVATION_SCHEMES = ["static", "dynamic"]