sglang 0.5.3__py3-none-any.whl → 0.5.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (112) hide show
  1. sglang/bench_one_batch.py +0 -2
  2. sglang/bench_serving.py +224 -127
  3. sglang/compile_deep_gemm.py +3 -0
  4. sglang/launch_server.py +0 -14
  5. sglang/srt/configs/__init__.py +2 -0
  6. sglang/srt/configs/falcon_h1.py +12 -58
  7. sglang/srt/configs/mamba_utils.py +117 -0
  8. sglang/srt/configs/model_config.py +68 -31
  9. sglang/srt/configs/nemotron_h.py +286 -0
  10. sglang/srt/configs/qwen3_next.py +11 -43
  11. sglang/srt/disaggregation/decode.py +7 -18
  12. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +1 -1
  13. sglang/srt/disaggregation/nixl/conn.py +55 -23
  14. sglang/srt/disaggregation/prefill.py +17 -32
  15. sglang/srt/entrypoints/engine.py +2 -2
  16. sglang/srt/entrypoints/grpc_request_manager.py +10 -23
  17. sglang/srt/entrypoints/grpc_server.py +220 -80
  18. sglang/srt/entrypoints/http_server.py +49 -1
  19. sglang/srt/entrypoints/openai/protocol.py +159 -31
  20. sglang/srt/entrypoints/openai/serving_chat.py +13 -71
  21. sglang/srt/entrypoints/openai/serving_tokenize.py +144 -0
  22. sglang/srt/environ.py +4 -0
  23. sglang/srt/function_call/function_call_parser.py +8 -6
  24. sglang/srt/grpc/sglang_scheduler_pb2.py +78 -70
  25. sglang/srt/grpc/sglang_scheduler_pb2.pyi +64 -6
  26. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +88 -0
  27. sglang/srt/layers/attention/attention_registry.py +31 -22
  28. sglang/srt/layers/attention/fla/layernorm_gated.py +47 -30
  29. sglang/srt/layers/attention/flashattention_backend.py +0 -1
  30. sglang/srt/layers/attention/flashinfer_backend.py +223 -6
  31. sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -1
  32. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +165 -59
  33. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -1
  34. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +9 -4
  35. sglang/srt/layers/attention/mamba/mamba.py +189 -241
  36. sglang/srt/layers/attention/mamba/mamba2_metadata.py +211 -0
  37. sglang/srt/layers/attention/mamba/mixer2_rms_norm_gated.py +120 -0
  38. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +0 -50
  39. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +0 -60
  40. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +0 -111
  41. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +0 -11
  42. sglang/srt/layers/attention/triton_backend.py +1 -1
  43. sglang/srt/layers/logits_processor.py +136 -6
  44. sglang/srt/layers/modelopt_utils.py +11 -0
  45. sglang/srt/layers/moe/cutlass_w4a8_moe.py +18 -21
  46. sglang/srt/layers/moe/ep_moe/kernels.py +31 -452
  47. sglang/srt/layers/moe/ep_moe/layer.py +8 -286
  48. sglang/srt/layers/moe/fused_moe_triton/layer.py +6 -11
  49. sglang/srt/layers/moe/moe_runner/deep_gemm.py +304 -0
  50. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  51. sglang/srt/layers/moe/utils.py +7 -1
  52. sglang/srt/layers/quantization/__init__.py +1 -1
  53. sglang/srt/layers/quantization/fp8.py +84 -18
  54. sglang/srt/layers/quantization/modelopt_quant.py +1 -1
  55. sglang/srt/layers/quantization/quark/quark.py +3 -1
  56. sglang/srt/layers/quantization/w4afp8.py +2 -16
  57. sglang/srt/lora/lora_manager.py +0 -8
  58. sglang/srt/managers/overlap_utils.py +18 -16
  59. sglang/srt/managers/schedule_batch.py +119 -90
  60. sglang/srt/managers/schedule_policy.py +1 -1
  61. sglang/srt/managers/scheduler.py +213 -126
  62. sglang/srt/managers/scheduler_metrics_mixin.py +1 -1
  63. sglang/srt/managers/scheduler_output_processor_mixin.py +180 -86
  64. sglang/srt/managers/tokenizer_manager.py +270 -53
  65. sglang/srt/managers/tp_worker.py +39 -28
  66. sglang/srt/mem_cache/allocator.py +7 -2
  67. sglang/srt/mem_cache/chunk_cache.py +1 -1
  68. sglang/srt/mem_cache/memory_pool.py +162 -68
  69. sglang/srt/mem_cache/radix_cache.py +8 -3
  70. sglang/srt/mem_cache/swa_radix_cache.py +70 -14
  71. sglang/srt/model_executor/cuda_graph_runner.py +1 -1
  72. sglang/srt/model_executor/forward_batch_info.py +4 -18
  73. sglang/srt/model_executor/model_runner.py +55 -51
  74. sglang/srt/model_loader/__init__.py +1 -1
  75. sglang/srt/model_loader/loader.py +187 -6
  76. sglang/srt/model_loader/weight_utils.py +3 -0
  77. sglang/srt/models/falcon_h1.py +11 -9
  78. sglang/srt/models/gemma3_mm.py +16 -0
  79. sglang/srt/models/grok.py +5 -13
  80. sglang/srt/models/mixtral.py +1 -3
  81. sglang/srt/models/mllama4.py +11 -1
  82. sglang/srt/models/nemotron_h.py +514 -0
  83. sglang/srt/models/utils.py +5 -1
  84. sglang/srt/sampling/sampling_batch_info.py +11 -9
  85. sglang/srt/server_args.py +100 -33
  86. sglang/srt/speculative/eagle_worker.py +11 -13
  87. sglang/srt/speculative/ngram_worker.py +12 -11
  88. sglang/srt/speculative/spec_utils.py +0 -1
  89. sglang/srt/two_batch_overlap.py +1 -0
  90. sglang/srt/utils/common.py +18 -0
  91. sglang/srt/utils/hf_transformers_utils.py +2 -0
  92. sglang/test/longbench_v2/__init__.py +1 -0
  93. sglang/test/longbench_v2/test_longbench_v2_eval.py +238 -0
  94. sglang/test/longbench_v2/validate_longbench_v2.py +337 -0
  95. sglang/test/longbench_v2/validate_longbench_v2_standalone.py +306 -0
  96. sglang/test/run_eval.py +40 -0
  97. sglang/test/simple_eval_longbench_v2.py +332 -0
  98. sglang/test/test_cutlass_w4a8_moe.py +9 -19
  99. sglang/test/test_deterministic.py +18 -2
  100. sglang/test/test_deterministic_utils.py +81 -0
  101. sglang/test/test_disaggregation_utils.py +63 -0
  102. sglang/test/test_utils.py +32 -11
  103. sglang/version.py +1 -1
  104. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/METADATA +4 -4
  105. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/RECORD +109 -98
  106. sglang/srt/layers/attention/mamba/mamba_utils.py +0 -81
  107. sglang/srt/managers/tp_worker_overlap_thread.py +0 -311
  108. sglang/test/test_block_fp8_ep.py +0 -358
  109. /sglang/srt/speculative/{ngram_utils.py → ngram_info.py} +0 -0
  110. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/WHEEL +0 -0
  111. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/licenses/LICENSE +0 -0
  112. {sglang-0.5.3.dist-info → sglang-0.5.3.post1.dist-info}/top_level.txt +0 -0
@@ -1,14 +1,10 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import logging
4
- from contextlib import nullcontext
5
- from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
4
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union
6
5
 
7
6
  import torch
8
- import triton
9
- import triton.language as tl
10
7
 
11
- from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
12
8
  from sglang.srt.layers.moe import (
13
9
  get_deepep_mode,
14
10
  get_moe_a2a_backend,
@@ -18,13 +14,10 @@ from sglang.srt.layers.moe import (
18
14
  from sglang.srt.layers.moe.ep_moe.kernels import (
19
15
  ep_gather,
20
16
  ep_scatter,
21
- moe_ep_deepgemm_preprocess,
22
- post_reorder_triton_kernel,
23
17
  silu_and_mul_masked_post_quant_fwd,
24
18
  tma_align_input_scale,
25
19
  )
26
20
  from sglang.srt.layers.moe.fused_moe_triton.layer import FlashInferFusedMoE, FusedMoE
27
- from sglang.srt.layers.moe.topk import TopKOutput
28
21
  from sglang.srt.layers.quantization import deep_gemm_wrapper
29
22
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
30
23
  from sglang.srt.layers.quantization.fp8 import Fp8Config
@@ -36,19 +29,10 @@ from sglang.srt.layers.quantization.modelopt_quant import (
36
29
  CUTEDSL_MOE_NVFP4_DISPATCH,
37
30
  ModelOptNvFp4FusedMoEMethod,
38
31
  )
39
- from sglang.srt.managers.schedule_batch import global_server_args_dict
40
32
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
41
33
  from sglang.srt.offloader import get_offloader
42
34
  from sglang.srt.single_batch_overlap import DownGemmOverlapArgs
43
- from sglang.srt.utils import (
44
- ceil_div,
45
- dispose_tensor,
46
- get_bool_env_var,
47
- get_int_env_var,
48
- is_cuda,
49
- is_hip,
50
- is_npu,
51
- )
35
+ from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip, is_npu
52
36
 
53
37
  if TYPE_CHECKING:
54
38
  from sglang.srt.layers.moe.token_dispatcher import (
@@ -72,29 +56,13 @@ if _use_aiter:
72
56
  logger = logging.getLogger(__name__)
73
57
 
74
58
 
75
- # TODO(kaixih@nvidia): ideally we should merge this logic into
76
- # `fill_gateup_input_triton_kernel` to directly generate e8m0 scale.
77
- @torch.compile
78
- def _cast_to_e8m0_with_rounding_up(x: torch.Tensor) -> torch.Tensor:
79
- temp = x.to(torch.float32).view(torch.int32)
80
- exp = torch.bitwise_right_shift(temp, 23)
81
- mant = torch.bitwise_and(temp, 0x7FFFFF)
82
- is_ru = torch.logical_and(
83
- torch.logical_and((mant > 0), (exp != 0xFE)),
84
- ~torch.logical_and((exp == 0), (mant <= 0x400000)),
85
- )
86
- exp = torch.where(is_ru, exp + 1, exp)
87
- new_x = exp.to(torch.uint8).view(torch.int)
88
- return new_x.transpose(1, 2).contiguous().transpose(1, 2)
89
-
90
-
91
- class EPMoE(FusedMoE):
59
+ class DeepEPMoE(FusedMoE):
92
60
  """
93
- MoE Expert Parallel Impl
94
-
95
-
61
+ MoE Expert Parallel Impl based on DeepEP (https://github.com/deepseek-ai/DeepEP/tree/main)
96
62
  """
97
63
 
64
+ _has_printed = False
65
+
98
66
  def __init__(
99
67
  self,
100
68
  num_experts: int,
@@ -108,272 +76,29 @@ class EPMoE(FusedMoE):
108
76
  prefix: str = "",
109
77
  activation: str = "silu",
110
78
  routed_scaling_factor: Optional[float] = None,
111
- gemm1_alpha: Optional[float] = None,
112
- gemm1_clamp_limit: Optional[float] = None,
113
- with_bias: bool = False,
114
79
  ):
115
80
  super().__init__(
116
81
  num_experts=num_experts,
82
+ top_k=top_k,
117
83
  hidden_size=hidden_size,
118
84
  intermediate_size=intermediate_size,
119
- num_fused_shared_experts=num_fused_shared_experts,
120
85
  layer_id=layer_id,
121
- top_k=top_k,
86
+ num_fused_shared_experts=num_fused_shared_experts,
122
87
  params_dtype=params_dtype,
123
88
  quant_config=quant_config,
124
89
  prefix=prefix,
125
90
  activation=activation,
126
- # apply_router_weight_on_input=apply_router_weight_on_input,
127
91
  routed_scaling_factor=routed_scaling_factor,
128
- gemm1_alpha=gemm1_alpha,
129
- gemm1_clamp_limit=gemm1_clamp_limit,
130
- with_bias=with_bias,
131
92
  )
132
93
 
133
- self.intermediate_size = intermediate_size
134
-
135
94
  if isinstance(quant_config, Fp8Config):
136
95
  self.use_block_quant = getattr(self.quant_method, "block_quant", False)
137
- self.block_shape = (
138
- self.quant_method.quant_config.weight_block_size
139
- if self.use_block_quant
140
- else None
141
- )
142
96
  self.use_fp8_w8a8 = True
143
97
  self.fp8_dtype = torch.float8_e4m3fn
144
- self.activation_scheme = quant_config.activation_scheme
145
98
  else:
146
99
  self.use_fp8_w8a8 = False
147
100
  self.use_block_quant = False
148
- self.block_shape = None
149
- self.activation_scheme = None
150
-
151
- def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
152
- if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8:
153
- return self.forward_deepgemm(hidden_states, topk_output)
154
- else:
155
- return super().forward(hidden_states, topk_output)
156
-
157
- def forward_deepgemm(
158
- self,
159
- hidden_states: torch.Tensor,
160
- topk_output: TopKOutput,
161
- ):
162
-
163
- self.w13_weight_fp8 = (
164
- self.w13_weight,
165
- (
166
- self.w13_weight_scale_inv
167
- if self.use_block_quant
168
- else self.w13_weight_scale
169
- ),
170
- )
171
- self.w2_weight_fp8 = (
172
- self.w2_weight,
173
- self.w2_weight_scale_inv if self.use_block_quant else self.w2_weight_scale,
174
- )
175
-
176
- assert self.quant_method is not None
177
- assert self.moe_runner_config.activation == "silu"
178
-
179
- hidden_states_shape = hidden_states.shape
180
- hidden_states_dtype = hidden_states.dtype
181
- hidden_states_device = hidden_states.device
182
-
183
- topk_weights, topk_ids, _ = topk_output
184
-
185
- if not self.use_block_quant:
186
- # Convert per-tensor quant to per-block quant by repeating scales for forward_deepgemm
187
- scale_block_size = 128
188
- w13_weight_scale_n = 2 * (
189
- (self.intermediate_size + scale_block_size - 1) // scale_block_size
190
- )
191
- w13_weight_scale_k = (
192
- hidden_states_shape[-1] + scale_block_size - 1
193
- ) // scale_block_size
194
- w13_weight_scale = (
195
- self.w13_weight_scale.unsqueeze(1)
196
- .repeat_interleave(w13_weight_scale_n, dim=1)
197
- .unsqueeze(2)
198
- .repeat_interleave(w13_weight_scale_k, dim=2)
199
- )
200
- self.w13_weight_fp8 = (
201
- self.w13_weight,
202
- w13_weight_scale,
203
- )
204
- w2_weight_scale_n = (
205
- hidden_states_shape[-1] + scale_block_size - 1
206
- ) // scale_block_size
207
- w2_weight_scale_k = (
208
- self.intermediate_size + scale_block_size - 1
209
- ) // scale_block_size
210
- w2_weight_scale = (
211
- self.w2_weight_scale.unsqueeze(1)
212
- .repeat_interleave(w2_weight_scale_n, dim=1)
213
- .unsqueeze(2)
214
- .repeat_interleave(w2_weight_scale_k, dim=2)
215
- )
216
- self.w2_weight_fp8 = (
217
- self.w2_weight,
218
- w2_weight_scale,
219
- )
220
-
221
- # PreReorder
222
- m_max, masked_m, expected_m, src2dst, gateup_input, gateup_input_scale = (
223
- moe_ep_deepgemm_preprocess(
224
- topk_ids,
225
- self.num_experts,
226
- hidden_states,
227
- self.top_k,
228
- self.start_expert_id,
229
- self.end_expert_id,
230
- self.block_shape,
231
- )
232
- )
233
-
234
- dispose_tensor(hidden_states)
235
-
236
- if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
237
- b, s_mn, s_k = gateup_input_scale.shape
238
- assert (
239
- s_mn % 4 == 0 and s_k % 4 == 0
240
- ), f"scales must be aligned to 4, but got ({b}, {s_mn}, {s_k})"
241
-
242
- # GroupGemm-0
243
- gateup_input_fp8 = (
244
- gateup_input,
245
- (
246
- _cast_to_e8m0_with_rounding_up(gateup_input_scale)
247
- if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
248
- else deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(
249
- gateup_input_scale
250
- )
251
- ),
252
- )
253
- num_groups, m, k = gateup_input_fp8[0].size()
254
- n = self.w13_weight.size(1)
255
- gateup_output = torch.empty(
256
- (num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
257
- )
258
- deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
259
- gateup_input_fp8,
260
- self.w13_weight_fp8,
261
- gateup_output,
262
- masked_m,
263
- expected_m,
264
- )
265
- del gateup_input
266
- del gateup_input_fp8
267
101
 
268
- # Act
269
- down_input = torch.empty(
270
- (
271
- gateup_output.shape[0],
272
- gateup_output.shape[1],
273
- gateup_output.shape[2] // 2,
274
- ),
275
- device=hidden_states_device,
276
- dtype=self.fp8_dtype,
277
- )
278
- scale_block_size = 128
279
- down_input_scale = torch.empty(
280
- (
281
- gateup_output.shape[0],
282
- gateup_output.shape[1],
283
- gateup_output.shape[2] // 2 // scale_block_size,
284
- ),
285
- device=hidden_states_device,
286
- dtype=torch.float32,
287
- )
288
- silu_and_mul_masked_post_quant_fwd(
289
- gateup_output,
290
- down_input,
291
- down_input_scale,
292
- scale_block_size,
293
- masked_m,
294
- scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
295
- )
296
- del gateup_output
297
-
298
- # GroupGemm-1
299
- n = self.w2_weight.size(1)
300
- down_input_fp8 = (
301
- down_input,
302
- (
303
- down_input_scale
304
- if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0
305
- else deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(down_input_scale)
306
- ),
307
- )
308
- down_output = torch.empty(
309
- (num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
310
- )
311
- deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
312
- down_input_fp8,
313
- self.w2_weight_fp8,
314
- down_output,
315
- masked_m,
316
- expected_m,
317
- )
318
- del down_input
319
- del down_input_fp8
320
-
321
- # PostReorder
322
- output = torch.empty(
323
- hidden_states_shape, dtype=hidden_states_dtype, device=hidden_states_device
324
- )
325
- post_reorder_triton_kernel[(hidden_states_shape[0],)](
326
- down_output,
327
- output,
328
- src2dst,
329
- topk_ids,
330
- topk_weights,
331
- self.start_expert_id,
332
- self.end_expert_id,
333
- self.top_k,
334
- hidden_states_shape[1],
335
- m_max * self.start_expert_id,
336
- BLOCK_SIZE=512,
337
- )
338
- if self.moe_runner_config.routed_scaling_factor is not None:
339
- output *= self.moe_runner_config.routed_scaling_factor
340
- return output
341
-
342
-
343
- class DeepEPMoE(EPMoE):
344
- """
345
- MoE Expert Parallel Impl based on DeepEP (https://github.com/deepseek-ai/DeepEP/tree/main)
346
- """
347
-
348
- _has_printed = False
349
-
350
- def __init__(
351
- self,
352
- num_experts: int,
353
- top_k: int,
354
- hidden_size: int,
355
- intermediate_size: int,
356
- layer_id: int,
357
- num_fused_shared_experts: int = 0,
358
- params_dtype: Optional[torch.dtype] = None,
359
- quant_config: Optional[QuantizationConfig] = None,
360
- prefix: str = "",
361
- activation: str = "silu",
362
- routed_scaling_factor: Optional[float] = None,
363
- ):
364
- super().__init__(
365
- num_experts=num_experts,
366
- top_k=top_k,
367
- hidden_size=hidden_size,
368
- intermediate_size=intermediate_size,
369
- layer_id=layer_id,
370
- num_fused_shared_experts=num_fused_shared_experts,
371
- params_dtype=params_dtype,
372
- quant_config=quant_config,
373
- prefix=prefix,
374
- activation=activation,
375
- routed_scaling_factor=routed_scaling_factor,
376
- )
377
102
  self.deepep_mode = get_deepep_mode()
378
103
 
379
104
  # TODO: move to the beginning of the file
@@ -567,7 +292,6 @@ class DeepEPMoE(EPMoE):
567
292
  N = self.w13_weight.size(1)
568
293
  scale_block_size = 128
569
294
 
570
- # TODO also unify other branches (e.g. `EPMoE.forward_deepgemm` sets the field on forward pass)
571
295
  w13_weight_fp8 = (
572
296
  self.w13_weight,
573
297
  (
@@ -988,8 +712,6 @@ def get_moe_impl_class(quant_config: Optional[QuantizationConfig]):
988
712
  return FlashInferFusedMoE
989
713
  if get_moe_runner_backend().is_flashinfer_cutlass():
990
714
  return FusedMoE
991
- if get_moe_expert_parallel_world_size() > 1:
992
- return EPMoE
993
715
  return FusedMoE
994
716
 
995
717
 
@@ -156,8 +156,7 @@ class FusedMoE(torch.nn.Module):
156
156
  self.moe_tp_rank = get_moe_tensor_parallel_rank()
157
157
  assert num_experts % self.moe_ep_size == 0
158
158
  self.num_local_experts = num_experts // self.moe_ep_size
159
- self.start_expert_id = self.moe_ep_rank * self.num_local_experts
160
- self.end_expert_id = self.start_expert_id + self.num_local_experts - 1
159
+
161
160
  if self.moe_ep_size > 1:
162
161
  # TODO(ch-wan): support shared experts fusion
163
162
  # Create a tensor of size num_experts filled with -1
@@ -207,15 +206,11 @@ class FusedMoE(torch.nn.Module):
207
206
  gemm1_clamp_limit=gemm1_clamp_limit,
208
207
  )
209
208
 
210
- if quant_config is None:
211
- self.quant_method: FusedMoEMethodBase = UnquantizedFusedMoEMethod(
212
- self.use_triton_kernels
213
- )
214
- else:
215
- self.quant_method: FusedMoEMethodBase = quant_config.get_quant_method(
216
- self, prefix
217
- )
218
- assert self.quant_method is not None
209
+ self.quant_method: Optional[FusedMoEMethodBase] = None
210
+ if quant_config is not None:
211
+ self.quant_method = quant_config.get_quant_method(self, prefix)
212
+ if self.quant_method is None:
213
+ self.quant_method = UnquantizedFusedMoEMethod(self.use_triton_kernels)
219
214
 
220
215
  self.quant_method.create_weights(
221
216
  layer=self,
@@ -0,0 +1,304 @@
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+ from typing import TYPE_CHECKING, List, Optional
5
+
6
+ import torch
7
+
8
+ from sglang.srt.layers.moe.moe_runner.base import (
9
+ MoeQuantInfo,
10
+ MoeRunnerConfig,
11
+ MoeRunnerCore,
12
+ RunnerInput,
13
+ RunnerOutput,
14
+ register_post_permute,
15
+ register_pre_permute,
16
+ )
17
+ from sglang.srt.layers.moe.utils import MoeRunnerBackend
18
+ from sglang.srt.utils import dispose_tensor
19
+
20
+ if TYPE_CHECKING:
21
+ from sglang.srt.layers.moe.token_dispatcher.standard import (
22
+ StandardCombineInput,
23
+ StandardDispatchOutput,
24
+ )
25
+
26
+
27
+ # TODO(kaixih@nvidia): ideally we should merge this logic into
28
+ # `fill_gateup_input_triton_kernel` to directly generate e8m0 scale.
29
+ @torch.compile
30
+ def _cast_to_e8m0_with_rounding_up(x: torch.Tensor) -> torch.Tensor:
31
+ temp = x.to(torch.float32).view(torch.int32)
32
+ exp = torch.bitwise_right_shift(temp, 23)
33
+ mant = torch.bitwise_and(temp, 0x7FFFFF)
34
+ is_ru = torch.logical_and(
35
+ torch.logical_and((mant > 0), (exp != 0xFE)),
36
+ ~torch.logical_and((exp == 0), (mant <= 0x400000)),
37
+ )
38
+ exp = torch.where(is_ru, exp + 1, exp)
39
+ new_x = exp.to(torch.uint8).view(torch.int)
40
+ return new_x.transpose(1, 2).contiguous().transpose(1, 2)
41
+
42
+
43
+ @dataclass
44
+ class DeepGemmRunnerInput(RunnerInput):
45
+ hidden_states: torch.Tensor
46
+ hidden_states_scale: torch.Tensor
47
+ masked_m: torch.Tensor
48
+ expected_m: int
49
+ use_masked_gemm: bool
50
+
51
+ @property
52
+ def runner_backend(self) -> MoeRunnerBackend:
53
+ return MoeRunnerBackend.DEEP_GEMM
54
+
55
+
56
+ @dataclass
57
+ class DeepGemmRunnerOutput(RunnerOutput):
58
+ hidden_states: torch.Tensor
59
+
60
+ @property
61
+ def runner_backend(self) -> MoeRunnerBackend:
62
+ return MoeRunnerBackend.DEEP_GEMM
63
+
64
+
65
+ @dataclass
66
+ class DeepGemmMoeQuantInfo(MoeQuantInfo):
67
+ w13_weight: torch.Tensor
68
+ w2_weight: torch.Tensor
69
+ use_fp8: bool
70
+ w13_scale: Optional[torch.Tensor] = None
71
+ w2_scale: Optional[torch.Tensor] = None
72
+ block_shape: Optional[List[int]] = None
73
+
74
+
75
+ class DeepGemmRunnerCore(MoeRunnerCore):
76
+ def __init__(self, config: MoeRunnerConfig):
77
+ super().__init__(config)
78
+ assert self.config.activation == "silu"
79
+
80
+ def run(
81
+ self,
82
+ runner_input: DeepGemmRunnerInput,
83
+ quant_info: DeepGemmMoeQuantInfo,
84
+ running_state: dict,
85
+ ) -> DeepGemmRunnerOutput:
86
+
87
+ if runner_input.use_masked_gemm:
88
+ hidden_states = self._run_masked_gemm(
89
+ runner_input,
90
+ quant_info,
91
+ running_state,
92
+ )
93
+ else:
94
+ hidden_states = self._run_contiguous_gemm(
95
+ runner_input,
96
+ quant_info,
97
+ running_state,
98
+ )
99
+ return DeepGemmRunnerOutput(hidden_states=hidden_states)
100
+
101
+ def _run_masked_gemm(
102
+ self,
103
+ runner_input: DeepGemmRunnerInput,
104
+ quant_info: DeepGemmMoeQuantInfo,
105
+ running_state: dict,
106
+ ) -> torch.Tensor:
107
+
108
+ from sglang.srt.layers.moe.ep_moe.kernels import (
109
+ silu_and_mul_masked_post_quant_fwd,
110
+ )
111
+ from sglang.srt.layers.quantization import deep_gemm_wrapper
112
+
113
+ hidden_states = runner_input.hidden_states
114
+ hidden_states_scale = runner_input.hidden_states_scale
115
+ masked_m = runner_input.masked_m
116
+ expected_m = runner_input.expected_m
117
+
118
+ w13_weight = quant_info.w13_weight
119
+ w2_weight = quant_info.w2_weight
120
+ w13_scale = quant_info.w13_scale
121
+ w2_scale = quant_info.w2_scale
122
+
123
+ hidden_states_device = running_state["hidden_states_device"]
124
+
125
+ if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
126
+ b, s_mn, s_k = hidden_states_scale.shape
127
+ assert (
128
+ s_mn % 4 == 0 and s_k % 4 == 0
129
+ ), f"scales must be aligned to 4, but got ({b}, {s_mn}, {s_k})"
130
+
131
+ # GroupGemm-0
132
+ if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
133
+ hidden_states_scale = _cast_to_e8m0_with_rounding_up(hidden_states_scale)
134
+ else:
135
+ hidden_states_scale = deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(
136
+ hidden_states_scale
137
+ )
138
+
139
+ num_groups, m, k = hidden_states.shape
140
+ n = w13_weight.size(1)
141
+ gateup_output = torch.empty(
142
+ (num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
143
+ )
144
+ deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
145
+ (hidden_states, hidden_states_scale),
146
+ (w13_weight, w13_scale),
147
+ gateup_output,
148
+ masked_m,
149
+ expected_m,
150
+ )
151
+ dispose_tensor(hidden_states)
152
+
153
+ # Act
154
+ down_input = torch.empty(
155
+ (
156
+ gateup_output.shape[0],
157
+ gateup_output.shape[1],
158
+ gateup_output.shape[2] // 2,
159
+ ),
160
+ device=hidden_states_device,
161
+ dtype=torch.float8_e4m3fn,
162
+ )
163
+ scale_block_size = 128
164
+ down_input_scale = torch.empty(
165
+ (
166
+ gateup_output.shape[0],
167
+ gateup_output.shape[1],
168
+ gateup_output.shape[2] // 2 // scale_block_size,
169
+ ),
170
+ device=hidden_states_device,
171
+ dtype=torch.float32,
172
+ )
173
+ silu_and_mul_masked_post_quant_fwd(
174
+ gateup_output,
175
+ down_input,
176
+ down_input_scale,
177
+ scale_block_size,
178
+ masked_m,
179
+ scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
180
+ )
181
+ del gateup_output
182
+
183
+ # GroupGemm-1
184
+ n = w2_weight.shape[1]
185
+
186
+ if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
187
+ down_input_scale = deep_gemm_wrapper.get_mn_major_tma_aligned_tensor(
188
+ down_input_scale
189
+ )
190
+
191
+ down_output = torch.empty(
192
+ (num_groups, m, n), device=hidden_states_device, dtype=torch.bfloat16
193
+ )
194
+ deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_masked(
195
+ (down_input, down_input_scale),
196
+ (w2_weight, w2_scale),
197
+ down_output,
198
+ masked_m,
199
+ expected_m,
200
+ )
201
+ del down_input
202
+
203
+ return down_output
204
+
205
+ def _run_contiguous_gemm(
206
+ self,
207
+ runner_input: DeepGemmRunnerInput,
208
+ quant_info: DeepGemmMoeQuantInfo,
209
+ running_state: dict,
210
+ ) -> torch.Tensor:
211
+ pass
212
+
213
+ @property
214
+ def runner_backend(self) -> MoeRunnerBackend:
215
+ return MoeRunnerBackend.DEEP_GEMM
216
+
217
+
218
+ @register_pre_permute("standard", "deep_gemm")
219
+ def pre_permute_standard_to_deep_gemm(
220
+ dispatch_output: StandardDispatchOutput,
221
+ quant_info: DeepGemmMoeQuantInfo,
222
+ runner_config: MoeRunnerConfig,
223
+ running_state: dict,
224
+ ) -> DeepGemmRunnerInput:
225
+ from sglang.srt.layers.moe.ep_moe.kernels import moe_ep_deepgemm_preprocess
226
+
227
+ hidden_states, topk_output = dispatch_output
228
+ topk_weights, topk_ids, _ = topk_output
229
+
230
+ hidden_states_shape = hidden_states.shape
231
+ hidden_states_dtype = hidden_states.dtype
232
+ hidden_states_device = hidden_states.device
233
+ hidden_states_ref = hidden_states
234
+
235
+ topk_weights, topk_ids = topk_weights, topk_ids
236
+
237
+ # PreReorder
238
+ masked_m, expected_m, src2dst, hidden_states, hidden_states_scale = (
239
+ moe_ep_deepgemm_preprocess(
240
+ topk_ids,
241
+ runner_config.num_local_experts,
242
+ hidden_states,
243
+ runner_config.top_k,
244
+ quant_info.block_shape,
245
+ )
246
+ )
247
+
248
+ dispose_tensor(hidden_states_ref)
249
+
250
+ running_state["topk_ids"] = topk_ids
251
+ running_state["topk_weights"] = topk_weights
252
+ running_state["hidden_states_shape"] = hidden_states_shape
253
+ running_state["hidden_states_dtype"] = hidden_states_dtype
254
+ running_state["hidden_states_device"] = hidden_states_device
255
+ running_state["src2dst"] = src2dst
256
+
257
+ return DeepGemmRunnerInput(
258
+ hidden_states=hidden_states,
259
+ hidden_states_scale=hidden_states_scale,
260
+ masked_m=masked_m,
261
+ expected_m=expected_m,
262
+ use_masked_gemm=True,
263
+ )
264
+
265
+
266
+ @register_post_permute("deep_gemm", "standard")
267
+ def post_permute_deep_gemm_to_standard(
268
+ runner_output: DeepGemmRunnerOutput,
269
+ quant_info: DeepGemmMoeQuantInfo,
270
+ runner_config: MoeRunnerConfig,
271
+ running_state: dict,
272
+ ) -> StandardCombineInput:
273
+ from sglang.srt.layers.moe.ep_moe.kernels import post_reorder_triton_kernel
274
+ from sglang.srt.layers.moe.token_dispatcher.standard import StandardCombineInput
275
+
276
+ hidden_states_shape = running_state["hidden_states_shape"]
277
+ hidden_states_dtype = running_state["hidden_states_dtype"]
278
+ hidden_states_device = running_state["hidden_states_device"]
279
+ src2dst = running_state["src2dst"]
280
+ topk_ids = running_state["topk_ids"]
281
+ topk_weights = running_state["topk_weights"]
282
+
283
+ output = torch.empty(
284
+ hidden_states_shape, dtype=hidden_states_dtype, device=hidden_states_device
285
+ )
286
+ post_reorder_triton_kernel[(hidden_states_shape[0],)](
287
+ runner_output.hidden_states,
288
+ output,
289
+ src2dst,
290
+ topk_ids,
291
+ topk_weights,
292
+ runner_config.top_k,
293
+ hidden_states_shape[1],
294
+ BLOCK_SIZE=512,
295
+ )
296
+
297
+ dispose_tensor(runner_output.hidden_states)
298
+
299
+ if runner_config.routed_scaling_factor is not None:
300
+ output *= runner_config.routed_scaling_factor
301
+
302
+ return StandardCombineInput(
303
+ hidden_states=output,
304
+ )