sglang 0.5.0rc1__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (203) hide show
  1. sglang/bench_one_batch.py +0 -7
  2. sglang/bench_one_batch_server.py +7 -2
  3. sglang/bench_serving.py +3 -3
  4. sglang/eval/llama3_eval.py +0 -1
  5. sglang/srt/configs/model_config.py +25 -9
  6. sglang/srt/configs/update_config.py +40 -5
  7. sglang/srt/constrained/xgrammar_backend.py +23 -11
  8. sglang/srt/conversation.py +2 -15
  9. sglang/srt/disaggregation/ascend/conn.py +1 -3
  10. sglang/srt/disaggregation/base/conn.py +1 -0
  11. sglang/srt/disaggregation/decode.py +1 -2
  12. sglang/srt/disaggregation/launch_lb.py +7 -1
  13. sglang/srt/disaggregation/mini_lb.py +11 -5
  14. sglang/srt/disaggregation/mooncake/conn.py +141 -47
  15. sglang/srt/disaggregation/prefill.py +261 -5
  16. sglang/srt/disaggregation/utils.py +2 -1
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
  18. sglang/srt/distributed/device_communicators/pynccl.py +68 -18
  19. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
  20. sglang/srt/distributed/naive_distributed.py +112 -0
  21. sglang/srt/distributed/parallel_state.py +90 -4
  22. sglang/srt/entrypoints/context.py +20 -1
  23. sglang/srt/entrypoints/engine.py +29 -4
  24. sglang/srt/entrypoints/http_server.py +76 -0
  25. sglang/srt/entrypoints/openai/protocol.py +4 -2
  26. sglang/srt/entrypoints/openai/serving_chat.py +23 -6
  27. sglang/srt/entrypoints/openai/serving_completions.py +10 -1
  28. sglang/srt/entrypoints/openai/serving_responses.py +2 -2
  29. sglang/srt/eplb/expert_distribution.py +2 -3
  30. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  31. sglang/srt/hf_transformers_utils.py +24 -0
  32. sglang/srt/host_shared_memory.py +83 -0
  33. sglang/srt/layers/attention/ascend_backend.py +132 -22
  34. sglang/srt/layers/attention/flashattention_backend.py +24 -17
  35. sglang/srt/layers/attention/flashinfer_backend.py +14 -3
  36. sglang/srt/layers/attention/flashinfer_mla_backend.py +227 -76
  37. sglang/srt/layers/attention/triton_backend.py +109 -73
  38. sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
  39. sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
  40. sglang/srt/layers/attention/trtllm_mha_backend.py +398 -36
  41. sglang/srt/layers/attention/trtllm_mla_backend.py +49 -19
  42. sglang/srt/layers/attention/utils.py +94 -15
  43. sglang/srt/layers/attention/vision.py +40 -13
  44. sglang/srt/layers/attention/vision_utils.py +65 -0
  45. sglang/srt/layers/communicator.py +58 -10
  46. sglang/srt/layers/dp_attention.py +137 -27
  47. sglang/srt/layers/elementwise.py +94 -0
  48. sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
  49. sglang/srt/layers/layernorm.py +8 -1
  50. sglang/srt/layers/linear.py +24 -0
  51. sglang/srt/layers/logits_processor.py +16 -18
  52. sglang/srt/layers/moe/__init__.py +31 -0
  53. sglang/srt/layers/moe/ep_moe/layer.py +37 -33
  54. sglang/srt/layers/moe/fused_moe_native.py +14 -25
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  63. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  64. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  65. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  66. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  67. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  68. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  69. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  70. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
  71. sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
  72. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
  73. sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
  74. sglang/srt/layers/moe/moe_runner/base.py +13 -0
  75. sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
  76. sglang/srt/layers/moe/router.py +15 -9
  77. sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
  78. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
  79. sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
  80. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  81. sglang/srt/layers/moe/topk.py +167 -83
  82. sglang/srt/layers/moe/utils.py +159 -18
  83. sglang/srt/layers/multimodal.py +156 -40
  84. sglang/srt/layers/quantization/__init__.py +18 -46
  85. sglang/srt/layers/quantization/awq.py +22 -23
  86. sglang/srt/layers/quantization/base_config.py +2 -6
  87. sglang/srt/layers/quantization/blockwise_int8.py +4 -12
  88. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -29
  89. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -1
  90. sglang/srt/layers/quantization/fp8.py +127 -119
  91. sglang/srt/layers/quantization/fp8_kernel.py +195 -24
  92. sglang/srt/layers/quantization/fp8_utils.py +34 -9
  93. sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
  94. sglang/srt/layers/quantization/gptq.py +17 -21
  95. sglang/srt/layers/quantization/marlin_utils.py +26 -8
  96. sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
  97. sglang/srt/layers/quantization/modelopt_quant.py +217 -98
  98. sglang/srt/layers/quantization/moe_wna16.py +10 -15
  99. sglang/srt/layers/quantization/mxfp4.py +222 -39
  100. sglang/srt/layers/quantization/quark/quark.py +390 -0
  101. sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
  102. sglang/srt/layers/quantization/unquant.py +34 -70
  103. sglang/srt/layers/quantization/utils.py +77 -2
  104. sglang/srt/layers/quantization/w4afp8.py +7 -8
  105. sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
  106. sglang/srt/layers/quantization/w8a8_int8.py +5 -13
  107. sglang/srt/layers/radix_attention.py +6 -0
  108. sglang/srt/layers/rotary_embedding.py +1 -0
  109. sglang/srt/layers/sampler.py +5 -2
  110. sglang/srt/lora/layers.py +6 -2
  111. sglang/srt/lora/lora_manager.py +21 -22
  112. sglang/srt/lora/lora_registry.py +3 -3
  113. sglang/srt/lora/mem_pool.py +26 -24
  114. sglang/srt/lora/utils.py +10 -12
  115. sglang/srt/managers/cache_controller.py +80 -19
  116. sglang/srt/managers/detokenizer_manager.py +10 -2
  117. sglang/srt/managers/io_struct.py +23 -0
  118. sglang/srt/managers/mm_utils.py +1 -1
  119. sglang/srt/managers/schedule_batch.py +22 -48
  120. sglang/srt/managers/scheduler.py +28 -20
  121. sglang/srt/managers/session_controller.py +1 -1
  122. sglang/srt/managers/template_manager.py +7 -5
  123. sglang/srt/managers/tokenizer_manager.py +88 -39
  124. sglang/srt/managers/tp_worker.py +1 -0
  125. sglang/srt/managers/utils.py +59 -1
  126. sglang/srt/mem_cache/allocator.py +10 -157
  127. sglang/srt/mem_cache/allocator_ascend.py +147 -0
  128. sglang/srt/mem_cache/chunk_cache.py +1 -1
  129. sglang/srt/mem_cache/hicache_storage.py +14 -4
  130. sglang/srt/mem_cache/memory_pool.py +3 -3
  131. sglang/srt/mem_cache/memory_pool_host.py +35 -2
  132. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
  133. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
  134. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
  135. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
  136. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
  137. sglang/srt/model_executor/cuda_graph_runner.py +33 -33
  138. sglang/srt/model_executor/forward_batch_info.py +11 -10
  139. sglang/srt/model_executor/model_runner.py +93 -78
  140. sglang/srt/model_executor/npu_graph_runner.py +94 -0
  141. sglang/srt/model_loader/loader.py +24 -6
  142. sglang/srt/models/dbrx.py +12 -6
  143. sglang/srt/models/deepseek.py +2 -1
  144. sglang/srt/models/deepseek_nextn.py +5 -2
  145. sglang/srt/models/deepseek_v2.py +226 -223
  146. sglang/srt/models/ernie4.py +2 -2
  147. sglang/srt/models/glm4_moe.py +27 -65
  148. sglang/srt/models/glm4_moe_nextn.py +2 -1
  149. sglang/srt/models/glm4v.py +52 -1
  150. sglang/srt/models/glm4v_moe.py +8 -11
  151. sglang/srt/models/gpt_oss.py +41 -76
  152. sglang/srt/models/granitemoe.py +0 -1
  153. sglang/srt/models/grok.py +376 -48
  154. sglang/srt/models/interns1.py +12 -47
  155. sglang/srt/models/internvl.py +6 -51
  156. sglang/srt/models/llama.py +10 -2
  157. sglang/srt/models/llama4.py +18 -7
  158. sglang/srt/models/minicpm3.py +0 -1
  159. sglang/srt/models/mixtral.py +0 -2
  160. sglang/srt/models/nemotron_nas.py +435 -0
  161. sglang/srt/models/olmoe.py +0 -1
  162. sglang/srt/models/phi4mm.py +3 -21
  163. sglang/srt/models/qwen2.py +2 -2
  164. sglang/srt/models/qwen2_5_vl.py +2 -0
  165. sglang/srt/models/qwen2_moe.py +23 -23
  166. sglang/srt/models/qwen3.py +2 -2
  167. sglang/srt/models/qwen3_classification.py +84 -0
  168. sglang/srt/models/qwen3_moe.py +27 -43
  169. sglang/srt/models/step3_vl.py +8 -3
  170. sglang/srt/models/xverse_moe.py +11 -5
  171. sglang/srt/multimodal/processors/base_processor.py +3 -3
  172. sglang/srt/multimodal/processors/internvl.py +7 -2
  173. sglang/srt/multimodal/processors/llava.py +11 -7
  174. sglang/srt/offloader.py +433 -0
  175. sglang/srt/operations.py +22 -2
  176. sglang/srt/reasoning_parser.py +4 -3
  177. sglang/srt/sampling/sampling_batch_info.py +7 -4
  178. sglang/srt/server_args.py +264 -105
  179. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +8 -21
  180. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
  181. sglang/srt/speculative/eagle_utils.py +36 -13
  182. sglang/srt/speculative/eagle_worker.py +56 -3
  183. sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
  184. sglang/srt/two_batch_overlap.py +20 -19
  185. sglang/srt/utils.py +68 -70
  186. sglang/test/runners.py +8 -5
  187. sglang/test/test_block_fp8.py +5 -6
  188. sglang/test/test_block_fp8_ep.py +13 -19
  189. sglang/test/test_cutlass_moe.py +4 -6
  190. sglang/test/test_cutlass_w4a8_moe.py +4 -3
  191. sglang/test/test_fp4_moe.py +4 -3
  192. sglang/test/test_marlin_moe.py +1 -1
  193. sglang/test/test_marlin_utils.py +1 -1
  194. sglang/test/test_utils.py +7 -0
  195. sglang/utils.py +0 -1
  196. sglang/version.py +1 -1
  197. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/METADATA +11 -11
  198. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/RECORD +201 -171
  199. sglang/srt/layers/quantization/fp4.py +0 -557
  200. sglang/srt/layers/quantization/scalar_type.py +0 -352
  201. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/WHEEL +0 -0
  202. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/licenses/LICENSE +0 -0
  203. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/top_level.txt +0 -0
@@ -1,55 +1,85 @@
1
+ from __future__ import annotations
2
+
1
3
  import importlib.util
2
4
  from enum import Enum
3
5
  from functools import lru_cache
6
+ from typing import TYPE_CHECKING, Optional
4
7
 
5
8
  from packaging import version as pkg_version
6
9
 
7
- from sglang.srt.managers.schedule_batch import global_server_args_dict
8
-
10
+ from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
11
+ from sglang.srt.layers.dp_attention import (
12
+ get_attention_dp_size,
13
+ is_dp_attention_enabled,
14
+ )
15
+ from sglang.srt.utils import logger
9
16
 
10
- @lru_cache(maxsize=1)
11
- def should_use_flashinfer_trtllm_moe():
12
- result = global_server_args_dict["enable_flashinfer_trtllm_moe"] and (
13
- not importlib.util.find_spec("flashinfer")
14
- or pkg_version.parse(__import__("flashinfer").__version__)
15
- >= pkg_version.parse("0.2.9rc1")
16
- )
17
- return result
17
+ if TYPE_CHECKING:
18
+ from sglang.srt.server_args import ServerArgs
18
19
 
19
20
 
20
21
  class MoeA2ABackend(Enum):
21
22
 
22
- STANDARD = ("standard", "none")
23
+ NONE = "none"
23
24
  DEEPEP = "deepep"
24
25
 
25
26
  @classmethod
26
27
  def _missing_(cls, value):
27
28
  if value is None:
28
- return cls.STANDARD
29
+ return cls.NONE
29
30
  for member in cls:
30
- if value in member.value:
31
+ if value == member.value:
31
32
  return member
32
33
  raise ValueError(f"No {cls.__name__} member for value {value}")
33
34
 
35
+ def is_none(self):
36
+ return self == MoeA2ABackend.NONE
37
+
34
38
  def is_deepep(self):
35
39
  return self == MoeA2ABackend.DEEPEP
36
40
 
37
- def is_standard(self):
38
- return self == MoeA2ABackend.STANDARD
41
+
42
+ class MoeRunnerBackend(Enum):
43
+
44
+ AUTO = "auto"
45
+ TRITON = "triton"
46
+ TRITON_KERNEL = "triton_kernel"
47
+ FLASHINFER = "flashinfer_trtllm"
48
+ FLASHINFER_CUTLASS = "flashinfer_cutlass"
49
+ FLASHINFER_MXFP4 = "flashinfer_mxfp4"
50
+
51
+ def is_auto(self):
52
+ return self == MoeRunnerBackend.AUTO
53
+
54
+ def is_triton(self):
55
+ return self == MoeRunnerBackend.TRITON
56
+
57
+ def is_triton_kernel(self):
58
+ return self == MoeRunnerBackend.TRITON_KERNEL
59
+
60
+ def is_flashinfer_trtllm(self):
61
+ return self == MoeRunnerBackend.FLASHINFER
62
+
63
+ def is_flashinfer_cutlass(self):
64
+ return self == MoeRunnerBackend.FLASHINFER_CUTLASS
65
+
66
+ def is_flashinfer_mxfp4(self):
67
+ return self == MoeRunnerBackend.FLASHINFER_MXFP4
39
68
 
40
69
 
41
70
  class DeepEPMode(Enum):
71
+
42
72
  NORMAL = "normal"
43
73
  LOW_LATENCY = "low_latency"
44
74
  AUTO = "auto"
45
75
 
46
- def enable_normal(self):
76
+ def enable_normal(self) -> bool:
47
77
  return self in [DeepEPMode.NORMAL, DeepEPMode.AUTO]
48
78
 
49
- def enable_low_latency(self):
79
+ def enable_low_latency(self) -> bool:
50
80
  return self in [DeepEPMode.LOW_LATENCY, DeepEPMode.AUTO]
51
81
 
52
- def resolve(self, is_extend_in_batch: bool):
82
+ def resolve(self, is_extend_in_batch: bool) -> DeepEPMode:
53
83
  if self != DeepEPMode.AUTO:
54
84
  return self
55
85
 
@@ -57,3 +87,114 @@ class DeepEPMode(Enum):
57
87
  return DeepEPMode.NORMAL
58
88
  else:
59
89
  return DeepEPMode.LOW_LATENCY
90
+
91
+ def is_normal(self) -> bool:
92
+ return self == DeepEPMode.NORMAL
93
+
94
+ def is_low_latency(self) -> bool:
95
+ return self == DeepEPMode.LOW_LATENCY
96
+
97
+ def is_auto(self) -> bool:
98
+ return self == DeepEPMode.AUTO
99
+
100
+
101
+ MOE_A2A_BACKEND: Optional[MoeA2ABackend] = None
102
+ MOE_RUNNER_BACKEND: Optional[MoeRunnerBackend] = None
103
+ DEEPEP_MODE: Optional[DeepEPMode] = None
104
+ IS_TBO_ENABLED: Optional[bool] = None
105
+ TBO_TOKEN_DISTRIBUTION_THRESHOLD: Optional[float] = None
106
+ DEEPEP_CONFIG: Optional[str] = None
107
+ DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER: Optional[bool] = None
108
+
109
+
110
+ def initialize_moe_config(server_args: ServerArgs):
111
+ global MOE_A2A_BACKEND
112
+ global MOE_RUNNER_BACKEND
113
+ global DEEPEP_MODE
114
+ global DEEPEP_CONFIG
115
+ global IS_TBO_ENABLED
116
+ global TBO_TOKEN_DISTRIBUTION_THRESHOLD
117
+ global DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER
118
+
119
+ MOE_A2A_BACKEND = MoeA2ABackend(server_args.moe_a2a_backend)
120
+ MOE_RUNNER_BACKEND = MoeRunnerBackend(server_args.moe_runner_backend)
121
+ DEEPEP_MODE = DeepEPMode(server_args.deepep_mode)
122
+ DEEPEP_CONFIG = server_args.deepep_config or ""
123
+ IS_TBO_ENABLED = server_args.enable_two_batch_overlap
124
+ TBO_TOKEN_DISTRIBUTION_THRESHOLD = server_args.tbo_token_distribution_threshold
125
+ DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER = (
126
+ server_args.disable_flashinfer_cutlass_moe_fp4_allgather
127
+ )
128
+
129
+
130
+ def get_moe_a2a_backend() -> MoeA2ABackend:
131
+ global MOE_A2A_BACKEND
132
+ if MOE_A2A_BACKEND is None:
133
+ logger.warning("MOE_A2A_BACKEND is not initialized, using default backend")
134
+ MOE_A2A_BACKEND = MoeA2ABackend(None)
135
+ return MOE_A2A_BACKEND
136
+
137
+
138
+ def get_moe_runner_backend() -> MoeRunnerBackend:
139
+ global MOE_RUNNER_BACKEND
140
+ if MOE_RUNNER_BACKEND is None:
141
+ logger.warning("MOE_RUNNER_BACKEND is not initialized, using triton backend")
142
+ MOE_RUNNER_BACKEND = MoeRunnerBackend("triton")
143
+ return MOE_RUNNER_BACKEND
144
+
145
+
146
+ def get_deepep_mode() -> DeepEPMode:
147
+ global DEEPEP_MODE
148
+ if DEEPEP_MODE is None:
149
+ logger.warning("DEEPEP_MODE is not initialized, using auto mode")
150
+ DEEPEP_MODE = DeepEPMode("auto")
151
+ return DEEPEP_MODE
152
+
153
+
154
+ def get_deepep_config() -> str:
155
+ global DEEPEP_CONFIG
156
+ if DEEPEP_CONFIG is None:
157
+ logger.warning("DEEPEP_CONFIG is not initialized, using default config")
158
+ DEEPEP_CONFIG = ""
159
+ return DEEPEP_CONFIG
160
+
161
+
162
+ def is_tbo_enabled() -> bool:
163
+ global IS_TBO_ENABLED
164
+ if IS_TBO_ENABLED is None:
165
+ logger.warning("IS_TBO_ENABLED is not initialized, using False")
166
+ IS_TBO_ENABLED = False
167
+ return IS_TBO_ENABLED
168
+
169
+
170
+ def get_tbo_token_distribution_threshold() -> float:
171
+ global TBO_TOKEN_DISTRIBUTION_THRESHOLD
172
+ if TBO_TOKEN_DISTRIBUTION_THRESHOLD is None:
173
+ logger.warning(
174
+ "TBO_TOKEN_DISTRIBUTION_THRESHOLD is not initialized, using 0.48"
175
+ )
176
+ TBO_TOKEN_DISTRIBUTION_THRESHOLD = 0.48
177
+ return TBO_TOKEN_DISTRIBUTION_THRESHOLD
178
+
179
+
180
+ @lru_cache(maxsize=1)
181
+ def should_use_flashinfer_trtllm_moe():
182
+ result = get_moe_runner_backend().is_flashinfer_trtllm() and (
183
+ not importlib.util.find_spec("flashinfer")
184
+ or pkg_version.parse(__import__("flashinfer").__version__)
185
+ >= pkg_version.parse("0.2.9rc1")
186
+ )
187
+ return result
188
+
189
+
190
+ @lru_cache(maxsize=1)
191
+ def should_use_flashinfer_cutlass_moe_fp4_allgather():
192
+ """
193
+ Perform FP4 quantize before all-gather for flashinfer cutlass moe to reduce communication cost for high-throughput serving.
194
+ """
195
+ return (
196
+ not DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER
197
+ and get_moe_runner_backend().is_flashinfer_cutlass()
198
+ and is_dp_attention_enabled()
199
+ and get_moe_expert_parallel_world_size() == get_attention_dp_size()
200
+ )
@@ -17,57 +17,173 @@ import torch
17
17
  import triton
18
18
  import triton.language as tl
19
19
 
20
+ FMIX32_C1 = 0x85EBCA6B
21
+ FMIX32_C2 = 0xC2B2AE35
22
+ POS_C1 = 0x27D4EB2D
23
+ POS_C2 = 0x165667B1
24
+
25
+
26
+ @triton.jit
27
+ def _rotl32(x, r: tl.constexpr):
28
+ return (x << r) | (x >> (32 - r))
29
+
30
+
31
+ @triton.jit
32
+ def _fmix32(x, C1: tl.constexpr, C2: tl.constexpr):
33
+ c1 = tl.full((), C1, tl.uint32)
34
+ c2 = tl.full((), C2, tl.uint32)
35
+ x ^= x >> 16
36
+ x = x * c1
37
+ x ^= x >> 13
38
+ x = x * c2
39
+ x ^= x >> 16
40
+ return x
41
+
20
42
 
21
43
  @triton.jit
22
- def hash_kernel(
23
- input_ptr,
24
- output_ptr,
25
- n_elements,
26
- BLOCK_SIZE: tl.constexpr,
27
- PRIME: tl.constexpr,
28
- XCONST: tl.constexpr,
44
+ def hash_tiles32_kernel_blocked(
45
+ in_ptr,
46
+ out_ptr,
47
+ n_u32,
48
+ seed1,
49
+ seed2,
50
+ FM_C1: tl.constexpr,
51
+ FM_C2: tl.constexpr,
52
+ POS_A: tl.constexpr,
53
+ POS_B: tl.constexpr,
54
+ TILE: tl.constexpr,
55
+ BLOCK: tl.constexpr,
56
+ USE_CG: tl.constexpr,
29
57
  ):
30
58
  pid = tl.program_id(axis=0)
31
- block_start = pid * BLOCK_SIZE
32
- offsets = block_start + tl.arange(0, BLOCK_SIZE)
33
- mask = offsets < n_elements
59
+ base = pid * TILE
60
+
61
+ s1 = tl.full((), seed1, tl.uint32)
62
+ s2 = tl.full((), seed2, tl.uint32)
63
+ posA = tl.full((), POS_A, tl.uint32)
64
+ posB = tl.full((), POS_B, tl.uint32)
65
+
66
+ h1 = tl.zeros((), dtype=tl.uint32)
67
+ h2 = tl.zeros((), dtype=tl.uint32)
68
+
69
+ for off in tl.static_range(0, TILE, BLOCK):
70
+ idx = base + off + tl.arange(0, BLOCK)
71
+ m = idx < n_u32
34
72
 
35
- data = tl.load(input_ptr + offsets, mask=mask, other=0).to(tl.int64)
36
- mixed = data ^ (offsets.to(tl.int64) + XCONST)
37
- hash_val = mixed * PRIME
38
- hash_val = hash_val ^ (hash_val >> 16)
39
- hash_val = hash_val * (PRIME ^ XCONST)
40
- hash_val = hash_val ^ (hash_val >> 13)
73
+ if USE_CG:
74
+ v = tl.load(in_ptr + idx, mask=m, other=0, cache_modifier=".cg")
75
+ else:
76
+ v = tl.load(in_ptr + idx, mask=m, other=0)
77
+ v = v.to(tl.uint32)
78
+
79
+ iu = idx.to(tl.uint32)
80
+ p1 = (iu * posA + s1) ^ _rotl32(iu, 15)
81
+ p2 = (iu * posB + s2) ^ _rotl32(iu, 13)
82
+
83
+ k1 = _fmix32(v ^ p1, C1=FM_C1, C2=FM_C2)
84
+ k2 = _fmix32(v ^ p2, C1=FM_C1, C2=FM_C2)
85
+
86
+ zero32 = tl.zeros_like(k1)
87
+ k1 = tl.where(m, k1, zero32)
88
+ k2 = tl.where(m, k2, zero32)
89
+
90
+ h1 += tl.sum(k1, axis=0).to(tl.uint32)
91
+ h2 += tl.sum(k2, axis=0).to(tl.uint32)
92
+
93
+ nbytes = tl.full((), n_u32 * 4, tl.uint32)
94
+ h1 ^= nbytes
95
+ h2 ^= nbytes
96
+ h1 = _fmix32(h1, C1=FM_C1, C2=FM_C2)
97
+ h2 = (
98
+ _fmix32(h2, C1=FMIX32_C1, C2=FMIX32_C2)
99
+ if False
100
+ else _fmix32(h2, C1=FM_C1, C2=FM_C2)
101
+ )
102
+
103
+ out = (h1.to(tl.uint64) << 32) | h2.to(tl.uint64)
104
+ tl.store(out_ptr + pid, out)
105
+
106
+
107
+ @triton.jit
108
+ def add_tree_reduce_u64_kernel(in_ptr, out_ptr, n_elems, CHUNK: tl.constexpr):
109
+ pid = tl.program_id(axis=0)
110
+ start = pid * CHUNK
111
+ h = tl.zeros((), dtype=tl.uint64)
112
+ for i in tl.static_range(0, CHUNK):
113
+ idx = start + i
114
+ m = idx < n_elems
115
+ v = tl.load(in_ptr + idx, mask=m, other=0).to(tl.uint64)
116
+ h += v
117
+ tl.store(out_ptr + pid, h)
41
118
 
42
- tl.store(output_ptr + offsets, hash_val, mask=mask)
43
119
 
120
+ def _as_uint32_words(t: torch.Tensor) -> torch.Tensor:
121
+ assert t.is_cuda, "Use .cuda() first"
122
+ tb = t.contiguous().view(torch.uint8)
123
+ nbytes = tb.numel()
124
+ pad = (4 - (nbytes & 3)) & 3
125
+ if pad:
126
+ tb_p = torch.empty(nbytes + pad, dtype=torch.uint8, device=tb.device)
127
+ tb_p[:nbytes].copy_(tb)
128
+ tb_p[nbytes:].zero_()
129
+ tb = tb_p
130
+ return tb.view(torch.uint32)
44
131
 
45
- PRIME_1 = -(11400714785074694791 ^ 0xFFFFFFFFFFFFFFFF) - 1
46
- PRIME_2 = -(14029467366897019727 ^ 0xFFFFFFFFFFFFFFFF) - 1
47
132
 
133
+ def _final_splitmix64(x: int) -> int:
134
+ mask = (1 << 64) - 1
135
+ x &= mask
136
+ x ^= x >> 30
137
+ x = (x * 0xBF58476D1CE4E5B9) & mask
138
+ x ^= x >> 27
139
+ x = (x * 0x94D049BB133111EB) & mask
140
+ x ^= x >> 31
141
+ return x
48
142
 
49
- def gpu_tensor_hash(tensor: torch.Tensor) -> int:
50
- assert tensor.is_cuda
51
- tensor = tensor.contiguous().view(torch.int32)
52
- n = tensor.numel()
53
- BLOCK_SIZE = 1024
54
- grid = (triton.cdiv(n, BLOCK_SIZE),)
55
143
 
56
- intermediate_hashes = torch.empty(n, dtype=torch.int64, device=tensor.device)
144
+ @torch.inference_mode()
145
+ def gpu_tensor_hash(
146
+ tensor: torch.Tensor,
147
+ *,
148
+ seed: int = 0x243F6A88,
149
+ tile_words: int = 8192,
150
+ block_words: int = 256,
151
+ reduce_chunk: int = 1024,
152
+ num_warps: int = 4,
153
+ num_stages: int = 4,
154
+ use_cg: bool = True,
155
+ ) -> int:
156
+ assert tensor.is_cuda, "Use .cuda() first"
157
+ u32 = _as_uint32_words(tensor)
158
+ n = u32.numel()
159
+ if n == 0:
160
+ return 0
57
161
 
58
- # Set cuda device to prevent ValueError: Pointer argument (at 0) cannot be accessed from Triton (cpu tensor?)
59
- # Solution from Tri: https://github.com/Dao-AILab/flash-attention/issues/523#issuecomment-1707611579
60
- with torch.cuda.device(tensor.device):
61
- hash_kernel[grid](
62
- tensor,
63
- intermediate_hashes,
64
- n,
65
- BLOCK_SIZE=BLOCK_SIZE,
66
- PRIME=PRIME_1,
67
- XCONST=PRIME_2,
68
- )
162
+ grid1 = (triton.cdiv(n, tile_words),)
163
+ partials = torch.empty(grid1[0], dtype=torch.uint64, device=u32.device)
164
+ hash_tiles32_kernel_blocked[grid1](
165
+ u32,
166
+ partials,
167
+ n,
168
+ seed1=seed & 0xFFFFFFFF,
169
+ seed2=((seed * 0x9E3779B1) ^ 0xDEADBEEF) & 0xFFFFFFFF,
170
+ FM_C1=FMIX32_C1,
171
+ FM_C2=FMIX32_C2,
172
+ POS_A=POS_C1,
173
+ POS_B=POS_C2,
174
+ TILE=tile_words,
175
+ BLOCK=block_words,
176
+ USE_CG=use_cg,
177
+ num_warps=num_warps,
178
+ num_stages=num_stages,
179
+ )
69
180
 
70
- # TODO: threads can't be synced on triton kernel
71
- final_hash = intermediate_hashes.sum().item()
181
+ cur = partials
182
+ while cur.numel() > 1:
183
+ n_elems = cur.numel()
184
+ grid2 = (triton.cdiv(n_elems, reduce_chunk),)
185
+ nxt = torch.empty(grid2[0], dtype=torch.uint64, device=cur.device)
186
+ add_tree_reduce_u64_kernel[grid2](cur, nxt, n_elems, CHUNK=reduce_chunk)
187
+ cur = nxt
72
188
 
73
- return final_hash
189
+ return _final_splitmix64(int(cur.item()))
@@ -16,7 +16,6 @@ try:
16
16
  )
17
17
  from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig
18
18
  from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
19
- from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config
20
19
  from vllm.model_executor.layers.quantization.gguf import GGUFConfig
21
20
  from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
22
21
  GPTQMarlin24Config,
@@ -37,9 +36,9 @@ except ImportError as e:
37
36
 
38
37
  AQLMConfig = BitsAndBytesConfig = CompressedTensorsConfig = DeepSpeedFPConfig = (
39
38
  ExpertsInt8Config
40
- ) = FBGEMMFp8Config = GGUFConfig = GPTQMarlin24Config = MarlinConfig = QQQConfig = (
41
- Int8TpuConfig
42
- ) = DummyConfig
39
+ ) = GGUFConfig = GPTQMarlin24Config = MarlinConfig = QQQConfig = Int8TpuConfig = (
40
+ DummyConfig
41
+ )
43
42
 
44
43
 
45
44
  from sglang.srt.layers.quantization.awq import AWQConfig, AWQMarlinConfig
@@ -48,20 +47,9 @@ from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config
48
47
  from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import (
49
48
  CompressedTensorsConfig,
50
49
  )
51
- from sglang.srt.utils import is_cuda, is_hip, mxfp_supported
52
-
53
- is_mxfp_supported = mxfp_supported()
54
- if is_mxfp_supported:
55
- from sglang.srt.layers.quantization.fp4 import MxFp4Config
56
-
57
50
  from sglang.srt.layers.quantization.fp8 import Fp8Config
58
- from sglang.srt.layers.quantization.gptq import (
59
- GPTQConfig,
60
- GPTQLinearMethod,
61
- GPTQMarlinConfig,
62
- GPTQMarlinLinearMethod,
63
- GPTQMarlinMoEMethod,
64
- )
51
+ from sglang.srt.layers.quantization.fpgemm_fp8 import FBGEMMFp8Config
52
+ from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
65
53
  from sglang.srt.layers.quantization.modelopt_quant import (
66
54
  ModelOptFp4Config,
67
55
  ModelOptFp8Config,
@@ -70,10 +58,12 @@ from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config
70
58
  from sglang.srt.layers.quantization.mxfp4 import Mxfp4Config
71
59
  from sglang.srt.layers.quantization.petit import PetitNvFp4Config
72
60
  from sglang.srt.layers.quantization.qoq import QoQConfig
73
- from sglang.srt.layers.quantization.utils import get_linear_quant_method
74
61
  from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config
75
62
  from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
76
63
  from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
64
+ from sglang.srt.utils import is_cuda, is_hip, mxfp_supported
65
+
66
+ _is_mxfp_supported = mxfp_supported()
77
67
 
78
68
  if TYPE_CHECKING:
79
69
  from sglang.srt.layers.moe.topk import TopKOutput
@@ -86,11 +76,16 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
86
76
  "modelopt_fp4": ModelOptFp4Config,
87
77
  "w8a8_int8": W8A8Int8Config,
88
78
  "w8a8_fp8": W8A8Fp8Config,
79
+ "awq": AWQConfig,
80
+ "awq_marlin": AWQMarlinConfig,
81
+ "gptq": GPTQConfig,
82
+ "gptq_marlin": GPTQMarlinConfig,
89
83
  "moe_wna16": MoeWNA16Config,
90
84
  "compressed-tensors": CompressedTensorsConfig,
91
85
  "qoq": QoQConfig,
92
86
  "w4afp8": W4AFp8Config,
93
87
  "petit_nvfp4": PetitNvFp4Config,
88
+ "fbgemm_fp8": FBGEMMFp8Config,
94
89
  }
95
90
 
96
91
 
@@ -101,29 +96,26 @@ if is_cuda():
101
96
  "mxfp4": Mxfp4Config,
102
97
  }
103
98
  )
104
- elif is_mxfp_supported and is_hip():
99
+ elif _is_mxfp_supported and is_hip():
100
+ from sglang.srt.layers.quantization.quark.quark import QuarkConfig
101
+
105
102
  BASE_QUANTIZATION_METHODS.update(
106
103
  {
107
- "quark": MxFp4Config,
108
- "mxfp4": MxFp4Config,
104
+ "quark": QuarkConfig,
105
+ "mxfp4": Mxfp4Config,
109
106
  }
110
107
  )
111
108
  # VLLM-dependent quantization methods
112
109
  VLLM_QUANTIZATION_METHODS = {
113
110
  "aqlm": AQLMConfig,
114
- "awq": AWQConfig,
115
111
  "deepspeedfp": DeepSpeedFPConfig,
116
112
  "tpu_int8": Int8TpuConfig,
117
- "fbgemm_fp8": FBGEMMFp8Config,
118
113
  "marlin": MarlinConfig,
119
114
  "gguf": GGUFConfig,
120
115
  "gptq_marlin_24": GPTQMarlin24Config,
121
- "awq_marlin": AWQMarlinConfig,
122
116
  "bitsandbytes": BitsAndBytesConfig,
123
117
  "qqq": QQQConfig,
124
118
  "experts_int8": ExpertsInt8Config,
125
- "gptq_marlin": GPTQMarlinConfig,
126
- "gptq": GPTQConfig,
127
119
  }
128
120
 
129
121
  QUANTIZATION_METHODS = {**BASE_QUANTIZATION_METHODS, **VLLM_QUANTIZATION_METHODS}
@@ -145,23 +137,6 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
145
137
  return QUANTIZATION_METHODS[quantization]
146
138
 
147
139
 
148
- def gptq_get_quant_method(self, layer, prefix):
149
- from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
150
-
151
- if isinstance(layer, FusedMoE):
152
- return GPTQMarlinMoEMethod(self)
153
-
154
- if isinstance(self, GPTQConfig):
155
- return get_linear_quant_method(
156
- self, layer, prefix=prefix, linear_method_cls=GPTQLinearMethod
157
- )
158
- elif isinstance(self, GPTQMarlinConfig):
159
- return get_linear_quant_method(
160
- self, layer, prefix=prefix, linear_method_cls=GPTQMarlinLinearMethod
161
- )
162
- return None
163
-
164
-
165
140
  original_isinstance = builtins.isinstance
166
141
 
167
142
 
@@ -239,10 +214,7 @@ def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
239
214
 
240
215
  def monkey_patch_quant_configs():
241
216
  """Apply all monkey patches in one place."""
242
- setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
243
- setattr(GPTQConfig, "get_quant_method", gptq_get_quant_method)
244
217
 
245
- monkey_patch_moe_apply(GPTQMarlinMoEMethod)
246
218
  monkey_patch_moe_apply(CompressedTensorsW8A8Fp8MoEMethod)
247
219
  monkey_patch_moe_apply(CompressedTensorsWNA16MoEMethod)
248
220
 
@@ -29,29 +29,26 @@ from sglang.srt.layers.quantization.marlin_utils import (
29
29
  verify_marlin_supported,
30
30
  verify_marlin_supports_shape,
31
31
  )
32
- from sglang.srt.layers.quantization.scalar_type import scalar_types
33
32
  from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
34
- from sglang.srt.layers.quantization.utils import replace_parameter
33
+ from sglang.srt.layers.quantization.utils import get_scalar_types, replace_parameter
35
34
 
36
35
  if TYPE_CHECKING:
37
- from sglang.srt.layers.moe.topk import TopKOutput
38
-
39
- try:
40
- from vllm import _custom_ops as ops
41
-
42
- warnings.warn(
43
- f"Using kernels directly from vllm. This might lead to performance degradation or "
44
- f"missing functionalities as certain kernels may not be optimized. "
45
- )
46
- except ImportError:
47
- ops = None
36
+ from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
37
+ from sglang.srt.layers.moe.topk import StandardTopKOutput
48
38
 
49
39
  from sglang.srt.utils import is_cuda, is_hip
50
40
 
51
41
  _is_cuda = is_cuda()
52
42
  _is_hip = is_hip()
53
43
  if _is_cuda:
54
- from sgl_kernel import awq_dequantize, fused_marlin_moe
44
+ from sgl_kernel import (
45
+ awq_dequantize,
46
+ awq_marlin_moe_repack,
47
+ awq_marlin_repack,
48
+ fused_marlin_moe,
49
+ )
50
+
51
+
55
52
  elif _is_hip:
56
53
  from sglang.srt.layers.quantization.awq_triton import (
57
54
  awq_dequantize_triton as awq_dequantize,
@@ -64,6 +61,9 @@ else:
64
61
  logger = logging.getLogger(__name__)
65
62
 
66
63
 
64
+ ScalarType, scalar_types = get_scalar_types()
65
+
66
+
67
67
  def is_layer_skipped_awq(prefix: str, modules_to_not_convert: List[str]):
68
68
  return any(module_name in prefix for module_name in modules_to_not_convert)
69
69
 
@@ -516,7 +516,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
516
516
  layer.workspace = marlin_make_workspace(device)
517
517
 
518
518
  # Repack weights from AWQ format to marlin format.
519
- marlin_qweight = ops.awq_marlin_repack(
519
+ marlin_qweight = awq_marlin_repack(
520
520
  layer.qweight,
521
521
  size_k=layer.input_size_per_partition,
522
522
  size_n=layer.output_size_per_partition,
@@ -684,7 +684,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
684
684
  requires_grad=False,
685
685
  )
686
686
 
687
- marlin_w13_qweight = ops.awq_marlin_moe_repack(
687
+ marlin_w13_qweight = awq_marlin_moe_repack(
688
688
  layer.w13_qweight,
689
689
  layer.w13_g_idx_sort_indices,
690
690
  size_k=layer.w13_qweight.shape[1],
@@ -693,7 +693,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
693
693
  )
694
694
  replace_parameter(layer, "w13_qweight", marlin_w13_qweight)
695
695
 
696
- marlin_w2_qweight = ops.awq_marlin_moe_repack(
696
+ marlin_w2_qweight = awq_marlin_moe_repack(
697
697
  layer.w2_qweight,
698
698
  layer.w2_g_idx_sort_indices,
699
699
  size_k=layer.w2_qweight.shape[1],
@@ -740,13 +740,12 @@ class AWQMoEMethod(FusedMoEMethodBase):
740
740
  self,
741
741
  layer: torch.nn.Module,
742
742
  x: torch.Tensor,
743
- topk_output: TopKOutput,
744
- *,
745
- activation: str = "silu",
746
- **kwargs,
743
+ topk_output: StandardTopKOutput,
744
+ moe_runner_config: MoeRunnerConfig,
747
745
  ) -> torch.Tensor:
748
-
749
- assert activation == "silu", "Only SiLU activation is supported."
746
+ assert (
747
+ moe_runner_config.activation == "silu"
748
+ ), "Only SiLU activation is supported."
750
749
 
751
750
  # The input must currently be float16
752
751
  orig_dtype = x.dtype
@@ -9,6 +9,7 @@ import torch
9
9
  from torch import nn
10
10
 
11
11
  if TYPE_CHECKING:
12
+ from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
12
13
  from sglang.srt.layers.moe.topk import TopKOutput
13
14
 
14
15
 
@@ -100,12 +101,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
100
101
  layer: torch.nn.Module,
101
102
  x: torch.Tensor,
102
103
  topk_output: TopKOutput,
103
- *,
104
- activation: str = "silu",
105
- apply_router_weight_on_input: bool = False,
106
- inplace: bool = True,
107
- no_combine: bool = False,
108
- routed_scaling_factor: Optional[float] = None,
104
+ moe_runner_config: MoeRunnerConfig,
109
105
  ) -> torch.Tensor:
110
106
  raise NotImplementedError
111
107