sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. sglang/bench_one_batch.py +2 -1
  2. sglang/eval/loogle_eval.py +7 -0
  3. sglang/srt/_custom_ops.py +29 -1
  4. sglang/srt/configs/deepseekvl2.py +11 -2
  5. sglang/srt/configs/internvl.py +3 -0
  6. sglang/srt/configs/janus_pro.py +3 -0
  7. sglang/srt/configs/model_config.py +10 -8
  8. sglang/srt/configs/update_config.py +3 -1
  9. sglang/srt/conversation.py +2 -1
  10. sglang/srt/custom_op.py +5 -2
  11. sglang/srt/disaggregation/common/conn.py +34 -6
  12. sglang/srt/disaggregation/decode.py +9 -1
  13. sglang/srt/disaggregation/mini_lb.py +3 -2
  14. sglang/srt/disaggregation/mooncake/conn.py +93 -76
  15. sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
  16. sglang/srt/disaggregation/nixl/conn.py +17 -13
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
  18. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
  19. sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
  20. sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
  21. sglang/srt/distributed/parallel_state.py +103 -15
  22. sglang/srt/entrypoints/engine.py +31 -33
  23. sglang/srt/entrypoints/http_server.py +20 -32
  24. sglang/srt/entrypoints/openai/protocol.py +3 -3
  25. sglang/srt/entrypoints/openai/serving_chat.py +48 -6
  26. sglang/srt/eplb/expert_location_dispatch.py +1 -1
  27. sglang/srt/function_call/base_format_detector.py +74 -12
  28. sglang/srt/function_call/deepseekv3_detector.py +26 -11
  29. sglang/srt/function_call/ebnf_composer.py +95 -63
  30. sglang/srt/function_call/function_call_parser.py +4 -2
  31. sglang/srt/function_call/kimik2_detector.py +41 -16
  32. sglang/srt/function_call/llama32_detector.py +6 -3
  33. sglang/srt/function_call/mistral_detector.py +11 -3
  34. sglang/srt/function_call/pythonic_detector.py +16 -14
  35. sglang/srt/function_call/qwen25_detector.py +12 -3
  36. sglang/srt/function_call/qwen3_coder_detector.py +151 -0
  37. sglang/srt/hf_transformers_utils.py +0 -1
  38. sglang/srt/layers/activation.py +24 -3
  39. sglang/srt/layers/attention/base_attn_backend.py +3 -1
  40. sglang/srt/layers/attention/flashattention_backend.py +3 -3
  41. sglang/srt/layers/attention/flashinfer_backend.py +40 -1
  42. sglang/srt/layers/communicator.py +12 -12
  43. sglang/srt/layers/dp_attention.py +72 -24
  44. sglang/srt/layers/linear.py +13 -102
  45. sglang/srt/layers/logits_processor.py +34 -24
  46. sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
  47. sglang/srt/layers/moe/ep_moe/layer.py +23 -402
  48. sglang/srt/layers/moe/fused_moe_native.py +7 -47
  49. sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
  50. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  52. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +54 -263
  57. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
  58. sglang/srt/layers/moe/topk.py +190 -23
  59. sglang/srt/layers/quantization/__init__.py +20 -134
  60. sglang/srt/layers/quantization/awq.py +578 -11
  61. sglang/srt/layers/quantization/awq_triton.py +339 -0
  62. sglang/srt/layers/quantization/base_config.py +85 -10
  63. sglang/srt/layers/quantization/blockwise_int8.py +17 -55
  64. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
  65. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +23 -79
  66. sglang/srt/layers/quantization/fp8.py +273 -62
  67. sglang/srt/layers/quantization/fp8_kernel.py +210 -46
  68. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  69. sglang/srt/layers/quantization/gptq.py +501 -143
  70. sglang/srt/layers/quantization/marlin_utils.py +790 -0
  71. sglang/srt/layers/quantization/modelopt_quant.py +34 -112
  72. sglang/srt/layers/quantization/moe_wna16.py +45 -49
  73. sglang/srt/layers/quantization/petit.py +252 -0
  74. sglang/srt/layers/quantization/petit_utils.py +104 -0
  75. sglang/srt/layers/quantization/qoq.py +7 -6
  76. sglang/srt/layers/quantization/scalar_type.py +352 -0
  77. sglang/srt/layers/quantization/unquant.py +422 -0
  78. sglang/srt/layers/quantization/utils.py +340 -9
  79. sglang/srt/layers/quantization/w4afp8.py +8 -4
  80. sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
  81. sglang/srt/layers/quantization/w8a8_int8.py +51 -115
  82. sglang/srt/layers/radix_attention.py +5 -3
  83. sglang/srt/layers/vocab_parallel_embedding.py +1 -41
  84. sglang/srt/lora/lora.py +0 -4
  85. sglang/srt/lora/lora_manager.py +162 -164
  86. sglang/srt/lora/lora_registry.py +124 -0
  87. sglang/srt/lora/mem_pool.py +83 -35
  88. sglang/srt/lora/utils.py +12 -5
  89. sglang/srt/managers/cache_controller.py +288 -0
  90. sglang/srt/managers/io_struct.py +60 -30
  91. sglang/srt/managers/mm_utils.py +7 -8
  92. sglang/srt/managers/schedule_batch.py +163 -113
  93. sglang/srt/managers/schedule_policy.py +68 -27
  94. sglang/srt/managers/scheduler.py +256 -86
  95. sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
  96. sglang/srt/managers/tokenizer_manager.py +38 -27
  97. sglang/srt/managers/tp_worker.py +16 -4
  98. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  99. sglang/srt/mem_cache/allocator.py +74 -23
  100. sglang/srt/mem_cache/base_prefix_cache.py +14 -2
  101. sglang/srt/mem_cache/chunk_cache.py +5 -2
  102. sglang/srt/mem_cache/hicache_storage.py +168 -0
  103. sglang/srt/mem_cache/hiradix_cache.py +194 -5
  104. sglang/srt/mem_cache/memory_pool.py +16 -1
  105. sglang/srt/mem_cache/memory_pool_host.py +44 -2
  106. sglang/srt/mem_cache/radix_cache.py +26 -0
  107. sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
  108. sglang/srt/metrics/collector.py +9 -0
  109. sglang/srt/model_executor/cuda_graph_runner.py +66 -31
  110. sglang/srt/model_executor/forward_batch_info.py +210 -25
  111. sglang/srt/model_executor/model_runner.py +147 -42
  112. sglang/srt/model_loader/loader.py +7 -1
  113. sglang/srt/model_loader/utils.py +4 -4
  114. sglang/srt/models/clip.py +1 -1
  115. sglang/srt/models/deepseek.py +9 -6
  116. sglang/srt/models/deepseek_janus_pro.py +1 -1
  117. sglang/srt/models/deepseek_v2.py +192 -173
  118. sglang/srt/models/deepseek_vl2.py +5 -5
  119. sglang/srt/models/gemma.py +48 -0
  120. sglang/srt/models/gemma2.py +52 -0
  121. sglang/srt/models/gemma3_causal.py +63 -0
  122. sglang/srt/models/gemma3_mm.py +1 -1
  123. sglang/srt/models/gemma3n_mm.py +2 -4
  124. sglang/srt/models/granitemoe.py +385 -0
  125. sglang/srt/models/grok.py +9 -3
  126. sglang/srt/models/hunyuan.py +63 -16
  127. sglang/srt/models/internvl.py +1 -1
  128. sglang/srt/models/kimi_vl.py +1 -1
  129. sglang/srt/models/llama.py +41 -0
  130. sglang/srt/models/llama4.py +11 -11
  131. sglang/srt/models/llava.py +2 -2
  132. sglang/srt/models/llavavid.py +1 -1
  133. sglang/srt/models/minicpm.py +0 -2
  134. sglang/srt/models/minicpmo.py +3 -7
  135. sglang/srt/models/minicpmv.py +1 -1
  136. sglang/srt/models/mistral.py +1 -1
  137. sglang/srt/models/mixtral.py +9 -2
  138. sglang/srt/models/mllama.py +3 -5
  139. sglang/srt/models/mllama4.py +13 -6
  140. sglang/srt/models/olmoe.py +8 -5
  141. sglang/srt/models/persimmon.py +330 -0
  142. sglang/srt/models/phi.py +321 -0
  143. sglang/srt/models/phi4mm.py +44 -4
  144. sglang/srt/models/phi4mm_audio.py +1260 -0
  145. sglang/srt/models/phi4mm_utils.py +1917 -0
  146. sglang/srt/models/phimoe.py +9 -3
  147. sglang/srt/models/qwen.py +37 -0
  148. sglang/srt/models/qwen2.py +41 -0
  149. sglang/srt/models/qwen2_5_vl.py +4 -4
  150. sglang/srt/models/qwen2_audio.py +1 -1
  151. sglang/srt/models/qwen2_moe.py +53 -9
  152. sglang/srt/models/qwen2_vl.py +4 -4
  153. sglang/srt/models/qwen3.py +65 -1
  154. sglang/srt/models/qwen3_moe.py +57 -24
  155. sglang/srt/models/vila.py +1 -1
  156. sglang/srt/multimodal/processors/base_processor.py +91 -97
  157. sglang/srt/multimodal/processors/clip.py +21 -19
  158. sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
  159. sglang/srt/multimodal/processors/gemma3.py +13 -17
  160. sglang/srt/multimodal/processors/gemma3n.py +19 -23
  161. sglang/srt/multimodal/processors/internvl.py +9 -10
  162. sglang/srt/multimodal/processors/janus_pro.py +12 -27
  163. sglang/srt/multimodal/processors/kimi_vl.py +12 -14
  164. sglang/srt/multimodal/processors/llava.py +4 -2
  165. sglang/srt/multimodal/processors/minicpm.py +35 -44
  166. sglang/srt/multimodal/processors/mlama.py +21 -18
  167. sglang/srt/multimodal/processors/mllama4.py +4 -5
  168. sglang/srt/multimodal/processors/phi4mm.py +63 -39
  169. sglang/srt/multimodal/processors/pixtral.py +14 -35
  170. sglang/srt/multimodal/processors/qwen_audio.py +65 -0
  171. sglang/srt/multimodal/processors/qwen_vl.py +16 -21
  172. sglang/srt/multimodal/processors/vila.py +14 -14
  173. sglang/srt/reasoning_parser.py +46 -4
  174. sglang/srt/sampling/sampling_batch_info.py +6 -5
  175. sglang/srt/sampling/sampling_params.py +8 -1
  176. sglang/srt/server_args.py +454 -270
  177. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
  178. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +46 -37
  179. sglang/srt/speculative/eagle_utils.py +51 -23
  180. sglang/srt/speculative/eagle_worker.py +59 -44
  181. sglang/srt/two_batch_overlap.py +10 -5
  182. sglang/srt/utils.py +44 -69
  183. sglang/test/runners.py +14 -3
  184. sglang/test/test_activation.py +50 -1
  185. sglang/test/test_block_fp8.py +8 -3
  186. sglang/test/test_block_fp8_ep.py +1 -1
  187. sglang/test/test_custom_ops.py +12 -7
  188. sglang/test/test_cutlass_w4a8_moe.py +1 -3
  189. sglang/test/test_fp4_moe.py +1 -3
  190. sglang/test/test_marlin_moe.py +286 -0
  191. sglang/test/test_marlin_utils.py +171 -0
  192. sglang/test/test_utils.py +35 -0
  193. sglang/version.py +1 -1
  194. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/METADATA +10 -10
  195. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/RECORD +198 -175
  196. sglang/srt/layers/quantization/quant_utils.py +0 -166
  197. sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
  198. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/WHEEL +0 -0
  199. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/licenses/LICENSE +0 -0
  200. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,6 @@
1
1
  # Adapted from https://github.com/vllm-project/vllm/tree/v0.8.2/vllm/model_executor/layers/quantization/compressed_tensors
2
2
  # SPDX-License-Identifier: Apache-2.0
3
+ from __future__ import annotations
3
4
 
4
5
  import logging
5
6
  from contextlib import suppress
@@ -18,12 +19,8 @@ from compressed_tensors.quantization import (
18
19
  )
19
20
  from pydantic import BaseModel
20
21
 
21
- from sglang.srt.layers.linear import (
22
- LinearBase,
23
- LinearMethodBase,
24
- UnquantizedLinearMethod,
25
- )
26
22
  from sglang.srt.layers.quantization.base_config import (
23
+ LinearMethodBase,
27
24
  QuantizationConfig,
28
25
  QuantizeMethodBase,
29
26
  )
@@ -40,9 +37,13 @@ from sglang.srt.layers.quantization.compressed_tensors.utils import (
40
37
  is_activation_quantization_format,
41
38
  should_ignore_layer,
42
39
  )
40
+ from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
43
41
 
44
42
  try:
45
- import vllm
43
+ from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import (
44
+ WNA16_SUPPORTED_BITS,
45
+ CompressedTensorsWNA16,
46
+ )
46
47
 
47
48
  VLLM_AVAILABLE = True
48
49
  except ImportError:
@@ -97,7 +98,7 @@ class CompressedTensorsConfig(QuantizationConfig):
97
98
  self.config = config
98
99
  self.packed_modules_mapping = packed_modules_mapping
99
100
 
100
- def get_linear_method(self) -> "CompressedTensorsLinearMethod":
101
+ def get_linear_method(self) -> CompressedTensorsLinearMethod:
101
102
  return CompressedTensorsLinearMethod(self)
102
103
 
103
104
  def get_supported_act_dtypes(cls) -> List[torch.dtype]:
@@ -117,7 +118,8 @@ class CompressedTensorsConfig(QuantizationConfig):
117
118
  self,
118
119
  layer: torch.nn.Module,
119
120
  prefix: str,
120
- ) -> Optional["QuantizeMethodBase"]:
121
+ ) -> Optional[QuantizeMethodBase]:
122
+ from sglang.srt.layers.linear import LinearBase
121
123
 
122
124
  # Check if the layer is skipped for quantization.
123
125
  # TODO (@robertgshaw2): support module names
@@ -138,7 +140,7 @@ class CompressedTensorsConfig(QuantizationConfig):
138
140
  return None
139
141
 
140
142
  @classmethod
141
- def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig":
143
+ def from_config(cls, config: Dict[str, Any]) -> CompressedTensorsConfig:
142
144
  ignore: List[str] = cast(List[str], config.get("ignore", []))
143
145
  quant_format = cast(str, config.get("format"))
144
146
  target_scheme_map = cls._quantization_scheme_map_from_config(config=config)
@@ -357,7 +359,7 @@ class CompressedTensorsConfig(QuantizationConfig):
357
359
 
358
360
  def _get_scheme_from_parts(
359
361
  self, weight_quant: BaseModel, input_quant: BaseModel
360
- ) -> "CompressedTensorsScheme":
362
+ ) -> CompressedTensorsScheme:
361
363
 
362
364
  # Detect If Mixed Precision
363
365
  if self._is_wNa16_group_channel(weight_quant, input_quant):
@@ -435,7 +437,7 @@ class CompressedTensorsConfig(QuantizationConfig):
435
437
 
436
438
  def get_scheme(
437
439
  self, layer: torch.nn.Module, layer_name: Optional[str] = None
438
- ) -> Optional["CompressedTensorsScheme"]:
440
+ ) -> Optional[CompressedTensorsScheme]:
439
441
  """
440
442
  compressed-tensors supports non uniform in the following way:
441
443
 
@@ -1,15 +1,17 @@
1
1
  # Adapted from https://github.com/vllm-project/vllm/tree/v0.8.2/vllm/model_executor/layers/quantization/compressed_tensors
2
2
  # SPDX-License-Identifier: Apache-2.0
3
+ from __future__ import annotations
3
4
 
4
5
  import enum
5
6
  import logging
6
7
  from enum import Enum
7
- from typing import Callable, List, Optional
8
+ from typing import TYPE_CHECKING, List, Optional
8
9
 
9
10
  import torch
10
11
  from compressed_tensors import CompressionFormat
11
12
  from compressed_tensors.quantization import QuantizationStrategy
12
13
 
14
+ from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase
13
15
  from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_quant
14
16
  from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
15
17
  from sglang.srt.layers.quantization.utils import (
@@ -18,16 +20,14 @@ from sglang.srt.layers.quantization.utils import (
18
20
  per_tensor_dequantize,
19
21
  replace_parameter,
20
22
  )
21
- from sglang.srt.utils import is_cpu, is_cuda, is_npu, set_weight_attrs
23
+ from sglang.srt.utils import is_cpu, is_cuda, is_hip, is_npu, set_weight_attrs
22
24
 
23
- _is_cuda = is_cuda()
24
- _is_npu = is_npu()
25
- _is_cpu_amx_available = cpu_has_amx_support()
26
- _is_cpu = is_cpu()
25
+ if TYPE_CHECKING:
26
+ from sglang.srt.layers.moe.topk import TopKOutput
27
+ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import (
28
+ CompressedTensorsConfig,
29
+ )
27
30
 
28
- if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available)):
29
- from vllm import _custom_ops as vllm_ops
30
- from vllm._custom_ops import scaled_fp8_quant
31
31
 
32
32
  try:
33
33
  import vllm
@@ -51,7 +51,7 @@ __all__ = [
51
51
  ]
52
52
 
53
53
 
54
- class CompressedTensorsMoEMethod:
54
+ class CompressedTensorsMoEMethod(FusedMoEMethodBase):
55
55
  def __new__(cls, *args, **kwargs):
56
56
  if cls is CompressedTensorsMoEMethod:
57
57
  return super().__new__(cls)
@@ -59,7 +59,7 @@ class CompressedTensorsMoEMethod:
59
59
 
60
60
  @staticmethod
61
61
  def get_moe_method(
62
- quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
62
+ quant_config: CompressedTensorsConfig,
63
63
  ) -> "CompressedTensorsMoEMethod":
64
64
  # TODO: @dsikka: refactor this to use schemes as other kernels
65
65
  # are supported + check if the layer is being ignored.
@@ -82,9 +82,7 @@ class CompressedTensorsMoEMethod:
82
82
 
83
83
  class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
84
84
 
85
- def __init__(
86
- self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
87
- ):
85
+ def __init__(self, quant_config: CompressedTensorsConfig):
88
86
  self.quant_config = quant_config
89
87
  self.weight_quant = self.quant_config.target_scheme_map["Linear"].get("weights")
90
88
  self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
@@ -270,47 +268,21 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
270
268
  self,
271
269
  layer: torch.nn.Module,
272
270
  x: torch.Tensor,
273
- router_logits: torch.Tensor,
274
- top_k: int,
275
- renormalize: bool,
276
- use_grouped_topk: bool = False,
277
- topk_group: Optional[int] = None,
278
- num_expert_group: Optional[int] = None,
279
- num_fused_shared_experts: int = 0,
280
- global_num_experts: int = -1,
281
- expert_map: Optional[torch.Tensor] = None,
282
- custom_routing_function: Optional[Callable] = None,
283
- scoring_func: str = "softmax",
284
- correction_bias: Optional[torch.Tensor] = None,
271
+ topk_output: TopKOutput,
272
+ *,
285
273
  activation: str = "silu",
274
+ apply_router_weight_on_input: bool = False,
286
275
  inplace: bool = True,
287
276
  no_combine: bool = False,
288
- apply_router_weight_on_input: bool = False,
289
277
  routed_scaling_factor: Optional[float] = None,
290
278
  ) -> torch.Tensor:
291
279
  from sglang.srt.layers.moe.fused_moe_triton import fused_experts
292
- from sglang.srt.layers.moe.topk import select_experts
293
-
294
- topk_weights, topk_ids = select_experts(
295
- hidden_states=x,
296
- router_logits=router_logits,
297
- use_grouped_topk=use_grouped_topk,
298
- top_k=top_k,
299
- renormalize=renormalize,
300
- topk_group=topk_group,
301
- num_expert_group=num_expert_group,
302
- num_fused_shared_experts=num_fused_shared_experts,
303
- custom_routing_function=custom_routing_function,
304
- correction_bias=correction_bias,
305
- routed_scaling_factor=routed_scaling_factor,
306
- )
307
280
 
308
281
  return fused_experts(
309
282
  x,
310
283
  layer.w13_weight,
311
284
  layer.w2_weight,
312
- topk_weights=topk_weights,
313
- topk_ids=topk_ids,
285
+ topk_output=topk_output,
314
286
  inplace=inplace,
315
287
  activation=activation,
316
288
  use_fp8_w8a8=True,
@@ -327,9 +299,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
327
299
 
328
300
  class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
329
301
 
330
- def __init__(
331
- self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
332
- ):
302
+ def __init__(self, quant_config: CompressedTensorsConfig):
333
303
  self.quant_config = quant_config
334
304
  # TODO: @dsikka: refactor this to use schemes as other kernels
335
305
  # are supported + check if the layer is being ignored.
@@ -589,6 +559,8 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
589
559
  requires_grad=False,
590
560
  )
591
561
 
562
+ from vllm import _custom_ops as vllm_ops
563
+
592
564
  marlin_w13_qweight = vllm_ops.gptq_marlin_moe_repack(
593
565
  layer.w13_weight_packed,
594
566
  layer.w13_g_idx_sort_indices,
@@ -628,43 +600,15 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
628
600
  self,
629
601
  layer: torch.nn.Module,
630
602
  x: torch.Tensor,
631
- router_logits: torch.Tensor,
632
- top_k: int,
633
- renormalize: bool,
634
- use_grouped_topk: bool = False,
635
- topk_group: Optional[int] = None,
636
- num_expert_group: Optional[int] = None,
637
- num_fused_shared_experts: int = 0,
638
- global_num_experts: int = -1,
639
- expert_map: Optional[torch.Tensor] = None,
640
- custom_routing_function: Optional[Callable] = None,
641
- scoring_func: str = "softmax",
642
- correction_bias: Optional[torch.Tensor] = None,
603
+ topk_output: TopKOutput,
604
+ *,
643
605
  activation: str = "silu",
644
- routed_scaling_factor: Optional[float] = None,
606
+ **kwargs,
645
607
  ) -> torch.Tensor:
646
- from sglang.srt.layers.moe.topk import select_experts
647
608
 
648
609
  assert activation == "silu", "Only SiLU activation is supported."
649
- if expert_map is not None:
650
- raise NotImplementedError(
651
- "Expert Parallelism is not supported for " "fused Marlin MoE method."
652
- )
653
610
 
654
- topk_weights, topk_ids = select_experts(
655
- hidden_states=x,
656
- router_logits=router_logits,
657
- use_grouped_topk=use_grouped_topk,
658
- top_k=top_k,
659
- renormalize=renormalize,
660
- topk_group=topk_group,
661
- num_expert_group=num_expert_group,
662
- num_fused_shared_experts=num_fused_shared_experts,
663
- custom_routing_function=custom_routing_function,
664
- scoring_func=scoring_func,
665
- correction_bias=correction_bias,
666
- routed_scaling_factor=routed_scaling_factor,
667
- )
611
+ topk_weights, topk_ids, router_logits = topk_output
668
612
 
669
613
  return torch.ops.vllm.fused_marlin_moe(
670
614
  x,