sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (168) hide show
  1. sglang/bench_one_batch.py +2 -1
  2. sglang/eval/loogle_eval.py +7 -0
  3. sglang/srt/configs/deepseekvl2.py +11 -2
  4. sglang/srt/configs/internvl.py +3 -0
  5. sglang/srt/configs/janus_pro.py +3 -0
  6. sglang/srt/configs/model_config.py +9 -7
  7. sglang/srt/configs/update_config.py +3 -1
  8. sglang/srt/conversation.py +1 -0
  9. sglang/srt/custom_op.py +5 -2
  10. sglang/srt/disaggregation/decode.py +9 -1
  11. sglang/srt/disaggregation/mooncake/conn.py +44 -56
  12. sglang/srt/distributed/parallel_state.py +33 -0
  13. sglang/srt/entrypoints/engine.py +30 -26
  14. sglang/srt/entrypoints/openai/serving_chat.py +21 -2
  15. sglang/srt/eplb/expert_location_dispatch.py +1 -1
  16. sglang/srt/function_call/function_call_parser.py +2 -0
  17. sglang/srt/function_call/qwen3_detector.py +150 -0
  18. sglang/srt/hf_transformers_utils.py +0 -1
  19. sglang/srt/layers/activation.py +13 -0
  20. sglang/srt/layers/attention/flashattention_backend.py +3 -3
  21. sglang/srt/layers/attention/flashinfer_backend.py +40 -1
  22. sglang/srt/layers/linear.py +13 -102
  23. sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
  24. sglang/srt/layers/moe/ep_moe/layer.py +23 -402
  25. sglang/srt/layers/moe/fused_moe_native.py +7 -47
  26. sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
  27. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  28. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  29. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  30. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  31. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +35 -45
  33. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
  34. sglang/srt/layers/moe/topk.py +187 -12
  35. sglang/srt/layers/quantization/__init__.py +20 -134
  36. sglang/srt/layers/quantization/awq.py +578 -11
  37. sglang/srt/layers/quantization/awq_triton.py +339 -0
  38. sglang/srt/layers/quantization/base_config.py +85 -10
  39. sglang/srt/layers/quantization/blockwise_int8.py +17 -55
  40. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
  41. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +24 -73
  42. sglang/srt/layers/quantization/fp8.py +273 -62
  43. sglang/srt/layers/quantization/fp8_kernel.py +210 -46
  44. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  45. sglang/srt/layers/quantization/gptq.py +501 -143
  46. sglang/srt/layers/quantization/marlin_utils.py +790 -0
  47. sglang/srt/layers/quantization/modelopt_quant.py +26 -108
  48. sglang/srt/layers/quantization/moe_wna16.py +45 -49
  49. sglang/srt/layers/quantization/petit.py +252 -0
  50. sglang/srt/layers/quantization/petit_utils.py +104 -0
  51. sglang/srt/layers/quantization/qoq.py +7 -6
  52. sglang/srt/layers/quantization/scalar_type.py +352 -0
  53. sglang/srt/layers/quantization/unquant.py +422 -0
  54. sglang/srt/layers/quantization/utils.py +343 -3
  55. sglang/srt/layers/quantization/w4afp8.py +8 -4
  56. sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
  57. sglang/srt/layers/quantization/w8a8_int8.py +51 -115
  58. sglang/srt/layers/vocab_parallel_embedding.py +1 -41
  59. sglang/srt/lora/lora.py +0 -4
  60. sglang/srt/lora/lora_manager.py +87 -53
  61. sglang/srt/lora/mem_pool.py +81 -33
  62. sglang/srt/lora/utils.py +12 -5
  63. sglang/srt/managers/cache_controller.py +241 -0
  64. sglang/srt/managers/io_struct.py +41 -29
  65. sglang/srt/managers/mm_utils.py +7 -8
  66. sglang/srt/managers/schedule_batch.py +150 -110
  67. sglang/srt/managers/schedule_policy.py +68 -27
  68. sglang/srt/managers/scheduler.py +243 -61
  69. sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
  70. sglang/srt/managers/tokenizer_manager.py +11 -3
  71. sglang/srt/managers/tp_worker.py +14 -0
  72. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  73. sglang/srt/mem_cache/allocator.py +7 -16
  74. sglang/srt/mem_cache/base_prefix_cache.py +14 -2
  75. sglang/srt/mem_cache/chunk_cache.py +5 -2
  76. sglang/srt/mem_cache/hicache_storage.py +152 -0
  77. sglang/srt/mem_cache/hiradix_cache.py +179 -4
  78. sglang/srt/mem_cache/memory_pool.py +16 -1
  79. sglang/srt/mem_cache/memory_pool_host.py +41 -2
  80. sglang/srt/mem_cache/radix_cache.py +26 -0
  81. sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
  82. sglang/srt/metrics/collector.py +9 -0
  83. sglang/srt/model_executor/cuda_graph_runner.py +5 -6
  84. sglang/srt/model_executor/forward_batch_info.py +14 -1
  85. sglang/srt/model_executor/model_runner.py +109 -22
  86. sglang/srt/model_loader/loader.py +7 -1
  87. sglang/srt/model_loader/utils.py +4 -4
  88. sglang/srt/models/clip.py +1 -1
  89. sglang/srt/models/deepseek.py +9 -6
  90. sglang/srt/models/deepseek_janus_pro.py +1 -1
  91. sglang/srt/models/deepseek_v2.py +191 -171
  92. sglang/srt/models/deepseek_vl2.py +5 -5
  93. sglang/srt/models/gemma.py +48 -0
  94. sglang/srt/models/gemma2.py +52 -0
  95. sglang/srt/models/gemma3_causal.py +63 -0
  96. sglang/srt/models/gemma3_mm.py +1 -1
  97. sglang/srt/models/gemma3n_mm.py +2 -4
  98. sglang/srt/models/granitemoe.py +385 -0
  99. sglang/srt/models/grok.py +9 -3
  100. sglang/srt/models/hunyuan.py +63 -16
  101. sglang/srt/models/internvl.py +1 -1
  102. sglang/srt/models/kimi_vl.py +1 -1
  103. sglang/srt/models/llama.py +41 -0
  104. sglang/srt/models/llama4.py +11 -11
  105. sglang/srt/models/llava.py +2 -2
  106. sglang/srt/models/llavavid.py +1 -1
  107. sglang/srt/models/minicpm.py +0 -2
  108. sglang/srt/models/minicpmo.py +3 -7
  109. sglang/srt/models/minicpmv.py +1 -1
  110. sglang/srt/models/mistral.py +1 -1
  111. sglang/srt/models/mixtral.py +9 -2
  112. sglang/srt/models/mllama.py +3 -5
  113. sglang/srt/models/mllama4.py +3 -3
  114. sglang/srt/models/olmoe.py +8 -5
  115. sglang/srt/models/persimmon.py +330 -0
  116. sglang/srt/models/phi.py +321 -0
  117. sglang/srt/models/phi4mm.py +44 -4
  118. sglang/srt/models/phi4mm_audio.py +1260 -0
  119. sglang/srt/models/phi4mm_utils.py +1917 -0
  120. sglang/srt/models/phimoe.py +9 -3
  121. sglang/srt/models/qwen.py +37 -0
  122. sglang/srt/models/qwen2.py +41 -0
  123. sglang/srt/models/qwen2_5_vl.py +4 -4
  124. sglang/srt/models/qwen2_audio.py +1 -1
  125. sglang/srt/models/qwen2_moe.py +53 -5
  126. sglang/srt/models/qwen2_vl.py +4 -4
  127. sglang/srt/models/qwen3.py +65 -1
  128. sglang/srt/models/qwen3_moe.py +56 -18
  129. sglang/srt/models/vila.py +1 -1
  130. sglang/srt/multimodal/processors/base_processor.py +91 -97
  131. sglang/srt/multimodal/processors/clip.py +21 -19
  132. sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
  133. sglang/srt/multimodal/processors/gemma3.py +13 -17
  134. sglang/srt/multimodal/processors/gemma3n.py +19 -23
  135. sglang/srt/multimodal/processors/internvl.py +9 -10
  136. sglang/srt/multimodal/processors/janus_pro.py +12 -27
  137. sglang/srt/multimodal/processors/kimi_vl.py +12 -14
  138. sglang/srt/multimodal/processors/llava.py +4 -2
  139. sglang/srt/multimodal/processors/minicpm.py +35 -44
  140. sglang/srt/multimodal/processors/mlama.py +21 -18
  141. sglang/srt/multimodal/processors/mllama4.py +4 -5
  142. sglang/srt/multimodal/processors/phi4mm.py +63 -39
  143. sglang/srt/multimodal/processors/pixtral.py +14 -35
  144. sglang/srt/multimodal/processors/qwen_audio.py +65 -0
  145. sglang/srt/multimodal/processors/qwen_vl.py +16 -21
  146. sglang/srt/multimodal/processors/vila.py +14 -14
  147. sglang/srt/sampling/sampling_params.py +8 -1
  148. sglang/srt/server_args.py +393 -230
  149. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +9 -1
  150. sglang/srt/two_batch_overlap.py +1 -0
  151. sglang/srt/utils.py +27 -1
  152. sglang/test/runners.py +14 -3
  153. sglang/test/test_block_fp8.py +8 -3
  154. sglang/test/test_block_fp8_ep.py +1 -1
  155. sglang/test/test_custom_ops.py +12 -7
  156. sglang/test/test_cutlass_w4a8_moe.py +1 -3
  157. sglang/test/test_fp4_moe.py +1 -3
  158. sglang/test/test_marlin_moe.py +286 -0
  159. sglang/test/test_marlin_utils.py +171 -0
  160. sglang/test/test_utils.py +35 -0
  161. sglang/version.py +1 -1
  162. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/METADATA +8 -8
  163. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/RECORD +166 -146
  164. sglang/srt/layers/quantization/quant_utils.py +0 -166
  165. sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
  166. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/WHEEL +0 -0
  167. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/licenses/LICENSE +0 -0
  168. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/top_level.txt +0 -0
@@ -1,17 +1,13 @@
1
1
  import logging
2
- from typing import Callable, List, Optional, Tuple
2
+ from typing import List, Optional, Tuple
3
3
 
4
- import einops
5
4
  import torch
6
- from torch.nn import Module
7
5
 
8
- from sglang.srt.custom_op import CustomOp
9
6
  from sglang.srt.distributed import (
10
7
  get_tensor_model_parallel_rank,
11
8
  get_tensor_model_parallel_world_size,
12
9
  )
13
10
  from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
14
- from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
15
11
  from sglang.srt.layers.moe.ep_moe.kernels import (
16
12
  ep_gather,
17
13
  ep_scatter,
@@ -27,22 +23,20 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
27
23
  silu_and_mul_triton_kernel,
28
24
  tma_align_input_scale,
29
25
  )
30
- from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
31
- from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE, FusedMoEMethodBase
32
- from sglang.srt.layers.moe.topk import select_experts
26
+ from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
27
+ from sglang.srt.layers.moe.topk import TopKOutput
33
28
  from sglang.srt.layers.quantization import deep_gemm_wrapper
34
29
  from sglang.srt.layers.quantization.base_config import (
35
30
  QuantizationConfig,
36
31
  QuantizeMethodBase,
37
32
  )
38
- from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
33
+ from sglang.srt.layers.quantization.fp8 import Fp8EPMoEMethod
39
34
  from sglang.srt.layers.quantization.fp8_kernel import (
40
35
  is_fp8_fnuz,
41
- scaled_fp8_quant,
42
36
  sglang_per_token_group_quant_fp8,
43
37
  sglang_per_token_quant_fp8,
44
38
  )
45
- from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
39
+ from sglang.srt.layers.quantization.unquant import UnquantizedEPMoEMethod
46
40
  from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config, W4AFp8MoEMethod
47
41
  from sglang.srt.managers.schedule_batch import global_server_args_dict
48
42
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
@@ -53,7 +47,6 @@ from sglang.srt.utils import (
53
47
  get_bool_env_var,
54
48
  is_hip,
55
49
  is_npu,
56
- set_weight_attrs,
57
50
  )
58
51
 
59
52
  _is_hip = is_hip()
@@ -61,14 +54,11 @@ _is_npu = is_npu()
61
54
  _is_fp8_fnuz = is_fp8_fnuz()
62
55
  _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
63
56
 
64
- if not _is_npu:
57
+ if not (_is_npu or _is_hip):
65
58
  from sgl_kernel import silu_and_mul
66
59
 
67
60
  from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
68
61
 
69
- if _is_hip:
70
- from vllm._custom_ops import scaled_fp8_quant
71
-
72
62
  if _use_aiter:
73
63
  from aiter import ActivationType, QuantType
74
64
  from aiter.fused_moe import fused_moe
@@ -165,16 +155,9 @@ class EPMoE(torch.nn.Module):
165
155
  intermediate_size: int,
166
156
  layer_id: int,
167
157
  params_dtype: Optional[torch.dtype] = None,
168
- renormalize: bool = True,
169
- use_grouped_topk: bool = False,
170
- num_expert_group: Optional[int] = None,
171
- num_fused_shared_experts: int = 0,
172
- topk_group: Optional[int] = None,
173
158
  quant_config: Optional[QuantizationConfig] = None,
174
159
  tp_size: Optional[int] = None,
175
160
  prefix: str = "",
176
- correction_bias: Optional[torch.Tensor] = None,
177
- custom_routing_function: Optional[Callable] = None,
178
161
  activation: str = "silu",
179
162
  routed_scaling_factor: Optional[float] = None,
180
163
  use_per_token_if_dynamic: bool = True,
@@ -192,24 +175,12 @@ class EPMoE(torch.nn.Module):
192
175
  self.layer_id = layer_id
193
176
  self.num_experts = num_experts
194
177
  assert self.num_experts % self.tp_size == 0
195
- assert (
196
- num_fused_shared_experts == 0
197
- ), "num_fused_shared_experts is not supported in EP"
198
- self.num_fused_shared_experts = num_fused_shared_experts
199
178
  self.num_experts_per_partition, self.expert_map = self.determine_expert_map()
200
179
  self.start_expert_id = self.tp_rank * self.num_experts_per_partition
201
180
  self.end_expert_id = self.start_expert_id + self.num_experts_per_partition - 1
202
181
 
203
182
  self.top_k = top_k
204
183
  self.intermediate_size = intermediate_size
205
- self.renormalize = renormalize
206
- self.use_grouped_topk = use_grouped_topk
207
- if self.use_grouped_topk:
208
- assert num_expert_group is not None and topk_group is not None
209
- self.num_expert_group = num_expert_group
210
- self.topk_group = topk_group
211
- self.correction_bias = correction_bias
212
- self.custom_routing_function = custom_routing_function
213
184
  self.activation = activation
214
185
  self.routed_scaling_factor = routed_scaling_factor
215
186
  self.use_per_token_if_dynamic = use_per_token_if_dynamic
@@ -314,33 +285,24 @@ class EPMoE(torch.nn.Module):
314
285
  )
315
286
  return (local_num_experts, expert_map)
316
287
 
317
- def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
288
+ def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
318
289
  if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8:
319
- return self.forward_deepgemm(hidden_states, router_logits)
290
+ return self.forward_deepgemm(hidden_states, topk_output)
320
291
  else:
321
- return self.forward_normal(hidden_states, router_logits)
292
+ return self.forward_normal(hidden_states, topk_output)
322
293
 
323
294
  def forward_deepgemm(
324
- self, hidden_states: torch.Tensor, router_logits: torch.Tensor
295
+ self,
296
+ hidden_states: torch.Tensor,
297
+ topk_output: TopKOutput,
325
298
  ):
326
299
  assert self.quant_method is not None
327
300
  assert self.activation == "silu"
328
301
  hidden_states_shape = hidden_states.shape
329
302
  hidden_states_dtype = hidden_states.dtype
330
303
  hidden_states_device = hidden_states.device
331
- topk_weights, topk_ids = select_experts(
332
- hidden_states=hidden_states,
333
- router_logits=router_logits,
334
- top_k=self.top_k,
335
- use_grouped_topk=self.use_grouped_topk,
336
- renormalize=self.renormalize,
337
- topk_group=self.topk_group,
338
- num_expert_group=self.num_expert_group,
339
- num_fused_shared_experts=self.num_fused_shared_experts,
340
- correction_bias=self.correction_bias,
341
- custom_routing_function=self.custom_routing_function,
342
- routed_scaling_factor=self.routed_scaling_factor,
343
- )
304
+
305
+ topk_weights, topk_ids, _ = topk_output
344
306
 
345
307
  if not self.use_block_quant:
346
308
  # Convert per-tensor quant to per-block quant by repeating scales for forward_deepgemm
@@ -472,8 +434,10 @@ class EPMoE(torch.nn.Module):
472
434
  )
473
435
  return output
474
436
 
475
- def forward_normal(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
437
+ def forward_normal(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
476
438
  assert self.quant_method is not None
439
+ topk_weights, topk_ids, _ = topk_output
440
+
477
441
  hidden_states_shape = hidden_states.shape
478
442
  hidden_states_dtype = hidden_states.dtype
479
443
  hidden_states_device = hidden_states.device
@@ -484,23 +448,6 @@ class EPMoE(torch.nn.Module):
484
448
  use_per_token_if_dynamic=self.use_per_token_if_dynamic,
485
449
  )
486
450
 
487
- topk_weights, topk_ids = select_experts(
488
- hidden_states=hidden_states,
489
- router_logits=router_logits,
490
- top_k=self.top_k,
491
- use_grouped_topk=self.use_grouped_topk,
492
- renormalize=self.renormalize,
493
- topk_group=self.topk_group,
494
- num_expert_group=self.num_expert_group,
495
- num_fused_shared_experts=self.num_fused_shared_experts,
496
- correction_bias=self.correction_bias,
497
- custom_routing_function=self.custom_routing_function,
498
- routed_scaling_factor=self.routed_scaling_factor,
499
- expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
500
- layer_id=self.layer_id,
501
- ),
502
- )
503
-
504
451
  if self.use_w4afp8:
505
452
  local_topk_ids = topk_ids
506
453
  if self.expert_map is not None:
@@ -904,324 +851,6 @@ class EPMoE(torch.nn.Module):
904
851
  param_data[expert_id] = loaded_weight
905
852
 
906
853
 
907
- class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp):
908
-
909
- def create_weights(
910
- self,
911
- layer: torch.nn.Module,
912
- num_experts_per_partition: int,
913
- hidden_size: int,
914
- intermediate_size: int,
915
- params_dtype: torch.dtype,
916
- **extra_weight_attrs,
917
- ):
918
- # Fused gate_up_proj (column parallel)
919
- w13_weight = torch.nn.Parameter(
920
- torch.empty(
921
- num_experts_per_partition,
922
- 2 * intermediate_size,
923
- hidden_size,
924
- dtype=params_dtype,
925
- ),
926
- requires_grad=False,
927
- )
928
- layer.register_parameter("w13_weight", w13_weight)
929
- set_weight_attrs(w13_weight, extra_weight_attrs)
930
-
931
- # down_proj (row parallel)
932
- w2_weight = torch.nn.Parameter(
933
- torch.empty(
934
- num_experts_per_partition,
935
- hidden_size,
936
- intermediate_size,
937
- dtype=params_dtype,
938
- ),
939
- requires_grad=False,
940
- )
941
- layer.register_parameter("w2_weight", w2_weight)
942
- set_weight_attrs(w2_weight, extra_weight_attrs)
943
-
944
- # scale
945
- layer.register_parameter("w13_input_scale", None)
946
- layer.register_parameter("w13_weight_scale", None)
947
-
948
- ones_tensor = torch.ones(num_experts_per_partition, dtype=torch.float32)
949
-
950
- w2_input_scale = torch.nn.Parameter(
951
- ones_tensor,
952
- requires_grad=False,
953
- )
954
- layer.register_parameter("w2_input_scale", w2_input_scale)
955
- set_weight_attrs(w2_input_scale, extra_weight_attrs)
956
-
957
- w2_weight_scale = torch.nn.Parameter(
958
- ones_tensor,
959
- requires_grad=False,
960
- )
961
- layer.register_parameter("w2_weight_scale", w2_weight_scale)
962
- set_weight_attrs(w2_weight_scale, extra_weight_attrs)
963
-
964
- def apply(
965
- self,
966
- layer: torch.nn.Module,
967
- x: torch.Tensor,
968
- router_logits: torch.Tensor,
969
- top_k: int,
970
- renormalize: bool,
971
- use_grouped_topk: bool,
972
- topk_group: Optional[int] = None,
973
- num_expert_group: Optional[int] = None,
974
- custom_routing_function: Optional[Callable] = None,
975
- ) -> torch.Tensor:
976
- raise NotImplementedError
977
-
978
-
979
- class Fp8EPMoEMethod(Fp8MoEMethod):
980
- """MoE method for FP8.
981
- Supports loading FP8 checkpoints with static weight scale and
982
- dynamic/static activation scale.
983
-
984
- Args:
985
- quant_config: The quantization config.
986
- """
987
-
988
- def __init__(self, quant_config: Fp8Config):
989
- self.quant_config = quant_config
990
- self.block_quant = self.quant_config.weight_block_size is not None
991
-
992
- def create_weights(
993
- self,
994
- layer: Module,
995
- num_experts_per_partition: int,
996
- hidden_size: int,
997
- intermediate_size: int,
998
- params_dtype: torch.dtype,
999
- **extra_weight_attrs,
1000
- ):
1001
- if self.quant_config.is_checkpoint_fp8_serialized:
1002
- params_dtype = torch.float8_e4m3fn
1003
-
1004
- tp_size = get_tensor_model_parallel_world_size()
1005
- if self.block_quant:
1006
- block_n, block_k = (
1007
- self.quant_config.weight_block_size[0],
1008
- self.quant_config.weight_block_size[1],
1009
- )
1010
- # NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
1011
- # Required by column parallel or enabling merged weights
1012
- if intermediate_size % block_n != 0:
1013
- raise ValueError(
1014
- f"The output_size of gate's and up's weight = "
1015
- f"{intermediate_size} is not divisible by "
1016
- f"weight quantization block_n = {block_n}."
1017
- )
1018
- if tp_size > 1:
1019
- # Required by row parallel
1020
- if intermediate_size % block_k != 0:
1021
- raise ValueError(
1022
- f"The input_size of down's weight = "
1023
- f"{intermediate_size} is not divisible by "
1024
- f"weight quantization block_k = {block_k}."
1025
- )
1026
-
1027
- # WEIGHTS
1028
- w13_weight = torch.nn.Parameter(
1029
- torch.empty(
1030
- num_experts_per_partition,
1031
- 2 * intermediate_size,
1032
- hidden_size,
1033
- dtype=params_dtype,
1034
- ),
1035
- requires_grad=False,
1036
- )
1037
- layer.register_parameter("w13_weight", w13_weight)
1038
- set_weight_attrs(w13_weight, extra_weight_attrs)
1039
-
1040
- w2_weight = torch.nn.Parameter(
1041
- torch.empty(
1042
- num_experts_per_partition,
1043
- hidden_size,
1044
- intermediate_size,
1045
- dtype=params_dtype,
1046
- ),
1047
- requires_grad=False,
1048
- )
1049
- layer.register_parameter("w2_weight", w2_weight)
1050
- set_weight_attrs(w2_weight, extra_weight_attrs)
1051
-
1052
- # WEIGHT_SCALES
1053
- if self.block_quant:
1054
- w13_weight_scale = torch.nn.Parameter(
1055
- torch.ones(
1056
- num_experts_per_partition,
1057
- 2 * ((intermediate_size + block_n - 1) // block_n),
1058
- (hidden_size + block_k - 1) // block_k,
1059
- dtype=torch.float32,
1060
- ),
1061
- requires_grad=False,
1062
- )
1063
- w2_weight_scale = torch.nn.Parameter(
1064
- torch.ones(
1065
- num_experts_per_partition,
1066
- (hidden_size + block_n - 1) // block_n,
1067
- (intermediate_size + block_k - 1) // block_k,
1068
- dtype=torch.float32,
1069
- ),
1070
- requires_grad=False,
1071
- )
1072
- layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
1073
- layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
1074
- assert self.quant_config.activation_scheme == "dynamic"
1075
- else:
1076
- # WEIGHT_SCALES
1077
- # Allocate 2 scales for w1 and w3 respectively.
1078
- w13_weight_scale = torch.nn.Parameter(
1079
- torch.ones(num_experts_per_partition, 2, dtype=torch.float32),
1080
- requires_grad=False,
1081
- )
1082
- layer.register_parameter("w13_weight_scale", w13_weight_scale)
1083
-
1084
- w2_weight_scale = torch.nn.Parameter(
1085
- torch.ones(num_experts_per_partition, dtype=torch.float32),
1086
- requires_grad=False,
1087
- )
1088
- layer.register_parameter("w2_weight_scale", w2_weight_scale)
1089
- # Add the quantization method used (per tensor/grouped/channel)
1090
- # to ensure the weight scales are loaded in properly
1091
- extra_weight_attrs.update(
1092
- {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
1093
- if self.block_quant
1094
- else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
1095
- )
1096
- # If loading fp8 checkpoint, pass the weight loaders.
1097
- # If loading an fp16 checkpoint, do not (we will quantize in
1098
- # process_weights_after_loading()
1099
- if self.quant_config.is_checkpoint_fp8_serialized:
1100
- set_weight_attrs(w13_weight_scale, extra_weight_attrs)
1101
- set_weight_attrs(w2_weight_scale, extra_weight_attrs)
1102
-
1103
- # INPUT_SCALES
1104
- if self.quant_config.activation_scheme == "static":
1105
- if not self.quant_config.is_checkpoint_fp8_serialized:
1106
- raise ValueError(
1107
- "Found static activation scheme for checkpoint that "
1108
- "was not serialized fp8."
1109
- )
1110
-
1111
- w13_input_scale = torch.nn.Parameter(
1112
- torch.ones(num_experts_per_partition, dtype=torch.float32),
1113
- requires_grad=False,
1114
- )
1115
- layer.register_parameter("w13_input_scale", w13_input_scale)
1116
- set_weight_attrs(w13_input_scale, extra_weight_attrs)
1117
-
1118
- w2_input_scale = torch.nn.Parameter(
1119
- torch.ones(num_experts_per_partition, dtype=torch.float32),
1120
- requires_grad=False,
1121
- )
1122
- layer.register_parameter("w2_input_scale", w2_input_scale)
1123
- set_weight_attrs(w2_input_scale, extra_weight_attrs)
1124
-
1125
- else:
1126
- layer.w13_input_scale = None
1127
- layer.w2_input_scale = None
1128
-
1129
- def process_weights_after_loading(self, layer: Module) -> None:
1130
-
1131
- # If checkpoint is fp16, quantize in place.
1132
- if not self.quant_config.is_checkpoint_fp8_serialized:
1133
- # If rocm, use float8_e4m3fnuz as dtype
1134
- fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
1135
- w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
1136
- w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
1137
-
1138
- layer.w13_weight_scale = torch.nn.Parameter(
1139
- torch.ones(
1140
- layer.num_experts_per_partition,
1141
- dtype=torch.float32,
1142
- device=w13_weight.device,
1143
- ),
1144
- requires_grad=False,
1145
- )
1146
-
1147
- for expert in range(layer.num_experts_per_partition):
1148
- w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
1149
- scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
1150
- )
1151
- w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
1152
- scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
1153
- )
1154
- layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
1155
- layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
1156
- return
1157
-
1158
- # If checkpoint is fp8, we need to handle that the
1159
- # MoE kernels require single activation scale and single weight
1160
- # scale for w13 per expert.
1161
- else:
1162
- if self.quant_config.activation_scheme == "static":
1163
- if layer.w13_input_scale is None or layer.w2_input_scale is None:
1164
- raise ValueError(
1165
- "QuantConfig has static quantization, but found "
1166
- "activation scales are None."
1167
- )
1168
- layer.w13_weight_scale = torch.nn.Parameter(
1169
- torch.max(layer.w13_weight_scale, dim=1).values,
1170
- requires_grad=False,
1171
- )
1172
- if self.block_quant:
1173
- # If ROCm, normalize the weights and scales to e4m3fnuz
1174
- if _is_fp8_fnuz:
1175
- # activation_scheme: dynamic
1176
- w13_weight, w13_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
1177
- weight=layer.w13_weight,
1178
- weight_scale=layer.w13_weight_scale_inv,
1179
- input_scale=None,
1180
- )
1181
- w2_weight, w2_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
1182
- weight=layer.w2_weight,
1183
- weight_scale=layer.w2_weight_scale_inv,
1184
- input_scale=None,
1185
- )
1186
- # Reset the parameter
1187
- layer.w13_weight = torch.nn.Parameter(
1188
- w13_weight, requires_grad=False
1189
- )
1190
- layer.w13_weight_scale_inv = torch.nn.Parameter(
1191
- w13_weight_scale, requires_grad=False
1192
- )
1193
- layer.w13_input_scale = None
1194
- layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
1195
- layer.w2_weight_scale_inv = torch.nn.Parameter(
1196
- w2_weight_scale, requires_grad=False
1197
- )
1198
- layer.w2_input_scale = None
1199
- if _use_aiter:
1200
- layer.w13_weight = torch.nn.Parameter(
1201
- shuffle_weight(layer.w13_weight.data, (16, 16)),
1202
- requires_grad=False,
1203
- )
1204
- layer.w2_weight = torch.nn.Parameter(
1205
- shuffle_weight(layer.w2_weight.data, (16, 16)),
1206
- requires_grad=False,
1207
- )
1208
- return
1209
-
1210
- def apply(
1211
- self,
1212
- layer: torch.nn.Module,
1213
- x: torch.Tensor,
1214
- router_logits: torch.Tensor,
1215
- top_k: int,
1216
- renormalize: bool,
1217
- use_grouped_topk: bool,
1218
- topk_group: Optional[int] = None,
1219
- num_expert_group: Optional[int] = None,
1220
- custom_routing_function: Optional[Callable] = None,
1221
- ) -> torch.Tensor:
1222
- raise NotImplementedError
1223
-
1224
-
1225
854
  class DeepEPMoE(EPMoE):
1226
855
  """
1227
856
  MoE Expert Parallel Impl based on DeepEP (https://github.com/deepseek-ai/DeepEP/tree/main)
@@ -1237,16 +866,9 @@ class DeepEPMoE(EPMoE):
1237
866
  intermediate_size: int,
1238
867
  layer_id: int,
1239
868
  params_dtype: Optional[torch.dtype] = None,
1240
- renormalize: bool = True,
1241
- use_grouped_topk: bool = False,
1242
- num_expert_group: Optional[int] = None,
1243
- num_fused_shared_experts: int = 0,
1244
- topk_group: Optional[int] = None,
1245
869
  quant_config: Optional[QuantizationConfig] = None,
1246
870
  tp_size: Optional[int] = None,
1247
871
  prefix: str = "",
1248
- correction_bias: Optional[torch.Tensor] = None,
1249
- custom_routing_function: Optional[Callable] = None,
1250
872
  activation: str = "silu",
1251
873
  routed_scaling_factor: Optional[float] = None,
1252
874
  deepep_mode: DeepEPMode = DeepEPMode.auto,
@@ -1258,20 +880,19 @@ class DeepEPMoE(EPMoE):
1258
880
  intermediate_size=intermediate_size,
1259
881
  layer_id=layer_id,
1260
882
  params_dtype=params_dtype,
1261
- renormalize=renormalize,
1262
- use_grouped_topk=use_grouped_topk,
1263
- num_expert_group=num_expert_group,
1264
- num_fused_shared_experts=num_fused_shared_experts,
1265
- topk_group=topk_group,
1266
883
  quant_config=quant_config,
1267
884
  tp_size=tp_size,
1268
885
  prefix=prefix,
1269
- correction_bias=correction_bias,
1270
- custom_routing_function=custom_routing_function,
1271
886
  activation=activation,
1272
887
  routed_scaling_factor=routed_scaling_factor,
1273
888
  )
1274
889
  self.deepep_mode = deepep_mode
890
+ if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
891
+ assert self.use_fp8_w8a8, (
892
+ "DeepGEMM requires an fp8_w8a8 model; "
893
+ "alternatively, you can disable DeepGEMM by turning off the ENABLE_JIT_DEEPGEMM environment variable."
894
+ )
895
+
1275
896
  if self.deepep_mode.enable_low_latency():
1276
897
  assert (
1277
898
  deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM
@@ -9,21 +9,14 @@ import torch
9
9
  from torch.nn import functional as F
10
10
 
11
11
  from sglang.srt.layers.activation import GeluAndMul, SiluAndMul
12
- from sglang.srt.layers.moe.topk import select_experts
12
+ from sglang.srt.layers.moe.topk import TopKOutput
13
13
 
14
14
 
15
15
  def fused_moe_forward_native(
16
16
  layer: torch.nn.Module,
17
17
  x: torch.Tensor,
18
- use_grouped_topk: bool,
19
- top_k: int,
20
- router_logits: torch.Tensor,
21
- renormalize: bool,
22
- topk_group: Optional[int] = None,
23
- num_expert_group: Optional[int] = None,
24
- num_fused_shared_experts: int = 0,
25
- custom_routing_function: Optional[Callable] = None,
26
- correction_bias: Optional[torch.Tensor] = None,
18
+ topk_output: TopKOutput,
19
+ *,
27
20
  activation: str = "silu",
28
21
  apply_router_weight_on_input: bool = False,
29
22
  inplace: bool = True,
@@ -34,20 +27,7 @@ def fused_moe_forward_native(
34
27
  if apply_router_weight_on_input:
35
28
  raise NotImplementedError()
36
29
 
37
- topk_weights, topk_ids = select_experts(
38
- hidden_states=x,
39
- router_logits=router_logits,
40
- use_grouped_topk=use_grouped_topk,
41
- top_k=top_k,
42
- renormalize=renormalize,
43
- topk_group=topk_group,
44
- num_expert_group=num_expert_group,
45
- num_fused_shared_experts=num_fused_shared_experts,
46
- custom_routing_function=custom_routing_function,
47
- correction_bias=correction_bias,
48
- routed_scaling_factor=routed_scaling_factor,
49
- torch_native=True,
50
- )
30
+ topk_weights, topk_ids, _ = topk_output
51
31
 
52
32
  w13_weights = layer.w13_weight[topk_ids]
53
33
  w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2)
@@ -67,15 +47,8 @@ def fused_moe_forward_native(
67
47
  def moe_forward_native(
68
48
  layer: torch.nn.Module,
69
49
  x: torch.Tensor,
70
- use_grouped_topk: bool,
71
- top_k: int,
72
- router_logits: torch.Tensor,
73
- renormalize: bool,
74
- topk_group: Optional[int] = None,
75
- num_expert_group: Optional[int] = None,
76
- num_fused_shared_experts: int = 0,
77
- custom_routing_function: Optional[Callable] = None,
78
- correction_bias: Optional[torch.Tensor] = None,
50
+ topk_output: TopKOutput,
51
+ *,
79
52
  activation: str = "silu",
80
53
  apply_router_weight_on_input: bool = False,
81
54
  inplace: bool = True,
@@ -86,20 +59,7 @@ def moe_forward_native(
86
59
  if apply_router_weight_on_input:
87
60
  raise NotImplementedError()
88
61
 
89
- topk_weights, topk_ids = select_experts(
90
- hidden_states=x,
91
- router_logits=router_logits,
92
- use_grouped_topk=use_grouped_topk,
93
- top_k=top_k,
94
- renormalize=renormalize,
95
- topk_group=topk_group,
96
- num_expert_group=num_expert_group,
97
- num_fused_shared_experts=num_fused_shared_experts,
98
- custom_routing_function=custom_routing_function,
99
- correction_bias=correction_bias,
100
- torch_native=True,
101
- routed_scaling_factor=routed_scaling_factor,
102
- )
62
+ topk_weights, topk_ids, _ = topk_output
103
63
 
104
64
  # Ref code from https://huggingface.co/deepseek-ai/DeepSeek-V2/blob/e0828e3cc0a03408724b80c3cc92c8e072db8d01/modeling_deepseek.py#L589
105
65
  len_experts = layer.num_experts
@@ -1,14 +1,14 @@
1
1
  from contextlib import contextmanager
2
2
  from typing import Any, Dict, Optional
3
3
 
4
- import sglang.srt.layers.moe.fused_moe_triton.fused_moe # noqa
5
4
  from sglang.srt.layers.moe.fused_moe_triton.fused_moe import (
6
5
  fused_experts,
7
6
  get_config_file_name,
7
+ moe_align_block_size,
8
+ try_get_optimal_moe_config,
8
9
  )
9
10
  from sglang.srt.layers.moe.fused_moe_triton.layer import (
10
11
  FusedMoE,
11
- FusedMoEMethodBase,
12
12
  FusedMoeWeightScaleSupported,
13
13
  )
14
14
 
@@ -30,11 +30,11 @@ def get_config() -> Optional[Dict[str, Any]]:
30
30
 
31
31
  __all__ = [
32
32
  "FusedMoE",
33
- "FusedMoEMethodBase",
34
33
  "FusedMoeWeightScaleSupported",
35
34
  "override_config",
36
35
  "get_config",
37
- "fused_moe",
38
36
  "fused_experts",
39
37
  "get_config_file_name",
38
+ "moe_align_block_size",
39
+ "try_get_optimal_moe_config",
40
40
  ]