sglang 0.5.0rc2__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. sglang/bench_one_batch.py +0 -6
  2. sglang/bench_one_batch_server.py +7 -2
  3. sglang/bench_serving.py +3 -3
  4. sglang/eval/llama3_eval.py +0 -1
  5. sglang/srt/configs/model_config.py +24 -9
  6. sglang/srt/configs/update_config.py +40 -5
  7. sglang/srt/constrained/xgrammar_backend.py +23 -11
  8. sglang/srt/conversation.py +2 -15
  9. sglang/srt/disaggregation/ascend/conn.py +1 -3
  10. sglang/srt/disaggregation/base/conn.py +1 -0
  11. sglang/srt/disaggregation/decode.py +1 -1
  12. sglang/srt/disaggregation/launch_lb.py +7 -1
  13. sglang/srt/disaggregation/mini_lb.py +11 -5
  14. sglang/srt/disaggregation/mooncake/conn.py +141 -47
  15. sglang/srt/disaggregation/prefill.py +261 -5
  16. sglang/srt/disaggregation/utils.py +2 -1
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
  18. sglang/srt/distributed/device_communicators/pynccl.py +68 -18
  19. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
  20. sglang/srt/distributed/naive_distributed.py +112 -0
  21. sglang/srt/distributed/parallel_state.py +90 -4
  22. sglang/srt/entrypoints/context.py +20 -1
  23. sglang/srt/entrypoints/engine.py +27 -2
  24. sglang/srt/entrypoints/http_server.py +12 -0
  25. sglang/srt/entrypoints/openai/protocol.py +2 -2
  26. sglang/srt/entrypoints/openai/serving_chat.py +22 -6
  27. sglang/srt/entrypoints/openai/serving_completions.py +9 -1
  28. sglang/srt/entrypoints/openai/serving_responses.py +2 -2
  29. sglang/srt/eplb/expert_distribution.py +2 -3
  30. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  31. sglang/srt/hf_transformers_utils.py +24 -0
  32. sglang/srt/host_shared_memory.py +83 -0
  33. sglang/srt/layers/attention/ascend_backend.py +132 -22
  34. sglang/srt/layers/attention/flashattention_backend.py +24 -17
  35. sglang/srt/layers/attention/flashinfer_backend.py +11 -3
  36. sglang/srt/layers/attention/flashinfer_mla_backend.py +226 -76
  37. sglang/srt/layers/attention/triton_backend.py +85 -46
  38. sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
  39. sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
  40. sglang/srt/layers/attention/trtllm_mha_backend.py +390 -30
  41. sglang/srt/layers/attention/trtllm_mla_backend.py +39 -16
  42. sglang/srt/layers/attention/utils.py +94 -15
  43. sglang/srt/layers/attention/vision.py +40 -13
  44. sglang/srt/layers/attention/vision_utils.py +65 -0
  45. sglang/srt/layers/communicator.py +51 -3
  46. sglang/srt/layers/dp_attention.py +23 -4
  47. sglang/srt/layers/elementwise.py +94 -0
  48. sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
  49. sglang/srt/layers/layernorm.py +8 -1
  50. sglang/srt/layers/linear.py +24 -0
  51. sglang/srt/layers/logits_processor.py +5 -1
  52. sglang/srt/layers/moe/__init__.py +31 -0
  53. sglang/srt/layers/moe/ep_moe/layer.py +37 -33
  54. sglang/srt/layers/moe/fused_moe_native.py +14 -25
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
  60. sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
  61. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
  62. sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
  63. sglang/srt/layers/moe/moe_runner/base.py +13 -0
  64. sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
  65. sglang/srt/layers/moe/router.py +15 -9
  66. sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
  67. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
  68. sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
  69. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  70. sglang/srt/layers/moe/topk.py +167 -83
  71. sglang/srt/layers/moe/utils.py +159 -18
  72. sglang/srt/layers/quantization/__init__.py +13 -14
  73. sglang/srt/layers/quantization/awq.py +7 -7
  74. sglang/srt/layers/quantization/base_config.py +2 -6
  75. sglang/srt/layers/quantization/blockwise_int8.py +4 -12
  76. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -28
  77. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -1
  78. sglang/srt/layers/quantization/fp8.py +127 -119
  79. sglang/srt/layers/quantization/fp8_kernel.py +195 -24
  80. sglang/srt/layers/quantization/fp8_utils.py +34 -9
  81. sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
  82. sglang/srt/layers/quantization/gptq.py +5 -4
  83. sglang/srt/layers/quantization/marlin_utils.py +11 -3
  84. sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
  85. sglang/srt/layers/quantization/modelopt_quant.py +165 -68
  86. sglang/srt/layers/quantization/moe_wna16.py +10 -15
  87. sglang/srt/layers/quantization/mxfp4.py +206 -37
  88. sglang/srt/layers/quantization/quark/quark.py +390 -0
  89. sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
  90. sglang/srt/layers/quantization/unquant.py +34 -70
  91. sglang/srt/layers/quantization/utils.py +25 -0
  92. sglang/srt/layers/quantization/w4afp8.py +7 -8
  93. sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
  94. sglang/srt/layers/quantization/w8a8_int8.py +5 -13
  95. sglang/srt/layers/radix_attention.py +6 -0
  96. sglang/srt/layers/rotary_embedding.py +1 -0
  97. sglang/srt/lora/lora_manager.py +21 -22
  98. sglang/srt/lora/lora_registry.py +3 -3
  99. sglang/srt/lora/mem_pool.py +26 -24
  100. sglang/srt/lora/utils.py +10 -12
  101. sglang/srt/managers/cache_controller.py +76 -18
  102. sglang/srt/managers/detokenizer_manager.py +10 -2
  103. sglang/srt/managers/io_struct.py +9 -0
  104. sglang/srt/managers/mm_utils.py +1 -1
  105. sglang/srt/managers/schedule_batch.py +4 -9
  106. sglang/srt/managers/scheduler.py +25 -16
  107. sglang/srt/managers/session_controller.py +1 -1
  108. sglang/srt/managers/template_manager.py +7 -5
  109. sglang/srt/managers/tokenizer_manager.py +60 -21
  110. sglang/srt/managers/tp_worker.py +1 -0
  111. sglang/srt/managers/utils.py +59 -1
  112. sglang/srt/mem_cache/allocator.py +7 -5
  113. sglang/srt/mem_cache/allocator_ascend.py +0 -11
  114. sglang/srt/mem_cache/hicache_storage.py +14 -4
  115. sglang/srt/mem_cache/memory_pool.py +3 -3
  116. sglang/srt/mem_cache/memory_pool_host.py +35 -2
  117. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
  118. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
  119. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
  120. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
  121. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
  122. sglang/srt/model_executor/cuda_graph_runner.py +25 -12
  123. sglang/srt/model_executor/forward_batch_info.py +4 -1
  124. sglang/srt/model_executor/model_runner.py +43 -32
  125. sglang/srt/model_executor/npu_graph_runner.py +94 -0
  126. sglang/srt/model_loader/loader.py +24 -6
  127. sglang/srt/models/dbrx.py +12 -6
  128. sglang/srt/models/deepseek.py +2 -1
  129. sglang/srt/models/deepseek_nextn.py +3 -1
  130. sglang/srt/models/deepseek_v2.py +224 -223
  131. sglang/srt/models/ernie4.py +2 -2
  132. sglang/srt/models/glm4_moe.py +25 -63
  133. sglang/srt/models/glm4v.py +52 -1
  134. sglang/srt/models/glm4v_moe.py +8 -11
  135. sglang/srt/models/gpt_oss.py +34 -74
  136. sglang/srt/models/granitemoe.py +0 -1
  137. sglang/srt/models/grok.py +376 -48
  138. sglang/srt/models/interns1.py +12 -47
  139. sglang/srt/models/internvl.py +6 -51
  140. sglang/srt/models/llama4.py +0 -2
  141. sglang/srt/models/minicpm3.py +0 -1
  142. sglang/srt/models/mixtral.py +0 -2
  143. sglang/srt/models/nemotron_nas.py +435 -0
  144. sglang/srt/models/olmoe.py +0 -1
  145. sglang/srt/models/phi4mm.py +3 -21
  146. sglang/srt/models/qwen2_5_vl.py +2 -0
  147. sglang/srt/models/qwen2_moe.py +3 -18
  148. sglang/srt/models/qwen3.py +2 -2
  149. sglang/srt/models/qwen3_classification.py +7 -1
  150. sglang/srt/models/qwen3_moe.py +9 -38
  151. sglang/srt/models/step3_vl.py +2 -1
  152. sglang/srt/models/xverse_moe.py +11 -5
  153. sglang/srt/multimodal/processors/base_processor.py +3 -3
  154. sglang/srt/multimodal/processors/internvl.py +7 -2
  155. sglang/srt/multimodal/processors/llava.py +11 -7
  156. sglang/srt/offloader.py +433 -0
  157. sglang/srt/operations.py +6 -1
  158. sglang/srt/reasoning_parser.py +4 -3
  159. sglang/srt/server_args.py +237 -104
  160. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -0
  161. sglang/srt/speculative/eagle_utils.py +36 -13
  162. sglang/srt/speculative/eagle_worker.py +56 -3
  163. sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
  164. sglang/srt/two_batch_overlap.py +16 -11
  165. sglang/srt/utils.py +68 -70
  166. sglang/test/runners.py +8 -5
  167. sglang/test/test_block_fp8.py +5 -6
  168. sglang/test/test_block_fp8_ep.py +13 -19
  169. sglang/test/test_cutlass_moe.py +4 -6
  170. sglang/test/test_cutlass_w4a8_moe.py +4 -3
  171. sglang/test/test_fp4_moe.py +4 -3
  172. sglang/test/test_utils.py +7 -0
  173. sglang/utils.py +0 -1
  174. sglang/version.py +1 -1
  175. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/METADATA +7 -7
  176. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/RECORD +179 -161
  177. sglang/srt/layers/quantization/fp4.py +0 -557
  178. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/WHEEL +0 -0
  179. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/licenses/LICENSE +0 -0
  180. {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/top_level.txt +0 -0
@@ -1,55 +1,85 @@
1
+ from __future__ import annotations
2
+
1
3
  import importlib.util
2
4
  from enum import Enum
3
5
  from functools import lru_cache
6
+ from typing import TYPE_CHECKING, Optional
4
7
 
5
8
  from packaging import version as pkg_version
6
9
 
7
- from sglang.srt.managers.schedule_batch import global_server_args_dict
8
-
10
+ from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
11
+ from sglang.srt.layers.dp_attention import (
12
+ get_attention_dp_size,
13
+ is_dp_attention_enabled,
14
+ )
15
+ from sglang.srt.utils import logger
9
16
 
10
- @lru_cache(maxsize=1)
11
- def should_use_flashinfer_trtllm_moe():
12
- result = global_server_args_dict["enable_flashinfer_trtllm_moe"] and (
13
- not importlib.util.find_spec("flashinfer")
14
- or pkg_version.parse(__import__("flashinfer").__version__)
15
- >= pkg_version.parse("0.2.9rc1")
16
- )
17
- return result
17
+ if TYPE_CHECKING:
18
+ from sglang.srt.server_args import ServerArgs
18
19
 
19
20
 
20
21
  class MoeA2ABackend(Enum):
21
22
 
22
- STANDARD = ("standard", "none")
23
+ NONE = "none"
23
24
  DEEPEP = "deepep"
24
25
 
25
26
  @classmethod
26
27
  def _missing_(cls, value):
27
28
  if value is None:
28
- return cls.STANDARD
29
+ return cls.NONE
29
30
  for member in cls:
30
- if value in member.value:
31
+ if value == member.value:
31
32
  return member
32
33
  raise ValueError(f"No {cls.__name__} member for value {value}")
33
34
 
35
+ def is_none(self):
36
+ return self == MoeA2ABackend.NONE
37
+
34
38
  def is_deepep(self):
35
39
  return self == MoeA2ABackend.DEEPEP
36
40
 
37
- def is_standard(self):
38
- return self == MoeA2ABackend.STANDARD
41
+
42
+ class MoeRunnerBackend(Enum):
43
+
44
+ AUTO = "auto"
45
+ TRITON = "triton"
46
+ TRITON_KERNEL = "triton_kernel"
47
+ FLASHINFER = "flashinfer_trtllm"
48
+ FLASHINFER_CUTLASS = "flashinfer_cutlass"
49
+ FLASHINFER_MXFP4 = "flashinfer_mxfp4"
50
+
51
+ def is_auto(self):
52
+ return self == MoeRunnerBackend.AUTO
53
+
54
+ def is_triton(self):
55
+ return self == MoeRunnerBackend.TRITON
56
+
57
+ def is_triton_kernel(self):
58
+ return self == MoeRunnerBackend.TRITON_KERNEL
59
+
60
+ def is_flashinfer_trtllm(self):
61
+ return self == MoeRunnerBackend.FLASHINFER
62
+
63
+ def is_flashinfer_cutlass(self):
64
+ return self == MoeRunnerBackend.FLASHINFER_CUTLASS
65
+
66
+ def is_flashinfer_mxfp4(self):
67
+ return self == MoeRunnerBackend.FLASHINFER_MXFP4
39
68
 
40
69
 
41
70
  class DeepEPMode(Enum):
71
+
42
72
  NORMAL = "normal"
43
73
  LOW_LATENCY = "low_latency"
44
74
  AUTO = "auto"
45
75
 
46
- def enable_normal(self):
76
+ def enable_normal(self) -> bool:
47
77
  return self in [DeepEPMode.NORMAL, DeepEPMode.AUTO]
48
78
 
49
- def enable_low_latency(self):
79
+ def enable_low_latency(self) -> bool:
50
80
  return self in [DeepEPMode.LOW_LATENCY, DeepEPMode.AUTO]
51
81
 
52
- def resolve(self, is_extend_in_batch: bool):
82
+ def resolve(self, is_extend_in_batch: bool) -> DeepEPMode:
53
83
  if self != DeepEPMode.AUTO:
54
84
  return self
55
85
 
@@ -57,3 +87,114 @@ class DeepEPMode(Enum):
57
87
  return DeepEPMode.NORMAL
58
88
  else:
59
89
  return DeepEPMode.LOW_LATENCY
90
+
91
+ def is_normal(self) -> bool:
92
+ return self == DeepEPMode.NORMAL
93
+
94
+ def is_low_latency(self) -> bool:
95
+ return self == DeepEPMode.LOW_LATENCY
96
+
97
+ def is_auto(self) -> bool:
98
+ return self == DeepEPMode.AUTO
99
+
100
+
101
+ MOE_A2A_BACKEND: Optional[MoeA2ABackend] = None
102
+ MOE_RUNNER_BACKEND: Optional[MoeRunnerBackend] = None
103
+ DEEPEP_MODE: Optional[DeepEPMode] = None
104
+ IS_TBO_ENABLED: Optional[bool] = None
105
+ TBO_TOKEN_DISTRIBUTION_THRESHOLD: Optional[float] = None
106
+ DEEPEP_CONFIG: Optional[str] = None
107
+ DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER: Optional[bool] = None
108
+
109
+
110
+ def initialize_moe_config(server_args: ServerArgs):
111
+ global MOE_A2A_BACKEND
112
+ global MOE_RUNNER_BACKEND
113
+ global DEEPEP_MODE
114
+ global DEEPEP_CONFIG
115
+ global IS_TBO_ENABLED
116
+ global TBO_TOKEN_DISTRIBUTION_THRESHOLD
117
+ global DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER
118
+
119
+ MOE_A2A_BACKEND = MoeA2ABackend(server_args.moe_a2a_backend)
120
+ MOE_RUNNER_BACKEND = MoeRunnerBackend(server_args.moe_runner_backend)
121
+ DEEPEP_MODE = DeepEPMode(server_args.deepep_mode)
122
+ DEEPEP_CONFIG = server_args.deepep_config or ""
123
+ IS_TBO_ENABLED = server_args.enable_two_batch_overlap
124
+ TBO_TOKEN_DISTRIBUTION_THRESHOLD = server_args.tbo_token_distribution_threshold
125
+ DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER = (
126
+ server_args.disable_flashinfer_cutlass_moe_fp4_allgather
127
+ )
128
+
129
+
130
+ def get_moe_a2a_backend() -> MoeA2ABackend:
131
+ global MOE_A2A_BACKEND
132
+ if MOE_A2A_BACKEND is None:
133
+ logger.warning("MOE_A2A_BACKEND is not initialized, using default backend")
134
+ MOE_A2A_BACKEND = MoeA2ABackend(None)
135
+ return MOE_A2A_BACKEND
136
+
137
+
138
+ def get_moe_runner_backend() -> MoeRunnerBackend:
139
+ global MOE_RUNNER_BACKEND
140
+ if MOE_RUNNER_BACKEND is None:
141
+ logger.warning("MOE_RUNNER_BACKEND is not initialized, using triton backend")
142
+ MOE_RUNNER_BACKEND = MoeRunnerBackend("triton")
143
+ return MOE_RUNNER_BACKEND
144
+
145
+
146
+ def get_deepep_mode() -> DeepEPMode:
147
+ global DEEPEP_MODE
148
+ if DEEPEP_MODE is None:
149
+ logger.warning("DEEPEP_MODE is not initialized, using auto mode")
150
+ DEEPEP_MODE = DeepEPMode("auto")
151
+ return DEEPEP_MODE
152
+
153
+
154
+ def get_deepep_config() -> str:
155
+ global DEEPEP_CONFIG
156
+ if DEEPEP_CONFIG is None:
157
+ logger.warning("DEEPEP_CONFIG is not initialized, using default config")
158
+ DEEPEP_CONFIG = ""
159
+ return DEEPEP_CONFIG
160
+
161
+
162
+ def is_tbo_enabled() -> bool:
163
+ global IS_TBO_ENABLED
164
+ if IS_TBO_ENABLED is None:
165
+ logger.warning("IS_TBO_ENABLED is not initialized, using False")
166
+ IS_TBO_ENABLED = False
167
+ return IS_TBO_ENABLED
168
+
169
+
170
+ def get_tbo_token_distribution_threshold() -> float:
171
+ global TBO_TOKEN_DISTRIBUTION_THRESHOLD
172
+ if TBO_TOKEN_DISTRIBUTION_THRESHOLD is None:
173
+ logger.warning(
174
+ "TBO_TOKEN_DISTRIBUTION_THRESHOLD is not initialized, using 0.48"
175
+ )
176
+ TBO_TOKEN_DISTRIBUTION_THRESHOLD = 0.48
177
+ return TBO_TOKEN_DISTRIBUTION_THRESHOLD
178
+
179
+
180
+ @lru_cache(maxsize=1)
181
+ def should_use_flashinfer_trtllm_moe():
182
+ result = get_moe_runner_backend().is_flashinfer_trtllm() and (
183
+ not importlib.util.find_spec("flashinfer")
184
+ or pkg_version.parse(__import__("flashinfer").__version__)
185
+ >= pkg_version.parse("0.2.9rc1")
186
+ )
187
+ return result
188
+
189
+
190
+ @lru_cache(maxsize=1)
191
+ def should_use_flashinfer_cutlass_moe_fp4_allgather():
192
+ """
193
+ Perform FP4 quantize before all-gather for flashinfer cutlass moe to reduce communication cost for high-throughput serving.
194
+ """
195
+ return (
196
+ not DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER
197
+ and get_moe_runner_backend().is_flashinfer_cutlass()
198
+ and is_dp_attention_enabled()
199
+ and get_moe_expert_parallel_world_size() == get_attention_dp_size()
200
+ )
@@ -16,7 +16,6 @@ try:
16
16
  )
17
17
  from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig
18
18
  from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
19
- from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config
20
19
  from vllm.model_executor.layers.quantization.gguf import GGUFConfig
21
20
  from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
22
21
  GPTQMarlin24Config,
@@ -37,9 +36,9 @@ except ImportError as e:
37
36
 
38
37
  AQLMConfig = BitsAndBytesConfig = CompressedTensorsConfig = DeepSpeedFPConfig = (
39
38
  ExpertsInt8Config
40
- ) = FBGEMMFp8Config = GGUFConfig = GPTQMarlin24Config = MarlinConfig = QQQConfig = (
41
- Int8TpuConfig
42
- ) = DummyConfig
39
+ ) = GGUFConfig = GPTQMarlin24Config = MarlinConfig = QQQConfig = Int8TpuConfig = (
40
+ DummyConfig
41
+ )
43
42
 
44
43
 
45
44
  from sglang.srt.layers.quantization.awq import AWQConfig, AWQMarlinConfig
@@ -48,13 +47,8 @@ from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config
48
47
  from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import (
49
48
  CompressedTensorsConfig,
50
49
  )
51
- from sglang.srt.utils import is_cuda, is_hip, mxfp_supported
52
-
53
- is_mxfp_supported = mxfp_supported()
54
- if is_mxfp_supported:
55
- from sglang.srt.layers.quantization.fp4 import MxFp4Config
56
-
57
50
  from sglang.srt.layers.quantization.fp8 import Fp8Config
51
+ from sglang.srt.layers.quantization.fpgemm_fp8 import FBGEMMFp8Config
58
52
  from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
59
53
  from sglang.srt.layers.quantization.modelopt_quant import (
60
54
  ModelOptFp4Config,
@@ -67,6 +61,9 @@ from sglang.srt.layers.quantization.qoq import QoQConfig
67
61
  from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config
68
62
  from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
69
63
  from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
64
+ from sglang.srt.utils import is_cuda, is_hip, mxfp_supported
65
+
66
+ _is_mxfp_supported = mxfp_supported()
70
67
 
71
68
  if TYPE_CHECKING:
72
69
  from sglang.srt.layers.moe.topk import TopKOutput
@@ -88,6 +85,7 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
88
85
  "qoq": QoQConfig,
89
86
  "w4afp8": W4AFp8Config,
90
87
  "petit_nvfp4": PetitNvFp4Config,
88
+ "fbgemm_fp8": FBGEMMFp8Config,
91
89
  }
92
90
 
93
91
 
@@ -98,11 +96,13 @@ if is_cuda():
98
96
  "mxfp4": Mxfp4Config,
99
97
  }
100
98
  )
101
- elif is_mxfp_supported and is_hip():
99
+ elif _is_mxfp_supported and is_hip():
100
+ from sglang.srt.layers.quantization.quark.quark import QuarkConfig
101
+
102
102
  BASE_QUANTIZATION_METHODS.update(
103
103
  {
104
- "quark": MxFp4Config,
105
- "mxfp4": MxFp4Config,
104
+ "quark": QuarkConfig,
105
+ "mxfp4": Mxfp4Config,
106
106
  }
107
107
  )
108
108
  # VLLM-dependent quantization methods
@@ -110,7 +110,6 @@ VLLM_QUANTIZATION_METHODS = {
110
110
  "aqlm": AQLMConfig,
111
111
  "deepspeedfp": DeepSpeedFPConfig,
112
112
  "tpu_int8": Int8TpuConfig,
113
- "fbgemm_fp8": FBGEMMFp8Config,
114
113
  "marlin": MarlinConfig,
115
114
  "gguf": GGUFConfig,
116
115
  "gptq_marlin_24": GPTQMarlin24Config,
@@ -33,7 +33,8 @@ from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
33
33
  from sglang.srt.layers.quantization.utils import get_scalar_types, replace_parameter
34
34
 
35
35
  if TYPE_CHECKING:
36
- from sglang.srt.layers.moe.topk import TopKOutput
36
+ from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
37
+ from sglang.srt.layers.moe.topk import StandardTopKOutput
37
38
 
38
39
  from sglang.srt.utils import is_cuda, is_hip
39
40
 
@@ -739,13 +740,12 @@ class AWQMoEMethod(FusedMoEMethodBase):
739
740
  self,
740
741
  layer: torch.nn.Module,
741
742
  x: torch.Tensor,
742
- topk_output: TopKOutput,
743
- *,
744
- activation: str = "silu",
745
- **kwargs,
743
+ topk_output: StandardTopKOutput,
744
+ moe_runner_config: MoeRunnerConfig,
746
745
  ) -> torch.Tensor:
747
-
748
- assert activation == "silu", "Only SiLU activation is supported."
746
+ assert (
747
+ moe_runner_config.activation == "silu"
748
+ ), "Only SiLU activation is supported."
749
749
 
750
750
  # The input must currently be float16
751
751
  orig_dtype = x.dtype
@@ -9,6 +9,7 @@ import torch
9
9
  from torch import nn
10
10
 
11
11
  if TYPE_CHECKING:
12
+ from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
12
13
  from sglang.srt.layers.moe.topk import TopKOutput
13
14
 
14
15
 
@@ -100,12 +101,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
100
101
  layer: torch.nn.Module,
101
102
  x: torch.Tensor,
102
103
  topk_output: TopKOutput,
103
- *,
104
- activation: str = "silu",
105
- apply_router_weight_on_input: bool = False,
106
- inplace: bool = True,
107
- no_combine: bool = False,
108
- routed_scaling_factor: Optional[float] = None,
104
+ moe_runner_config: MoeRunnerConfig,
109
105
  ) -> torch.Tensor:
110
106
  raise NotImplementedError
111
107
 
@@ -3,7 +3,7 @@
3
3
  from __future__ import annotations
4
4
 
5
5
  import logging
6
- from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
6
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional
7
7
 
8
8
  import torch
9
9
  from torch.nn import Module
@@ -22,6 +22,7 @@ from sglang.srt.layers.quantization.utils import is_layer_skipped
22
22
  from sglang.srt.utils import set_weight_attrs
23
23
 
24
24
  if TYPE_CHECKING:
25
+ from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
25
26
  from sglang.srt.layers.moe.topk import TopKOutput
26
27
 
27
28
  ACTIVATION_SCHEMES = ["static", "dynamic"]
@@ -348,12 +349,7 @@ class BlockInt8MoEMethod(FusedMoEMethodBase):
348
349
  layer: torch.nn.Module,
349
350
  x: torch.Tensor,
350
351
  topk_output: TopKOutput,
351
- *,
352
- activation: str = "silu",
353
- apply_router_weight_on_input: bool = False,
354
- inplace: bool = True,
355
- no_combine: bool = False,
356
- routed_scaling_factor: Optional[float] = None,
352
+ moe_runner_config: MoeRunnerConfig,
357
353
  ) -> torch.Tensor:
358
354
  from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
359
355
 
@@ -363,15 +359,11 @@ class BlockInt8MoEMethod(FusedMoEMethodBase):
363
359
  layer.w13_weight,
364
360
  layer.w2_weight,
365
361
  topk_output=topk_output,
366
- inplace=inplace,
367
- activation=activation,
368
- apply_router_weight_on_input=apply_router_weight_on_input,
362
+ moe_runner_config=moe_runner_config,
369
363
  use_int8_w8a8=True,
370
364
  w1_scale=(layer.w13_weight_scale_inv),
371
365
  w2_scale=(layer.w2_weight_scale_inv),
372
366
  a1_scale=layer.w13_input_scale,
373
367
  a2_scale=layer.w2_input_scale,
374
368
  block_shape=self.quant_config.weight_block_size,
375
- no_combine=no_combine,
376
- routed_scaling_factor=routed_scaling_factor,
377
369
  )
@@ -19,15 +19,30 @@ from sglang.srt.layers.quantization.utils import (
19
19
  per_tensor_dequantize,
20
20
  replace_parameter,
21
21
  )
22
- from sglang.srt.utils import is_cpu, is_cuda, is_hip, is_npu, set_weight_attrs
22
+ from sglang.srt.utils import (
23
+ get_bool_env_var,
24
+ is_cpu,
25
+ is_cuda,
26
+ is_hip,
27
+ is_npu,
28
+ set_weight_attrs,
29
+ )
23
30
 
24
31
  if TYPE_CHECKING:
25
32
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
33
+ from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
26
34
  from sglang.srt.layers.moe.topk import TopKOutput
27
35
  from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import (
28
36
  CompressedTensorsConfig,
29
37
  )
30
38
 
39
+ _is_hip = is_hip()
40
+ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
41
+
42
+ if _use_aiter:
43
+ from aiter.ops.shuffle import shuffle_weight
44
+
45
+ from sglang.srt.layers.moe.rocm_moe_utils import rocm_fused_experts_tkw1
31
46
 
32
47
  try:
33
48
  import vllm
@@ -264,37 +279,66 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
264
279
  max_w13_scales, requires_grad=False
265
280
  )
266
281
 
282
+ if _use_aiter:
283
+ with torch.no_grad():
284
+ # Pre-shuffle weights
285
+ layer.w13_weight = torch.nn.Parameter(
286
+ shuffle_weight(layer.w13_weight.data, (16, 16)),
287
+ requires_grad=False,
288
+ )
289
+ torch.cuda.empty_cache()
290
+ layer.w2_weight = torch.nn.Parameter(
291
+ shuffle_weight(layer.w2_weight.data, (16, 16)),
292
+ requires_grad=False,
293
+ )
294
+ torch.cuda.empty_cache()
295
+
267
296
  def apply(
268
297
  self,
269
298
  layer: torch.nn.Module,
270
299
  x: torch.Tensor,
271
300
  topk_output: TopKOutput,
272
- *,
273
- activation: str = "silu",
274
- apply_router_weight_on_input: bool = False,
275
- inplace: bool = True,
276
- no_combine: bool = False,
277
- routed_scaling_factor: Optional[float] = None,
301
+ moe_runner_config: MoeRunnerConfig,
278
302
  ) -> torch.Tensor:
279
303
  from sglang.srt.layers.moe.fused_moe_triton import fused_experts
280
304
 
281
- return fused_experts(
282
- x,
283
- layer.w13_weight,
284
- layer.w2_weight,
285
- topk_output=topk_output,
286
- inplace=inplace,
287
- activation=activation,
288
- use_fp8_w8a8=True,
289
- per_channel_quant=self.weight_quant.strategy
290
- == QuantizationStrategy.CHANNEL,
291
- w1_scale=layer.w13_weight_scale,
292
- w2_scale=layer.w2_weight_scale,
293
- a1_scale=layer.w13_input_scale,
294
- a2_scale=layer.w2_input_scale,
295
- apply_router_weight_on_input=apply_router_weight_on_input,
296
- routed_scaling_factor=routed_scaling_factor,
297
- )
305
+ if (
306
+ _use_aiter
307
+ and self.weight_quant.strategy == QuantizationStrategy.CHANNEL
308
+ and moe_runner_config.apply_router_weight_on_input
309
+ ):
310
+ topk_weights, topk_ids, _ = topk_output
311
+ return rocm_fused_experts_tkw1(
312
+ hidden_states=x,
313
+ w1=layer.w13_weight,
314
+ w2=layer.w2_weight,
315
+ topk_weights=topk_weights,
316
+ topk_ids=topk_ids,
317
+ activation=moe_runner_config.activation,
318
+ apply_router_weight_on_input=moe_runner_config.apply_router_weight_on_input,
319
+ use_fp8_w8a8=True,
320
+ per_channel_quant=self.weight_quant.strategy
321
+ == QuantizationStrategy.CHANNEL,
322
+ w1_scale=layer.w13_weight_scale,
323
+ w2_scale=layer.w2_weight_scale,
324
+ a1_scale=layer.w13_input_scale,
325
+ a2_scale=layer.w2_input_scale,
326
+ )
327
+ else:
328
+ return fused_experts(
329
+ x,
330
+ layer.w13_weight,
331
+ layer.w2_weight,
332
+ topk_output=topk_output,
333
+ moe_runner_config=moe_runner_config,
334
+ use_fp8_w8a8=True,
335
+ per_channel_quant=self.weight_quant.strategy
336
+ == QuantizationStrategy.CHANNEL,
337
+ w1_scale=layer.w13_weight_scale,
338
+ w2_scale=layer.w2_weight_scale,
339
+ a1_scale=layer.w13_input_scale,
340
+ a2_scale=layer.w2_input_scale,
341
+ )
298
342
 
299
343
 
300
344
  class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
@@ -601,12 +645,12 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
601
645
  layer: torch.nn.Module,
602
646
  x: torch.Tensor,
603
647
  topk_output: TopKOutput,
604
- *,
605
- activation: str = "silu",
606
- **kwargs,
648
+ moe_runner_config: MoeRunnerConfig,
607
649
  ) -> torch.Tensor:
608
650
 
609
- assert activation == "silu", "Only SiLU activation is supported."
651
+ assert (
652
+ moe_runner_config.activation == "silu"
653
+ ), "Only SiLU activation is supported."
610
654
 
611
655
  topk_weights, topk_ids, router_logits = topk_output
612
656
 
@@ -7,7 +7,8 @@ logger = logging.getLogger(__name__)
7
7
 
8
8
  def _compute_enable_deep_gemm():
9
9
  sm_version = get_device_sm()
10
- if sm_version < 90:
10
+ # TODO fix blackwell fp8
11
+ if sm_version != 90:
11
12
  return False
12
13
 
13
14
  try: