sglang 0.5.4__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (195) hide show
  1. sglang/bench_one_batch.py +149 -34
  2. sglang/bench_serving.py +73 -14
  3. sglang/compile_deep_gemm.py +13 -7
  4. sglang/launch_server.py +2 -0
  5. sglang/srt/batch_invariant_ops/__init__.py +2 -0
  6. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +221 -4
  7. sglang/srt/checkpoint_engine/__init__.py +9 -0
  8. sglang/srt/checkpoint_engine/update.py +317 -0
  9. sglang/srt/compilation/backend.py +1 -1
  10. sglang/srt/configs/__init__.py +2 -0
  11. sglang/srt/configs/deepseek_ocr.py +542 -10
  12. sglang/srt/configs/deepseekvl2.py +95 -194
  13. sglang/srt/configs/kimi_linear.py +160 -0
  14. sglang/srt/configs/mamba_utils.py +66 -0
  15. sglang/srt/configs/model_config.py +30 -7
  16. sglang/srt/constants.py +7 -0
  17. sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
  18. sglang/srt/disaggregation/decode.py +34 -6
  19. sglang/srt/disaggregation/nixl/conn.py +2 -2
  20. sglang/srt/disaggregation/prefill.py +25 -3
  21. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
  22. sglang/srt/distributed/parallel_state.py +9 -12
  23. sglang/srt/entrypoints/engine.py +31 -20
  24. sglang/srt/entrypoints/grpc_server.py +0 -1
  25. sglang/srt/entrypoints/http_server.py +94 -94
  26. sglang/srt/entrypoints/openai/protocol.py +7 -1
  27. sglang/srt/entrypoints/openai/serving_chat.py +42 -0
  28. sglang/srt/entrypoints/openai/serving_completions.py +10 -0
  29. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  30. sglang/srt/environ.py +23 -2
  31. sglang/srt/eplb/expert_distribution.py +64 -1
  32. sglang/srt/eplb/expert_location.py +106 -36
  33. sglang/srt/function_call/function_call_parser.py +2 -0
  34. sglang/srt/function_call/minimax_m2.py +367 -0
  35. sglang/srt/grpc/compile_proto.py +3 -0
  36. sglang/srt/layers/activation.py +6 -0
  37. sglang/srt/layers/attention/ascend_backend.py +233 -5
  38. sglang/srt/layers/attention/attention_registry.py +3 -0
  39. sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
  40. sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
  41. sglang/srt/layers/attention/fla/kda.py +1359 -0
  42. sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
  43. sglang/srt/layers/attention/flashattention_backend.py +19 -8
  44. sglang/srt/layers/attention/flashinfer_backend.py +10 -1
  45. sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -11
  46. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  47. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
  48. sglang/srt/layers/attention/mamba/mamba.py +20 -11
  49. sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
  50. sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
  51. sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
  52. sglang/srt/layers/attention/nsa/transform_index.py +1 -1
  53. sglang/srt/layers/attention/nsa_backend.py +157 -23
  54. sglang/srt/layers/attention/triton_backend.py +4 -1
  55. sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
  56. sglang/srt/layers/attention/trtllm_mla_backend.py +11 -15
  57. sglang/srt/layers/attention/utils.py +78 -0
  58. sglang/srt/layers/communicator.py +24 -1
  59. sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +1 -1
  60. sglang/srt/layers/layernorm.py +35 -6
  61. sglang/srt/layers/logits_processor.py +9 -20
  62. sglang/srt/layers/moe/cutlass_w4a8_moe.py +138 -0
  63. sglang/srt/layers/moe/ep_moe/kernels.py +194 -0
  64. sglang/srt/layers/moe/ep_moe/layer.py +78 -289
  65. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  66. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
  67. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
  68. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
  69. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
  70. sglang/srt/layers/moe/fused_moe_triton/layer.py +3 -3
  71. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +7 -4
  72. sglang/srt/layers/moe/moe_runner/deep_gemm.py +340 -55
  73. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  74. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  75. sglang/srt/layers/moe/token_dispatcher/__init__.py +4 -4
  76. sglang/srt/layers/moe/token_dispatcher/base.py +11 -5
  77. sglang/srt/layers/moe/token_dispatcher/deepep.py +25 -18
  78. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  79. sglang/srt/layers/moe/topk.py +35 -10
  80. sglang/srt/layers/moe/utils.py +3 -4
  81. sglang/srt/layers/pooler.py +21 -2
  82. sglang/srt/layers/quantization/__init__.py +13 -84
  83. sglang/srt/layers/quantization/auto_round.py +394 -0
  84. sglang/srt/layers/quantization/awq.py +0 -3
  85. sglang/srt/layers/quantization/base_config.py +7 -0
  86. sglang/srt/layers/quantization/fp8.py +68 -63
  87. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  88. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  89. sglang/srt/layers/quantization/gguf.py +566 -0
  90. sglang/srt/layers/quantization/modelopt_quant.py +168 -11
  91. sglang/srt/layers/quantization/mxfp4.py +30 -38
  92. sglang/srt/layers/quantization/unquant.py +23 -45
  93. sglang/srt/layers/quantization/w4afp8.py +38 -2
  94. sglang/srt/layers/radix_attention.py +5 -2
  95. sglang/srt/layers/rotary_embedding.py +130 -46
  96. sglang/srt/layers/sampler.py +12 -1
  97. sglang/srt/lora/lora_registry.py +9 -0
  98. sglang/srt/managers/async_mm_data_processor.py +122 -0
  99. sglang/srt/managers/data_parallel_controller.py +30 -3
  100. sglang/srt/managers/detokenizer_manager.py +3 -0
  101. sglang/srt/managers/io_struct.py +29 -4
  102. sglang/srt/managers/multi_tokenizer_mixin.py +22 -1
  103. sglang/srt/managers/schedule_batch.py +74 -15
  104. sglang/srt/managers/scheduler.py +185 -144
  105. sglang/srt/managers/scheduler_metrics_mixin.py +22 -14
  106. sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
  107. sglang/srt/managers/scheduler_pp_mixin.py +7 -2
  108. sglang/srt/managers/scheduler_profiler_mixin.py +3 -4
  109. sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
  110. sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
  111. sglang/srt/managers/session_controller.py +6 -5
  112. sglang/srt/managers/tokenizer_manager.py +165 -78
  113. sglang/srt/managers/tp_worker.py +24 -1
  114. sglang/srt/mem_cache/base_prefix_cache.py +23 -4
  115. sglang/srt/mem_cache/common.py +1 -0
  116. sglang/srt/mem_cache/hicache_storage.py +7 -1
  117. sglang/srt/mem_cache/memory_pool.py +253 -57
  118. sglang/srt/mem_cache/memory_pool_host.py +12 -5
  119. sglang/srt/mem_cache/radix_cache.py +4 -0
  120. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  121. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
  122. sglang/srt/metrics/collector.py +46 -3
  123. sglang/srt/model_executor/cuda_graph_runner.py +15 -3
  124. sglang/srt/model_executor/forward_batch_info.py +55 -14
  125. sglang/srt/model_executor/model_runner.py +77 -170
  126. sglang/srt/model_executor/npu_graph_runner.py +7 -3
  127. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +22 -12
  128. sglang/srt/model_loader/weight_utils.py +1 -1
  129. sglang/srt/models/bailing_moe.py +9 -2
  130. sglang/srt/models/deepseek_nextn.py +11 -2
  131. sglang/srt/models/deepseek_v2.py +296 -78
  132. sglang/srt/models/glm4.py +391 -77
  133. sglang/srt/models/glm4_moe.py +322 -354
  134. sglang/srt/models/glm4_moe_nextn.py +4 -14
  135. sglang/srt/models/glm4v.py +196 -55
  136. sglang/srt/models/glm4v_moe.py +29 -197
  137. sglang/srt/models/gpt_oss.py +1 -10
  138. sglang/srt/models/kimi_linear.py +678 -0
  139. sglang/srt/models/llama4.py +1 -1
  140. sglang/srt/models/llama_eagle3.py +11 -1
  141. sglang/srt/models/longcat_flash.py +2 -2
  142. sglang/srt/models/minimax_m2.py +922 -0
  143. sglang/srt/models/nvila.py +355 -0
  144. sglang/srt/models/nvila_lite.py +184 -0
  145. sglang/srt/models/qwen2.py +23 -2
  146. sglang/srt/models/qwen2_moe.py +30 -15
  147. sglang/srt/models/qwen3.py +35 -5
  148. sglang/srt/models/qwen3_moe.py +18 -12
  149. sglang/srt/models/qwen3_next.py +7 -0
  150. sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
  151. sglang/srt/multimodal/processors/base_processor.py +1 -0
  152. sglang/srt/multimodal/processors/glm4v.py +1 -1
  153. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  154. sglang/srt/multimodal/processors/points_v15_chat.py +2 -2
  155. sglang/srt/multiplex/multiplexing_mixin.py +209 -0
  156. sglang/srt/multiplex/pdmux_context.py +164 -0
  157. sglang/srt/parser/conversation.py +7 -1
  158. sglang/srt/parser/reasoning_parser.py +28 -1
  159. sglang/srt/sampling/custom_logit_processor.py +67 -1
  160. sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
  161. sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
  162. sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
  163. sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
  164. sglang/srt/server_args.py +459 -199
  165. sglang/srt/single_batch_overlap.py +2 -4
  166. sglang/srt/speculative/draft_utils.py +16 -0
  167. sglang/srt/speculative/eagle_info.py +42 -36
  168. sglang/srt/speculative/eagle_info_v2.py +68 -25
  169. sglang/srt/speculative/eagle_utils.py +261 -16
  170. sglang/srt/speculative/eagle_worker.py +11 -3
  171. sglang/srt/speculative/eagle_worker_v2.py +15 -9
  172. sglang/srt/speculative/spec_info.py +305 -31
  173. sglang/srt/speculative/spec_utils.py +44 -8
  174. sglang/srt/tracing/trace.py +121 -12
  175. sglang/srt/utils/common.py +142 -74
  176. sglang/srt/utils/hf_transformers_utils.py +38 -12
  177. sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
  178. sglang/test/kits/radix_cache_server_kit.py +50 -0
  179. sglang/test/runners.py +31 -7
  180. sglang/test/simple_eval_common.py +5 -3
  181. sglang/test/simple_eval_humaneval.py +1 -0
  182. sglang/test/simple_eval_math.py +1 -0
  183. sglang/test/simple_eval_mmlu.py +1 -0
  184. sglang/test/simple_eval_mmmu_vlm.py +1 -0
  185. sglang/test/test_deterministic.py +235 -12
  186. sglang/test/test_deterministic_utils.py +2 -1
  187. sglang/test/test_utils.py +7 -1
  188. sglang/version.py +1 -1
  189. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +15 -28
  190. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +194 -175
  191. sglang/srt/models/vila.py +0 -306
  192. /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
  193. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
  194. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
  195. {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, List, Optional
5
5
 
6
6
  import torch
7
7
 
8
+ from sglang.srt.layers import deep_gemm_wrapper
8
9
  from sglang.srt.layers.moe.moe_runner.base import (
9
10
  MoeQuantInfo,
10
11
  MoeRunnerConfig,
@@ -15,14 +16,31 @@ from sglang.srt.layers.moe.moe_runner.base import (
15
16
  register_pre_permute,
16
17
  )
17
18
  from sglang.srt.layers.moe.utils import MoeRunnerBackend
18
- from sglang.srt.utils import dispose_tensor
19
+ from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip, is_npu
20
+ from sglang.srt.utils.offloader import get_offloader
19
21
 
20
22
  if TYPE_CHECKING:
23
+ from sglang.srt.layers.moe.token_dispatcher.deepep import (
24
+ DeepEPLLCombineInput,
25
+ DeepEPLLDispatchOutput,
26
+ DeepEPNormalCombineInput,
27
+ DeepEPNormalDispatchOutput,
28
+ )
21
29
  from sglang.srt.layers.moe.token_dispatcher.standard import (
22
30
  StandardCombineInput,
23
31
  StandardDispatchOutput,
24
32
  )
25
33
 
34
+ _is_hip = is_hip()
35
+ _is_npu = is_npu()
36
+ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
37
+
38
+ if not (_is_npu or _is_hip):
39
+ from sgl_kernel import silu_and_mul
40
+
41
+
42
+ _MASKED_GEMM_FAST_ACT = get_bool_env_var("SGLANG_MASKED_GEMM_FAST_ACT")
43
+
26
44
 
27
45
  # TODO(kaixih@nvidia): ideally we should merge this logic into
28
46
  # `fill_gateup_input_triton_kernel` to directly generate e8m0 scale.
@@ -40,13 +58,23 @@ def _cast_to_e8m0_with_rounding_up(x: torch.Tensor) -> torch.Tensor:
40
58
  return new_x.transpose(1, 2).contiguous().transpose(1, 2)
41
59
 
42
60
 
61
+ def copy_list_to_gpu_no_ce(arr: List[int]):
62
+ from sgl_kernel.elementwise import copy_to_gpu_no_ce
63
+
64
+ tensor_cpu = torch.tensor(arr, dtype=torch.int32, device="cpu")
65
+ tensor_gpu = torch.empty_like(tensor_cpu, device="cuda")
66
+ copy_to_gpu_no_ce(tensor_cpu, tensor_gpu)
67
+ return tensor_gpu
68
+
69
+
43
70
  @dataclass
44
71
  class DeepGemmRunnerInput(RunnerInput):
45
72
  hidden_states: torch.Tensor
46
73
  hidden_states_scale: torch.Tensor
47
- masked_m: torch.Tensor
48
- expected_m: int
49
74
  use_masked_gemm: bool
75
+ masked_m: Optional[torch.Tensor] = None
76
+ expected_m: Optional[int] = None
77
+ m_indices: Optional[torch.Tensor] = None
50
78
 
51
79
  @property
52
80
  def runner_backend(self) -> MoeRunnerBackend:
@@ -84,20 +112,100 @@ class DeepGemmRunnerCore(MoeRunnerCore):
84
112
  running_state: dict,
85
113
  ) -> DeepGemmRunnerOutput:
86
114
 
87
- if runner_input.use_masked_gemm:
88
- hidden_states = self._run_masked_gemm(
89
- runner_input,
90
- quant_info,
91
- running_state,
115
+ if not runner_input.use_masked_gemm:
116
+ hidden_states = self._run_contiguous_gemm(
117
+ runner_input, quant_info, running_state
92
118
  )
93
119
  else:
94
- hidden_states = self._run_contiguous_gemm(
95
- runner_input,
96
- quant_info,
97
- running_state,
120
+ hidden_states = self._run_masked_gemm(
121
+ runner_input, quant_info, running_state
98
122
  )
99
123
  return DeepGemmRunnerOutput(hidden_states=hidden_states)
100
124
 
125
+ def _run_contiguous_gemm(
126
+ self,
127
+ runner_input: DeepGemmRunnerInput,
128
+ quant_info: DeepGemmMoeQuantInfo,
129
+ running_state: dict,
130
+ ) -> torch.Tensor:
131
+
132
+ from sglang.srt.layers.moe.ep_moe.kernels import tma_align_input_scale
133
+ from sglang.srt.layers.quantization.fp8_kernel import (
134
+ sglang_per_token_group_quant_fp8,
135
+ )
136
+
137
+ hidden_states = runner_input.hidden_states
138
+ hidden_states_scale = runner_input.hidden_states_scale
139
+ all_tokens = running_state["all_tokens"]
140
+ hidden_states_device = running_state["hidden_states_device"]
141
+ hidden_states_dtype = running_state["hidden_states_dtype"]
142
+ hidden_states_shape = running_state["hidden_states_shape"]
143
+ m_indices = runner_input.m_indices
144
+
145
+ N = quant_info.w13_weight.size(1)
146
+ K = hidden_states_shape[1]
147
+ scale_block_size = 128
148
+
149
+ w13_weight_fp8 = (
150
+ quant_info.w13_weight,
151
+ quant_info.w13_scale,
152
+ )
153
+ w2_weight_fp8 = (quant_info.w2_weight, quant_info.w2_scale)
154
+
155
+ gateup_output = torch.empty(
156
+ (all_tokens, N),
157
+ device=hidden_states_device,
158
+ dtype=torch.bfloat16,
159
+ )
160
+ if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
161
+ hidden_states_scale = tma_align_input_scale(hidden_states_scale)
162
+ deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
163
+ (hidden_states, hidden_states_scale),
164
+ w13_weight_fp8,
165
+ gateup_output,
166
+ m_indices,
167
+ )
168
+
169
+ dispose_tensor(hidden_states)
170
+ dispose_tensor(hidden_states_scale)
171
+
172
+ down_input = torch.empty(
173
+ (
174
+ all_tokens,
175
+ N // 2,
176
+ ),
177
+ device=gateup_output.device,
178
+ dtype=torch.bfloat16,
179
+ )
180
+ silu_and_mul(gateup_output.view(-1, N), down_input)
181
+ del gateup_output
182
+
183
+ down_input_fp8, down_input_scale = sglang_per_token_group_quant_fp8(
184
+ down_input,
185
+ scale_block_size,
186
+ column_major_scales=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
187
+ scale_tma_aligned=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
188
+ scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
189
+ )
190
+ del down_input
191
+
192
+ down_output = torch.empty(
193
+ (all_tokens, K),
194
+ device=hidden_states_device,
195
+ dtype=torch.bfloat16,
196
+ )
197
+ if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
198
+ down_input_scale = tma_align_input_scale(down_input_scale)
199
+
200
+ deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
201
+ (down_input_fp8, down_input_scale),
202
+ w2_weight_fp8,
203
+ down_output,
204
+ m_indices,
205
+ )
206
+
207
+ return down_output
208
+
101
209
  def _run_masked_gemm(
102
210
  self,
103
211
  runner_input: DeepGemmRunnerInput,
@@ -109,6 +217,9 @@ class DeepGemmRunnerCore(MoeRunnerCore):
109
217
  from sglang.srt.layers.moe.ep_moe.kernels import (
110
218
  silu_and_mul_masked_post_quant_fwd,
111
219
  )
220
+ from sglang.srt.layers.quantization.fp8_kernel import (
221
+ sglang_per_token_group_quant_8bit,
222
+ )
112
223
 
113
224
  hidden_states = runner_input.hidden_states
114
225
  hidden_states_scale = runner_input.hidden_states_scale
@@ -122,15 +233,16 @@ class DeepGemmRunnerCore(MoeRunnerCore):
122
233
 
123
234
  hidden_states_device = running_state["hidden_states_device"]
124
235
 
125
- if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
126
- b, s_mn, s_k = hidden_states_scale.shape
127
- assert (
128
- s_mn % 4 == 0 and s_k % 4 == 0
129
- ), f"scales must be aligned to 4, but got ({b}, {s_mn}, {s_k})"
130
-
131
236
  # GroupGemm-0
132
237
  if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
133
- hidden_states_scale = _cast_to_e8m0_with_rounding_up(hidden_states_scale)
238
+ if hidden_states_scale.dtype != torch.int:
239
+ b, s_mn, s_k = hidden_states_scale.shape
240
+ assert (
241
+ s_mn % 4 == 0 and s_k % 4 == 0
242
+ ), f"scales must be aligned to 4, but got ({b}, {s_mn}, {s_k})"
243
+ hidden_states_scale = _cast_to_e8m0_with_rounding_up(
244
+ hidden_states_scale
245
+ )
134
246
  else:
135
247
  hidden_states_scale = deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(
136
248
  hidden_states_scale
@@ -149,35 +261,49 @@ class DeepGemmRunnerCore(MoeRunnerCore):
149
261
  expected_m,
150
262
  )
151
263
  dispose_tensor(hidden_states)
264
+ dispose_tensor(hidden_states_scale)
152
265
 
153
266
  # Act
154
- down_input = torch.empty(
155
- (
156
- gateup_output.shape[0],
157
- gateup_output.shape[1],
158
- gateup_output.shape[2] // 2,
159
- ),
160
- device=hidden_states_device,
161
- dtype=torch.float8_e4m3fn,
162
- )
163
267
  scale_block_size = 128
164
- down_input_scale = torch.empty(
165
- (
166
- gateup_output.shape[0],
167
- gateup_output.shape[1],
168
- gateup_output.shape[2] // 2 // scale_block_size,
169
- ),
170
- device=hidden_states_device,
171
- dtype=torch.float32,
172
- )
173
- silu_and_mul_masked_post_quant_fwd(
174
- gateup_output,
175
- down_input,
176
- down_input_scale,
177
- scale_block_size,
178
- masked_m,
179
- scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
180
- )
268
+ if _MASKED_GEMM_FAST_ACT:
269
+ down_input, down_input_scale = sglang_per_token_group_quant_8bit(
270
+ x=gateup_output,
271
+ dst_dtype=torch.float8_e4m3fn,
272
+ group_size=scale_block_size,
273
+ masked_m=masked_m,
274
+ column_major_scales=True,
275
+ scale_tma_aligned=True,
276
+ scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
277
+ fuse_silu_and_mul=True,
278
+ enable_v2=True,
279
+ )
280
+ else:
281
+ down_input = torch.empty(
282
+ (
283
+ gateup_output.shape[0],
284
+ gateup_output.shape[1],
285
+ gateup_output.shape[2] // 2,
286
+ ),
287
+ device=hidden_states_device,
288
+ dtype=torch.float8_e4m3fn,
289
+ )
290
+ down_input_scale = torch.empty(
291
+ (
292
+ gateup_output.shape[0],
293
+ gateup_output.shape[1],
294
+ gateup_output.shape[2] // 2 // scale_block_size,
295
+ ),
296
+ device=hidden_states_device,
297
+ dtype=torch.float32,
298
+ )
299
+ silu_and_mul_masked_post_quant_fwd(
300
+ gateup_output,
301
+ down_input,
302
+ down_input_scale,
303
+ scale_block_size,
304
+ masked_m,
305
+ scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
306
+ )
181
307
  del gateup_output
182
308
 
183
309
  # GroupGemm-1
@@ -198,18 +324,9 @@ class DeepGemmRunnerCore(MoeRunnerCore):
198
324
  masked_m,
199
325
  expected_m,
200
326
  )
201
- del down_input
202
327
 
203
328
  return down_output
204
329
 
205
- def _run_contiguous_gemm(
206
- self,
207
- runner_input: DeepGemmRunnerInput,
208
- quant_info: DeepGemmMoeQuantInfo,
209
- running_state: dict,
210
- ) -> torch.Tensor:
211
- pass
212
-
213
330
  @property
214
331
  def runner_backend(self) -> MoeRunnerBackend:
215
332
  return MoeRunnerBackend.DEEP_GEMM
@@ -222,6 +339,7 @@ def pre_permute_standard_to_deep_gemm(
222
339
  runner_config: MoeRunnerConfig,
223
340
  running_state: dict,
224
341
  ) -> DeepGemmRunnerInput:
342
+
225
343
  from sglang.srt.layers.moe.ep_moe.kernels import moe_ep_deepgemm_preprocess
226
344
 
227
345
  hidden_states, topk_output = dispatch_output
@@ -257,9 +375,9 @@ def pre_permute_standard_to_deep_gemm(
257
375
  return DeepGemmRunnerInput(
258
376
  hidden_states=hidden_states,
259
377
  hidden_states_scale=hidden_states_scale,
378
+ use_masked_gemm=True,
260
379
  masked_m=masked_m,
261
380
  expected_m=expected_m,
262
- use_masked_gemm=True,
263
381
  )
264
382
 
265
383
 
@@ -302,3 +420,170 @@ def post_permute_deep_gemm_to_standard(
302
420
  return StandardCombineInput(
303
421
  hidden_states=output,
304
422
  )
423
+
424
+
425
+ @register_pre_permute("deepep_ll", "deep_gemm")
426
+ def pre_permute_deepep_ll_to_deep_gemm(
427
+ dispatch_output: DeepEPLLDispatchOutput,
428
+ quant_info: DeepGemmMoeQuantInfo,
429
+ runner_config: MoeRunnerConfig,
430
+ running_state: dict,
431
+ ) -> DeepGemmRunnerInput:
432
+
433
+ hidden_states, hidden_states_scale, topk_ids, topk_weights, masked_m, expected_m = (
434
+ dispatch_output
435
+ )
436
+
437
+ running_state["topk_ids"] = topk_ids
438
+ running_state["topk_weights"] = topk_weights
439
+ running_state["hidden_states_shape"] = hidden_states.shape
440
+ running_state["hidden_states_dtype"] = hidden_states.dtype
441
+ running_state["hidden_states_device"] = hidden_states.device
442
+
443
+ return DeepGemmRunnerInput(
444
+ hidden_states=hidden_states,
445
+ hidden_states_scale=hidden_states_scale,
446
+ use_masked_gemm=True,
447
+ masked_m=masked_m,
448
+ expected_m=expected_m,
449
+ )
450
+
451
+
452
+ @register_post_permute("deep_gemm", "deepep_ll")
453
+ def post_permute_deep_gemm_to_deepep_ll(
454
+ runner_output: DeepGemmRunnerOutput,
455
+ quant_info: DeepGemmMoeQuantInfo,
456
+ runner_config: MoeRunnerConfig,
457
+ running_state: dict,
458
+ ) -> DeepEPLLCombineInput:
459
+
460
+ from sglang.srt.layers.moe.token_dispatcher.deepep import DeepEPLLCombineInput
461
+
462
+ return DeepEPLLCombineInput(
463
+ hidden_states=runner_output.hidden_states,
464
+ topk_ids=running_state["topk_ids"],
465
+ topk_weights=running_state["topk_weights"],
466
+ )
467
+
468
+
469
+ @register_pre_permute("deepep_normal", "deep_gemm")
470
+ def pre_permute_deepep_normal_to_deep_gemm(
471
+ dispatch_output: DeepEPNormalDispatchOutput,
472
+ quant_info: DeepGemmMoeQuantInfo,
473
+ runner_config: MoeRunnerConfig,
474
+ running_state: dict,
475
+ ) -> DeepGemmRunnerInput:
476
+
477
+ from sglang.srt.layers.moe.ep_moe.kernels import ep_scatter
478
+
479
+ (
480
+ hidden_states,
481
+ hidden_states_scale,
482
+ topk_ids,
483
+ topk_weights,
484
+ num_recv_tokens_per_expert,
485
+ ) = dispatch_output
486
+ assert runner_config.activation == "silu"
487
+
488
+ all_tokens = sum(num_recv_tokens_per_expert)
489
+ running_state["all_tokens"] = all_tokens
490
+
491
+ K = hidden_states.shape[1]
492
+
493
+ hidden_states_shape = hidden_states.shape
494
+ hidden_states_device = hidden_states.device
495
+ hidden_states_dtype = hidden_states.dtype
496
+
497
+ running_state["hidden_states_shape"] = hidden_states_shape
498
+ running_state["hidden_states_device"] = hidden_states_device
499
+ running_state["hidden_states_dtype"] = hidden_states_dtype
500
+ running_state["topk_ids"] = topk_ids
501
+ running_state["topk_weights"] = topk_weights
502
+
503
+ input_tensor = torch.empty(
504
+ (all_tokens, K),
505
+ device=hidden_states.device,
506
+ dtype=hidden_states.dtype,
507
+ )
508
+ if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
509
+ # TODO check whether need `zeros`
510
+ input_tensor_scale = torch.zeros(
511
+ (ceil_div(K // 128, 4), all_tokens),
512
+ device=hidden_states.device,
513
+ dtype=torch.int,
514
+ ).transpose(0, 1)
515
+ else:
516
+ input_tensor_scale = torch.empty(
517
+ (all_tokens, K // 128),
518
+ device=hidden_states.device,
519
+ dtype=torch.float32,
520
+ )
521
+ m_indices = torch.empty(all_tokens, device=hidden_states.device, dtype=torch.int32)
522
+ output_index = torch.empty_like(topk_ids)
523
+
524
+ if get_offloader().forbid_copy_engine_usage:
525
+ num_recv_tokens_per_expert_gpu = copy_list_to_gpu_no_ce(
526
+ num_recv_tokens_per_expert
527
+ )
528
+ else:
529
+ num_recv_tokens_per_expert_gpu = torch.tensor(
530
+ num_recv_tokens_per_expert,
531
+ dtype=torch.int32,
532
+ pin_memory=True,
533
+ device="cpu",
534
+ ).cuda(non_blocking=True)
535
+ expert_start_loc = torch.empty_like(num_recv_tokens_per_expert_gpu)
536
+
537
+ ep_scatter(
538
+ hidden_states,
539
+ hidden_states_scale,
540
+ topk_ids,
541
+ num_recv_tokens_per_expert_gpu,
542
+ expert_start_loc,
543
+ input_tensor,
544
+ input_tensor_scale,
545
+ m_indices,
546
+ output_index,
547
+ scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
548
+ )
549
+ dispose_tensor(hidden_states)
550
+ dispose_tensor(hidden_states_scale)
551
+
552
+ running_state["output_index"] = output_index
553
+
554
+ return DeepGemmRunnerInput(
555
+ hidden_states=input_tensor,
556
+ hidden_states_scale=input_tensor_scale,
557
+ use_masked_gemm=False,
558
+ m_indices=m_indices,
559
+ )
560
+
561
+
562
+ @register_post_permute("deep_gemm", "deepep_normal")
563
+ def post_permute_deep_gemm_to_deepep_normal(
564
+ runner_output: DeepGemmRunnerOutput,
565
+ quant_info: DeepGemmMoeQuantInfo,
566
+ runner_config: MoeRunnerConfig,
567
+ running_state: dict,
568
+ ) -> DeepEPNormalCombineInput:
569
+
570
+ from sglang.srt.layers.moe.ep_moe.kernels import ep_gather
571
+ from sglang.srt.layers.moe.token_dispatcher.deepep import DeepEPNormalCombineInput
572
+
573
+ hidden_states = runner_output.hidden_states
574
+ topk_ids = running_state["topk_ids"]
575
+ topk_weights = running_state["topk_weights"]
576
+ output_index = running_state["output_index"]
577
+
578
+ gather_out = torch.empty(
579
+ running_state["hidden_states_shape"],
580
+ device=running_state["hidden_states_device"],
581
+ dtype=torch.bfloat16,
582
+ )
583
+ ep_gather(hidden_states, topk_ids, topk_weights, output_index, gather_out)
584
+
585
+ return DeepEPNormalCombineInput(
586
+ hidden_states=gather_out,
587
+ topk_ids=running_state["topk_ids"],
588
+ topk_weights=running_state["topk_weights"],
589
+ )
@@ -11,6 +11,7 @@ from sglang.srt.layers.moe.moe_runner.base import (
11
11
  )
12
12
  from sglang.srt.layers.moe.moe_runner.deep_gemm import DeepGemmRunnerCore
13
13
  from sglang.srt.layers.moe.moe_runner.triton import TritonRunnerCore
14
+ from sglang.srt.layers.moe.moe_runner.triton_kernels import TritonKernelsRunnerCore
14
15
  from sglang.srt.layers.moe.utils import get_moe_a2a_backend
15
16
 
16
17
  if TYPE_CHECKING:
@@ -31,6 +32,8 @@ class MoeRunner:
31
32
 
32
33
  if runner_backend.is_triton():
33
34
  self.runner_core = TritonRunnerCore(config)
35
+ elif runner_backend.is_triton_kernels():
36
+ self.runner_core = TritonKernelsRunnerCore(config)
34
37
  elif runner_backend.is_deep_gemm():
35
38
  self.runner_core = DeepGemmRunnerCore(config)
36
39
  else:
@@ -0,0 +1,194 @@
1
+ """Triton kernels MoE runner backend skeleton."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass
6
+ from typing import TYPE_CHECKING, Optional
7
+
8
+ import torch
9
+
10
+ from sglang.srt.layers.moe.moe_runner.base import (
11
+ MoeQuantInfo,
12
+ MoeRunnerConfig,
13
+ MoeRunnerCore,
14
+ RunnerInput,
15
+ RunnerOutput,
16
+ register_post_permute,
17
+ register_pre_permute,
18
+ )
19
+ from sglang.srt.layers.moe.utils import MoeRunnerBackend
20
+
21
+ if TYPE_CHECKING:
22
+ from triton_kernels.matmul_ogs import PrecisionConfig
23
+ from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx
24
+
25
+ from sglang.srt.layers.moe.token_dispatcher.standard import (
26
+ StandardCombineInput,
27
+ StandardDispatchOutput,
28
+ )
29
+
30
+
31
+ # ---------------------------------------------------------------------------
32
+ # Runner IO dataclasses
33
+ # ---------------------------------------------------------------------------
34
+
35
+
36
+ @dataclass
37
+ class TritonKernelsRunnerInput(RunnerInput):
38
+ """Input bundle passed to the triton-kernels runner core."""
39
+
40
+ hidden_states: torch.Tensor
41
+ routing_data: "RoutingData"
42
+ gather_indx: "GatherIndx"
43
+ scatter_indx: "ScatterIndx"
44
+
45
+ @property
46
+ def runner_backend(self) -> MoeRunnerBackend:
47
+ return MoeRunnerBackend.TRITON_KERNELS
48
+
49
+
50
+ @dataclass
51
+ class TritonKernelsRunnerOutput(RunnerOutput):
52
+ """Output bundle returned from the triton-kernels runner core."""
53
+
54
+ hidden_states: torch.Tensor
55
+
56
+ @property
57
+ def runner_backend(self) -> MoeRunnerBackend:
58
+ return MoeRunnerBackend.TRITON_KERNELS
59
+
60
+
61
+ @dataclass
62
+ class TritonKernelsQuantInfo(MoeQuantInfo):
63
+ """Quantization payload consumed by the triton-kernels backend."""
64
+
65
+ w13_weight: torch.Tensor
66
+ w2_weight: torch.Tensor
67
+ w13_bias: Optional[torch.Tensor] = None
68
+ w2_bias: Optional[torch.Tensor] = None
69
+ w13_precision_config: Optional[PrecisionConfig] = None
70
+ w2_precision_config: Optional[PrecisionConfig] = None
71
+ global_num_experts: int = -1
72
+
73
+
74
+ # ---------------------------------------------------------------------------
75
+ # Runner core
76
+ # ---------------------------------------------------------------------------
77
+
78
+
79
+ class TritonKernelsRunnerCore(MoeRunnerCore):
80
+ """Execute MoE experts via the external triton_kernels package."""
81
+
82
+ def run(
83
+ self,
84
+ runner_input: TritonKernelsRunnerInput,
85
+ quant_info: TritonKernelsQuantInfo,
86
+ running_state: dict,
87
+ ) -> TritonKernelsRunnerOutput:
88
+ from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
89
+ triton_kernel_fused_experts,
90
+ triton_kernel_fused_experts_with_bias,
91
+ )
92
+
93
+ hidden_states = runner_input.hidden_states
94
+
95
+ common_kwargs = dict(
96
+ routing_data=runner_input.routing_data,
97
+ gather_indx=runner_input.gather_indx,
98
+ scatter_indx=None if self.config.no_combine else runner_input.scatter_indx,
99
+ inplace=False,
100
+ activation=self.config.activation,
101
+ apply_router_weight_on_input=self.config.apply_router_weight_on_input,
102
+ global_num_experts=quant_info.global_num_experts,
103
+ )
104
+
105
+ has_bias = quant_info.w13_bias is not None or quant_info.w2_bias is not None
106
+
107
+ if has_bias:
108
+ assert (
109
+ quant_info.w13_bias is not None and quant_info.w2_bias is not None
110
+ ), "Bias execution requires both w13_bias and w2_bias"
111
+ output = triton_kernel_fused_experts_with_bias(
112
+ hidden_states=hidden_states,
113
+ w1=quant_info.w13_weight,
114
+ w1_pcg=quant_info.w13_precision_config,
115
+ b1=quant_info.w13_bias,
116
+ w2=quant_info.w2_weight,
117
+ w2_pcg=quant_info.w2_precision_config,
118
+ b2=quant_info.w2_bias,
119
+ gemm1_alpha=self.config.gemm1_alpha,
120
+ gemm1_clamp_limit=self.config.gemm1_clamp_limit,
121
+ **common_kwargs,
122
+ )
123
+ else:
124
+ output = triton_kernel_fused_experts(
125
+ hidden_states=hidden_states,
126
+ w1=quant_info.w13_weight,
127
+ w2=quant_info.w2_weight,
128
+ **common_kwargs,
129
+ )
130
+
131
+ if self.config.no_combine:
132
+ tokens = runner_input.hidden_states.shape[0]
133
+ hidden = runner_input.hidden_states.shape[-1]
134
+ total_rows = output.shape[0]
135
+ top_k = total_rows // tokens
136
+ output = output.view(tokens, top_k, hidden)
137
+
138
+ return TritonKernelsRunnerOutput(hidden_states=output)
139
+
140
+ @property
141
+ def runner_backend(self) -> MoeRunnerBackend:
142
+ return MoeRunnerBackend.TRITON_KERNELS
143
+
144
+
145
+ # ---------------------------------------------------------------------------
146
+ # Permute / fused hooks
147
+ # ---------------------------------------------------------------------------
148
+
149
+
150
+ @register_pre_permute("standard", "triton_kernel")
151
+ def pre_permute_standard_to_triton_kernels(
152
+ dispatch_output: "StandardDispatchOutput",
153
+ quant_info: TritonKernelsQuantInfo,
154
+ runner_config: MoeRunnerConfig,
155
+ running_state: dict,
156
+ ) -> TritonKernelsRunnerInput:
157
+ from sglang.srt.layers.moe.topk import TopKOutputChecker
158
+
159
+ hidden_states = dispatch_output.hidden_states
160
+ topk_output = dispatch_output.topk_output
161
+
162
+ assert TopKOutputChecker.format_is_triton_kernels(
163
+ topk_output
164
+ ), "Triton-kernel runner expects TritonKernelTopKOutput"
165
+
166
+ routing_data, gather_indx, scatter_indx = topk_output
167
+
168
+ return TritonKernelsRunnerInput(
169
+ hidden_states=hidden_states,
170
+ routing_data=routing_data,
171
+ gather_indx=gather_indx,
172
+ scatter_indx=scatter_indx,
173
+ )
174
+
175
+
176
+ @register_post_permute("triton_kernel", "standard")
177
+ def post_permute_triton_kernels_to_standard(
178
+ runner_output: TritonKernelsRunnerOutput,
179
+ quant_info: TritonKernelsQuantInfo,
180
+ runner_config: MoeRunnerConfig,
181
+ running_state: dict,
182
+ ) -> StandardCombineInput:
183
+ from sglang.srt.layers.moe.token_dispatcher.standard import StandardCombineInput
184
+
185
+ hidden_states = runner_output.hidden_states
186
+
187
+ if (
188
+ runner_config.routed_scaling_factor is not None
189
+ and runner_config.routed_scaling_factor != 1.0
190
+ and not runner_config.no_combine
191
+ ):
192
+ hidden_states.mul_(runner_config.routed_scaling_factor)
193
+
194
+ return StandardCombineInput(hidden_states=hidden_states)