sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. sglang/bench_one_batch.py +2 -1
  2. sglang/eval/loogle_eval.py +7 -0
  3. sglang/srt/_custom_ops.py +29 -1
  4. sglang/srt/configs/deepseekvl2.py +11 -2
  5. sglang/srt/configs/internvl.py +3 -0
  6. sglang/srt/configs/janus_pro.py +3 -0
  7. sglang/srt/configs/model_config.py +10 -8
  8. sglang/srt/configs/update_config.py +3 -1
  9. sglang/srt/conversation.py +2 -1
  10. sglang/srt/custom_op.py +5 -2
  11. sglang/srt/disaggregation/common/conn.py +34 -6
  12. sglang/srt/disaggregation/decode.py +9 -1
  13. sglang/srt/disaggregation/mini_lb.py +3 -2
  14. sglang/srt/disaggregation/mooncake/conn.py +93 -76
  15. sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
  16. sglang/srt/disaggregation/nixl/conn.py +17 -13
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
  18. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
  19. sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
  20. sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
  21. sglang/srt/distributed/parallel_state.py +103 -15
  22. sglang/srt/entrypoints/engine.py +31 -33
  23. sglang/srt/entrypoints/http_server.py +20 -32
  24. sglang/srt/entrypoints/openai/protocol.py +3 -3
  25. sglang/srt/entrypoints/openai/serving_chat.py +48 -6
  26. sglang/srt/eplb/expert_location_dispatch.py +1 -1
  27. sglang/srt/function_call/base_format_detector.py +74 -12
  28. sglang/srt/function_call/deepseekv3_detector.py +26 -11
  29. sglang/srt/function_call/ebnf_composer.py +95 -63
  30. sglang/srt/function_call/function_call_parser.py +4 -2
  31. sglang/srt/function_call/kimik2_detector.py +41 -16
  32. sglang/srt/function_call/llama32_detector.py +6 -3
  33. sglang/srt/function_call/mistral_detector.py +11 -3
  34. sglang/srt/function_call/pythonic_detector.py +16 -14
  35. sglang/srt/function_call/qwen25_detector.py +12 -3
  36. sglang/srt/function_call/qwen3_coder_detector.py +151 -0
  37. sglang/srt/hf_transformers_utils.py +0 -1
  38. sglang/srt/layers/activation.py +24 -3
  39. sglang/srt/layers/attention/base_attn_backend.py +3 -1
  40. sglang/srt/layers/attention/flashattention_backend.py +3 -3
  41. sglang/srt/layers/attention/flashinfer_backend.py +40 -1
  42. sglang/srt/layers/communicator.py +12 -12
  43. sglang/srt/layers/dp_attention.py +72 -24
  44. sglang/srt/layers/linear.py +13 -102
  45. sglang/srt/layers/logits_processor.py +34 -24
  46. sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
  47. sglang/srt/layers/moe/ep_moe/layer.py +23 -402
  48. sglang/srt/layers/moe/fused_moe_native.py +7 -47
  49. sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
  50. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  52. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +54 -263
  57. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
  58. sglang/srt/layers/moe/topk.py +190 -23
  59. sglang/srt/layers/quantization/__init__.py +20 -134
  60. sglang/srt/layers/quantization/awq.py +578 -11
  61. sglang/srt/layers/quantization/awq_triton.py +339 -0
  62. sglang/srt/layers/quantization/base_config.py +85 -10
  63. sglang/srt/layers/quantization/blockwise_int8.py +17 -55
  64. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
  65. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +23 -79
  66. sglang/srt/layers/quantization/fp8.py +273 -62
  67. sglang/srt/layers/quantization/fp8_kernel.py +210 -46
  68. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  69. sglang/srt/layers/quantization/gptq.py +501 -143
  70. sglang/srt/layers/quantization/marlin_utils.py +790 -0
  71. sglang/srt/layers/quantization/modelopt_quant.py +34 -112
  72. sglang/srt/layers/quantization/moe_wna16.py +45 -49
  73. sglang/srt/layers/quantization/petit.py +252 -0
  74. sglang/srt/layers/quantization/petit_utils.py +104 -0
  75. sglang/srt/layers/quantization/qoq.py +7 -6
  76. sglang/srt/layers/quantization/scalar_type.py +352 -0
  77. sglang/srt/layers/quantization/unquant.py +422 -0
  78. sglang/srt/layers/quantization/utils.py +340 -9
  79. sglang/srt/layers/quantization/w4afp8.py +8 -4
  80. sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
  81. sglang/srt/layers/quantization/w8a8_int8.py +51 -115
  82. sglang/srt/layers/radix_attention.py +5 -3
  83. sglang/srt/layers/vocab_parallel_embedding.py +1 -41
  84. sglang/srt/lora/lora.py +0 -4
  85. sglang/srt/lora/lora_manager.py +162 -164
  86. sglang/srt/lora/lora_registry.py +124 -0
  87. sglang/srt/lora/mem_pool.py +83 -35
  88. sglang/srt/lora/utils.py +12 -5
  89. sglang/srt/managers/cache_controller.py +288 -0
  90. sglang/srt/managers/io_struct.py +60 -30
  91. sglang/srt/managers/mm_utils.py +7 -8
  92. sglang/srt/managers/schedule_batch.py +163 -113
  93. sglang/srt/managers/schedule_policy.py +68 -27
  94. sglang/srt/managers/scheduler.py +256 -86
  95. sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
  96. sglang/srt/managers/tokenizer_manager.py +38 -27
  97. sglang/srt/managers/tp_worker.py +16 -4
  98. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  99. sglang/srt/mem_cache/allocator.py +74 -23
  100. sglang/srt/mem_cache/base_prefix_cache.py +14 -2
  101. sglang/srt/mem_cache/chunk_cache.py +5 -2
  102. sglang/srt/mem_cache/hicache_storage.py +168 -0
  103. sglang/srt/mem_cache/hiradix_cache.py +194 -5
  104. sglang/srt/mem_cache/memory_pool.py +16 -1
  105. sglang/srt/mem_cache/memory_pool_host.py +44 -2
  106. sglang/srt/mem_cache/radix_cache.py +26 -0
  107. sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
  108. sglang/srt/metrics/collector.py +9 -0
  109. sglang/srt/model_executor/cuda_graph_runner.py +66 -31
  110. sglang/srt/model_executor/forward_batch_info.py +210 -25
  111. sglang/srt/model_executor/model_runner.py +147 -42
  112. sglang/srt/model_loader/loader.py +7 -1
  113. sglang/srt/model_loader/utils.py +4 -4
  114. sglang/srt/models/clip.py +1 -1
  115. sglang/srt/models/deepseek.py +9 -6
  116. sglang/srt/models/deepseek_janus_pro.py +1 -1
  117. sglang/srt/models/deepseek_v2.py +192 -173
  118. sglang/srt/models/deepseek_vl2.py +5 -5
  119. sglang/srt/models/gemma.py +48 -0
  120. sglang/srt/models/gemma2.py +52 -0
  121. sglang/srt/models/gemma3_causal.py +63 -0
  122. sglang/srt/models/gemma3_mm.py +1 -1
  123. sglang/srt/models/gemma3n_mm.py +2 -4
  124. sglang/srt/models/granitemoe.py +385 -0
  125. sglang/srt/models/grok.py +9 -3
  126. sglang/srt/models/hunyuan.py +63 -16
  127. sglang/srt/models/internvl.py +1 -1
  128. sglang/srt/models/kimi_vl.py +1 -1
  129. sglang/srt/models/llama.py +41 -0
  130. sglang/srt/models/llama4.py +11 -11
  131. sglang/srt/models/llava.py +2 -2
  132. sglang/srt/models/llavavid.py +1 -1
  133. sglang/srt/models/minicpm.py +0 -2
  134. sglang/srt/models/minicpmo.py +3 -7
  135. sglang/srt/models/minicpmv.py +1 -1
  136. sglang/srt/models/mistral.py +1 -1
  137. sglang/srt/models/mixtral.py +9 -2
  138. sglang/srt/models/mllama.py +3 -5
  139. sglang/srt/models/mllama4.py +13 -6
  140. sglang/srt/models/olmoe.py +8 -5
  141. sglang/srt/models/persimmon.py +330 -0
  142. sglang/srt/models/phi.py +321 -0
  143. sglang/srt/models/phi4mm.py +44 -4
  144. sglang/srt/models/phi4mm_audio.py +1260 -0
  145. sglang/srt/models/phi4mm_utils.py +1917 -0
  146. sglang/srt/models/phimoe.py +9 -3
  147. sglang/srt/models/qwen.py +37 -0
  148. sglang/srt/models/qwen2.py +41 -0
  149. sglang/srt/models/qwen2_5_vl.py +4 -4
  150. sglang/srt/models/qwen2_audio.py +1 -1
  151. sglang/srt/models/qwen2_moe.py +53 -9
  152. sglang/srt/models/qwen2_vl.py +4 -4
  153. sglang/srt/models/qwen3.py +65 -1
  154. sglang/srt/models/qwen3_moe.py +57 -24
  155. sglang/srt/models/vila.py +1 -1
  156. sglang/srt/multimodal/processors/base_processor.py +91 -97
  157. sglang/srt/multimodal/processors/clip.py +21 -19
  158. sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
  159. sglang/srt/multimodal/processors/gemma3.py +13 -17
  160. sglang/srt/multimodal/processors/gemma3n.py +19 -23
  161. sglang/srt/multimodal/processors/internvl.py +9 -10
  162. sglang/srt/multimodal/processors/janus_pro.py +12 -27
  163. sglang/srt/multimodal/processors/kimi_vl.py +12 -14
  164. sglang/srt/multimodal/processors/llava.py +4 -2
  165. sglang/srt/multimodal/processors/minicpm.py +35 -44
  166. sglang/srt/multimodal/processors/mlama.py +21 -18
  167. sglang/srt/multimodal/processors/mllama4.py +4 -5
  168. sglang/srt/multimodal/processors/phi4mm.py +63 -39
  169. sglang/srt/multimodal/processors/pixtral.py +14 -35
  170. sglang/srt/multimodal/processors/qwen_audio.py +65 -0
  171. sglang/srt/multimodal/processors/qwen_vl.py +16 -21
  172. sglang/srt/multimodal/processors/vila.py +14 -14
  173. sglang/srt/reasoning_parser.py +46 -4
  174. sglang/srt/sampling/sampling_batch_info.py +6 -5
  175. sglang/srt/sampling/sampling_params.py +8 -1
  176. sglang/srt/server_args.py +454 -270
  177. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
  178. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +46 -37
  179. sglang/srt/speculative/eagle_utils.py +51 -23
  180. sglang/srt/speculative/eagle_worker.py +59 -44
  181. sglang/srt/two_batch_overlap.py +10 -5
  182. sglang/srt/utils.py +44 -69
  183. sglang/test/runners.py +14 -3
  184. sglang/test/test_activation.py +50 -1
  185. sglang/test/test_block_fp8.py +8 -3
  186. sglang/test/test_block_fp8_ep.py +1 -1
  187. sglang/test/test_custom_ops.py +12 -7
  188. sglang/test/test_cutlass_w4a8_moe.py +1 -3
  189. sglang/test/test_fp4_moe.py +1 -3
  190. sglang/test/test_marlin_moe.py +286 -0
  191. sglang/test/test_marlin_utils.py +171 -0
  192. sglang/test/test_utils.py +35 -0
  193. sglang/version.py +1 -1
  194. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/METADATA +10 -10
  195. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/RECORD +198 -175
  196. sglang/srt/layers/quantization/quant_utils.py +0 -166
  197. sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
  198. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/WHEEL +0 -0
  199. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/licenses/LICENSE +0 -0
  200. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/top_level.txt +0 -0
@@ -27,7 +27,9 @@ from sglang.srt.distributed import (
27
27
  tensor_model_parallel_all_gather,
28
28
  )
29
29
  from sglang.srt.layers.dp_attention import (
30
+ DPPaddingMode,
30
31
  attn_tp_all_gather,
32
+ attn_tp_all_gather_into_tensor,
31
33
  dp_gather_replicate,
32
34
  dp_scatter,
33
35
  get_attention_dp_rank,
@@ -111,7 +113,8 @@ class LogitsMetadata:
111
113
  # Number of tokens to sample per DP rank
112
114
  global_num_tokens_for_logprob_cpu: Optional[torch.Tensor] = None
113
115
  global_num_tokens_for_logprob_gpu: Optional[torch.Tensor] = None
114
-
116
+ # The gather mode for DP attention
117
+ dp_padding_mode: Optional[DPPaddingMode] = None
115
118
  # for padding
116
119
  padded_static_len: int = -1
117
120
 
@@ -163,12 +166,12 @@ class LogitsMetadata:
163
166
  forward_batch_gathered_buffer=forward_batch.gathered_buffer,
164
167
  global_num_tokens_for_logprob_cpu=forward_batch.global_num_tokens_for_logprob_cpu,
165
168
  global_num_tokens_for_logprob_gpu=forward_batch.global_num_tokens_for_logprob_gpu,
169
+ dp_padding_mode=DPPaddingMode.SUM_LEN,
166
170
  )
167
171
 
168
- def compute_dp_attention_metadata(self, hidden_states: torch.Tensor):
169
- if self.global_num_tokens_for_logprob_cpu is None:
170
- # we are capturing cuda graph
171
- return
172
+ def compute_dp_attention_metadata(self):
173
+ # TODO(ch-wan): gathered_buffer here is larger than the actual required size in draft extend,
174
+ # we may use a smaller buffer in draft extend.
172
175
 
173
176
  cumtokens = torch.cumsum(self.global_num_tokens_for_logprob_gpu, dim=0)
174
177
  dp_rank = get_attention_dp_rank()
@@ -179,18 +182,9 @@ class LogitsMetadata:
179
182
  else:
180
183
  dp_local_start_pos = cumtokens[dp_rank - 1]
181
184
  dp_local_num_tokens = self.global_num_tokens_for_logprob_gpu[dp_rank]
182
- gathered_buffer = torch.zeros(
183
- (
184
- sum(self.global_num_tokens_for_logprob_cpu),
185
- hidden_states.shape[1],
186
- ),
187
- dtype=hidden_states.dtype,
188
- device=hidden_states.device,
189
- )
190
185
 
191
186
  self.dp_local_start_pos = dp_local_start_pos
192
187
  self.dp_local_num_tokens = dp_local_num_tokens
193
- self.gathered_buffer = gathered_buffer
194
188
 
195
189
 
196
190
  class LogitsProcessor(nn.Module):
@@ -434,7 +428,7 @@ class LogitsProcessor(nn.Module):
434
428
  guarantee the given hidden_states follow this constraint.
435
429
  """
436
430
  if self.do_tensor_parallel_all_gather_dp_attn:
437
- logits_metadata.compute_dp_attention_metadata(hidden_states)
431
+ logits_metadata.compute_dp_attention_metadata()
438
432
  hidden_states, local_hidden_states = (
439
433
  torch.empty_like(logits_metadata.gathered_buffer),
440
434
  hidden_states,
@@ -463,15 +457,31 @@ class LogitsProcessor(nn.Module):
463
457
 
464
458
  if self.do_tensor_parallel_all_gather:
465
459
  if self.use_attn_tp_group:
466
- global_logits = torch.empty(
467
- (self.config.vocab_size, logits.shape[0]),
468
- device=logits.device,
469
- dtype=logits.dtype,
470
- )
471
- global_logits = global_logits.T
472
- attn_tp_all_gather(
473
- list(global_logits.tensor_split(self.attn_tp_size, dim=-1)), logits
474
- )
460
+ if self.config.vocab_size % self.attn_tp_size == 0:
461
+ global_logits = torch.empty(
462
+ (
463
+ self.attn_tp_size,
464
+ logits.shape[0],
465
+ self.config.vocab_size // self.attn_tp_size,
466
+ ),
467
+ device=logits.device,
468
+ dtype=logits.dtype,
469
+ )
470
+ attn_tp_all_gather_into_tensor(global_logits, logits)
471
+ global_logits = global_logits.permute(1, 0, 2).reshape(
472
+ logits.shape[0], self.config.vocab_size
473
+ )
474
+ else:
475
+ global_logits = torch.empty(
476
+ (self.config.vocab_size, logits.shape[0]),
477
+ device=logits.device,
478
+ dtype=logits.dtype,
479
+ )
480
+ global_logits = global_logits.T
481
+ attn_tp_all_gather(
482
+ list(global_logits.tensor_split(self.attn_tp_size, dim=-1)),
483
+ logits,
484
+ )
475
485
  logits = global_logits
476
486
  else:
477
487
  logits = tensor_model_parallel_all_gather(logits)
@@ -236,7 +236,8 @@ def pre_reorder_triton_kernel(
236
236
  ):
237
237
  OutDtype = gateup_input_ptr.dtype.element_ty
238
238
 
239
- src_idx = tl.program_id(0)
239
+ src_idx_int32 = tl.program_id(0)
240
+ src_idx = src_idx_int32.to(tl.int64)
240
241
  src2dst_ptr = src2dst_ptr + src_idx * topk
241
242
  topk_ids_ptr = topk_ids_ptr + src_idx * topk
242
243
  src_ptr = input_ptr + src_idx * hidden_size
@@ -255,7 +256,8 @@ def pre_reorder_triton_kernel(
255
256
  else:
256
257
  scale = 1.0
257
258
 
258
- dst_idx = tl.load(src2dst_ptr + idx)
259
+ dst_idx_int32 = tl.load(src2dst_ptr + idx)
260
+ dst_idx = dst_idx_int32.to(tl.int64)
259
261
  dst_ptr = gateup_input_ptr + dst_idx * hidden_size
260
262
  for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
261
263
  offset = start_offset + vec
@@ -1,17 +1,13 @@
1
1
  import logging
2
- from typing import Callable, List, Optional, Tuple
2
+ from typing import List, Optional, Tuple
3
3
 
4
- import einops
5
4
  import torch
6
- from torch.nn import Module
7
5
 
8
- from sglang.srt.custom_op import CustomOp
9
6
  from sglang.srt.distributed import (
10
7
  get_tensor_model_parallel_rank,
11
8
  get_tensor_model_parallel_world_size,
12
9
  )
13
10
  from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
14
- from sglang.srt.eplb.expert_location_dispatch import ExpertLocationDispatchInfo
15
11
  from sglang.srt.layers.moe.ep_moe.kernels import (
16
12
  ep_gather,
17
13
  ep_scatter,
@@ -27,22 +23,20 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
27
23
  silu_and_mul_triton_kernel,
28
24
  tma_align_input_scale,
29
25
  )
30
- from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
31
- from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE, FusedMoEMethodBase
32
- from sglang.srt.layers.moe.topk import select_experts
26
+ from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
27
+ from sglang.srt.layers.moe.topk import TopKOutput
33
28
  from sglang.srt.layers.quantization import deep_gemm_wrapper
34
29
  from sglang.srt.layers.quantization.base_config import (
35
30
  QuantizationConfig,
36
31
  QuantizeMethodBase,
37
32
  )
38
- from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
33
+ from sglang.srt.layers.quantization.fp8 import Fp8EPMoEMethod
39
34
  from sglang.srt.layers.quantization.fp8_kernel import (
40
35
  is_fp8_fnuz,
41
- scaled_fp8_quant,
42
36
  sglang_per_token_group_quant_fp8,
43
37
  sglang_per_token_quant_fp8,
44
38
  )
45
- from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
39
+ from sglang.srt.layers.quantization.unquant import UnquantizedEPMoEMethod
46
40
  from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config, W4AFp8MoEMethod
47
41
  from sglang.srt.managers.schedule_batch import global_server_args_dict
48
42
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
@@ -53,7 +47,6 @@ from sglang.srt.utils import (
53
47
  get_bool_env_var,
54
48
  is_hip,
55
49
  is_npu,
56
- set_weight_attrs,
57
50
  )
58
51
 
59
52
  _is_hip = is_hip()
@@ -61,14 +54,11 @@ _is_npu = is_npu()
61
54
  _is_fp8_fnuz = is_fp8_fnuz()
62
55
  _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
63
56
 
64
- if not _is_npu:
57
+ if not (_is_npu or _is_hip):
65
58
  from sgl_kernel import silu_and_mul
66
59
 
67
60
  from sglang.srt.layers.moe.cutlass_w4a8_moe import cutlass_w4a8_moe
68
61
 
69
- if _is_hip:
70
- from vllm._custom_ops import scaled_fp8_quant
71
-
72
62
  if _use_aiter:
73
63
  from aiter import ActivationType, QuantType
74
64
  from aiter.fused_moe import fused_moe
@@ -165,16 +155,9 @@ class EPMoE(torch.nn.Module):
165
155
  intermediate_size: int,
166
156
  layer_id: int,
167
157
  params_dtype: Optional[torch.dtype] = None,
168
- renormalize: bool = True,
169
- use_grouped_topk: bool = False,
170
- num_expert_group: Optional[int] = None,
171
- num_fused_shared_experts: int = 0,
172
- topk_group: Optional[int] = None,
173
158
  quant_config: Optional[QuantizationConfig] = None,
174
159
  tp_size: Optional[int] = None,
175
160
  prefix: str = "",
176
- correction_bias: Optional[torch.Tensor] = None,
177
- custom_routing_function: Optional[Callable] = None,
178
161
  activation: str = "silu",
179
162
  routed_scaling_factor: Optional[float] = None,
180
163
  use_per_token_if_dynamic: bool = True,
@@ -192,24 +175,12 @@ class EPMoE(torch.nn.Module):
192
175
  self.layer_id = layer_id
193
176
  self.num_experts = num_experts
194
177
  assert self.num_experts % self.tp_size == 0
195
- assert (
196
- num_fused_shared_experts == 0
197
- ), "num_fused_shared_experts is not supported in EP"
198
- self.num_fused_shared_experts = num_fused_shared_experts
199
178
  self.num_experts_per_partition, self.expert_map = self.determine_expert_map()
200
179
  self.start_expert_id = self.tp_rank * self.num_experts_per_partition
201
180
  self.end_expert_id = self.start_expert_id + self.num_experts_per_partition - 1
202
181
 
203
182
  self.top_k = top_k
204
183
  self.intermediate_size = intermediate_size
205
- self.renormalize = renormalize
206
- self.use_grouped_topk = use_grouped_topk
207
- if self.use_grouped_topk:
208
- assert num_expert_group is not None and topk_group is not None
209
- self.num_expert_group = num_expert_group
210
- self.topk_group = topk_group
211
- self.correction_bias = correction_bias
212
- self.custom_routing_function = custom_routing_function
213
184
  self.activation = activation
214
185
  self.routed_scaling_factor = routed_scaling_factor
215
186
  self.use_per_token_if_dynamic = use_per_token_if_dynamic
@@ -314,33 +285,24 @@ class EPMoE(torch.nn.Module):
314
285
  )
315
286
  return (local_num_experts, expert_map)
316
287
 
317
- def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
288
+ def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
318
289
  if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8:
319
- return self.forward_deepgemm(hidden_states, router_logits)
290
+ return self.forward_deepgemm(hidden_states, topk_output)
320
291
  else:
321
- return self.forward_normal(hidden_states, router_logits)
292
+ return self.forward_normal(hidden_states, topk_output)
322
293
 
323
294
  def forward_deepgemm(
324
- self, hidden_states: torch.Tensor, router_logits: torch.Tensor
295
+ self,
296
+ hidden_states: torch.Tensor,
297
+ topk_output: TopKOutput,
325
298
  ):
326
299
  assert self.quant_method is not None
327
300
  assert self.activation == "silu"
328
301
  hidden_states_shape = hidden_states.shape
329
302
  hidden_states_dtype = hidden_states.dtype
330
303
  hidden_states_device = hidden_states.device
331
- topk_weights, topk_ids = select_experts(
332
- hidden_states=hidden_states,
333
- router_logits=router_logits,
334
- top_k=self.top_k,
335
- use_grouped_topk=self.use_grouped_topk,
336
- renormalize=self.renormalize,
337
- topk_group=self.topk_group,
338
- num_expert_group=self.num_expert_group,
339
- num_fused_shared_experts=self.num_fused_shared_experts,
340
- correction_bias=self.correction_bias,
341
- custom_routing_function=self.custom_routing_function,
342
- routed_scaling_factor=self.routed_scaling_factor,
343
- )
304
+
305
+ topk_weights, topk_ids, _ = topk_output
344
306
 
345
307
  if not self.use_block_quant:
346
308
  # Convert per-tensor quant to per-block quant by repeating scales for forward_deepgemm
@@ -472,8 +434,10 @@ class EPMoE(torch.nn.Module):
472
434
  )
473
435
  return output
474
436
 
475
- def forward_normal(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
437
+ def forward_normal(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
476
438
  assert self.quant_method is not None
439
+ topk_weights, topk_ids, _ = topk_output
440
+
477
441
  hidden_states_shape = hidden_states.shape
478
442
  hidden_states_dtype = hidden_states.dtype
479
443
  hidden_states_device = hidden_states.device
@@ -484,23 +448,6 @@ class EPMoE(torch.nn.Module):
484
448
  use_per_token_if_dynamic=self.use_per_token_if_dynamic,
485
449
  )
486
450
 
487
- topk_weights, topk_ids = select_experts(
488
- hidden_states=hidden_states,
489
- router_logits=router_logits,
490
- top_k=self.top_k,
491
- use_grouped_topk=self.use_grouped_topk,
492
- renormalize=self.renormalize,
493
- topk_group=self.topk_group,
494
- num_expert_group=self.num_expert_group,
495
- num_fused_shared_experts=self.num_fused_shared_experts,
496
- correction_bias=self.correction_bias,
497
- custom_routing_function=self.custom_routing_function,
498
- routed_scaling_factor=self.routed_scaling_factor,
499
- expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
500
- layer_id=self.layer_id,
501
- ),
502
- )
503
-
504
451
  if self.use_w4afp8:
505
452
  local_topk_ids = topk_ids
506
453
  if self.expert_map is not None:
@@ -904,324 +851,6 @@ class EPMoE(torch.nn.Module):
904
851
  param_data[expert_id] = loaded_weight
905
852
 
906
853
 
907
- class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp):
908
-
909
- def create_weights(
910
- self,
911
- layer: torch.nn.Module,
912
- num_experts_per_partition: int,
913
- hidden_size: int,
914
- intermediate_size: int,
915
- params_dtype: torch.dtype,
916
- **extra_weight_attrs,
917
- ):
918
- # Fused gate_up_proj (column parallel)
919
- w13_weight = torch.nn.Parameter(
920
- torch.empty(
921
- num_experts_per_partition,
922
- 2 * intermediate_size,
923
- hidden_size,
924
- dtype=params_dtype,
925
- ),
926
- requires_grad=False,
927
- )
928
- layer.register_parameter("w13_weight", w13_weight)
929
- set_weight_attrs(w13_weight, extra_weight_attrs)
930
-
931
- # down_proj (row parallel)
932
- w2_weight = torch.nn.Parameter(
933
- torch.empty(
934
- num_experts_per_partition,
935
- hidden_size,
936
- intermediate_size,
937
- dtype=params_dtype,
938
- ),
939
- requires_grad=False,
940
- )
941
- layer.register_parameter("w2_weight", w2_weight)
942
- set_weight_attrs(w2_weight, extra_weight_attrs)
943
-
944
- # scale
945
- layer.register_parameter("w13_input_scale", None)
946
- layer.register_parameter("w13_weight_scale", None)
947
-
948
- ones_tensor = torch.ones(num_experts_per_partition, dtype=torch.float32)
949
-
950
- w2_input_scale = torch.nn.Parameter(
951
- ones_tensor,
952
- requires_grad=False,
953
- )
954
- layer.register_parameter("w2_input_scale", w2_input_scale)
955
- set_weight_attrs(w2_input_scale, extra_weight_attrs)
956
-
957
- w2_weight_scale = torch.nn.Parameter(
958
- ones_tensor,
959
- requires_grad=False,
960
- )
961
- layer.register_parameter("w2_weight_scale", w2_weight_scale)
962
- set_weight_attrs(w2_weight_scale, extra_weight_attrs)
963
-
964
- def apply(
965
- self,
966
- layer: torch.nn.Module,
967
- x: torch.Tensor,
968
- router_logits: torch.Tensor,
969
- top_k: int,
970
- renormalize: bool,
971
- use_grouped_topk: bool,
972
- topk_group: Optional[int] = None,
973
- num_expert_group: Optional[int] = None,
974
- custom_routing_function: Optional[Callable] = None,
975
- ) -> torch.Tensor:
976
- raise NotImplementedError
977
-
978
-
979
- class Fp8EPMoEMethod(Fp8MoEMethod):
980
- """MoE method for FP8.
981
- Supports loading FP8 checkpoints with static weight scale and
982
- dynamic/static activation scale.
983
-
984
- Args:
985
- quant_config: The quantization config.
986
- """
987
-
988
- def __init__(self, quant_config: Fp8Config):
989
- self.quant_config = quant_config
990
- self.block_quant = self.quant_config.weight_block_size is not None
991
-
992
- def create_weights(
993
- self,
994
- layer: Module,
995
- num_experts_per_partition: int,
996
- hidden_size: int,
997
- intermediate_size: int,
998
- params_dtype: torch.dtype,
999
- **extra_weight_attrs,
1000
- ):
1001
- if self.quant_config.is_checkpoint_fp8_serialized:
1002
- params_dtype = torch.float8_e4m3fn
1003
-
1004
- tp_size = get_tensor_model_parallel_world_size()
1005
- if self.block_quant:
1006
- block_n, block_k = (
1007
- self.quant_config.weight_block_size[0],
1008
- self.quant_config.weight_block_size[1],
1009
- )
1010
- # NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
1011
- # Required by column parallel or enabling merged weights
1012
- if intermediate_size % block_n != 0:
1013
- raise ValueError(
1014
- f"The output_size of gate's and up's weight = "
1015
- f"{intermediate_size} is not divisible by "
1016
- f"weight quantization block_n = {block_n}."
1017
- )
1018
- if tp_size > 1:
1019
- # Required by row parallel
1020
- if intermediate_size % block_k != 0:
1021
- raise ValueError(
1022
- f"The input_size of down's weight = "
1023
- f"{intermediate_size} is not divisible by "
1024
- f"weight quantization block_k = {block_k}."
1025
- )
1026
-
1027
- # WEIGHTS
1028
- w13_weight = torch.nn.Parameter(
1029
- torch.empty(
1030
- num_experts_per_partition,
1031
- 2 * intermediate_size,
1032
- hidden_size,
1033
- dtype=params_dtype,
1034
- ),
1035
- requires_grad=False,
1036
- )
1037
- layer.register_parameter("w13_weight", w13_weight)
1038
- set_weight_attrs(w13_weight, extra_weight_attrs)
1039
-
1040
- w2_weight = torch.nn.Parameter(
1041
- torch.empty(
1042
- num_experts_per_partition,
1043
- hidden_size,
1044
- intermediate_size,
1045
- dtype=params_dtype,
1046
- ),
1047
- requires_grad=False,
1048
- )
1049
- layer.register_parameter("w2_weight", w2_weight)
1050
- set_weight_attrs(w2_weight, extra_weight_attrs)
1051
-
1052
- # WEIGHT_SCALES
1053
- if self.block_quant:
1054
- w13_weight_scale = torch.nn.Parameter(
1055
- torch.ones(
1056
- num_experts_per_partition,
1057
- 2 * ((intermediate_size + block_n - 1) // block_n),
1058
- (hidden_size + block_k - 1) // block_k,
1059
- dtype=torch.float32,
1060
- ),
1061
- requires_grad=False,
1062
- )
1063
- w2_weight_scale = torch.nn.Parameter(
1064
- torch.ones(
1065
- num_experts_per_partition,
1066
- (hidden_size + block_n - 1) // block_n,
1067
- (intermediate_size + block_k - 1) // block_k,
1068
- dtype=torch.float32,
1069
- ),
1070
- requires_grad=False,
1071
- )
1072
- layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
1073
- layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
1074
- assert self.quant_config.activation_scheme == "dynamic"
1075
- else:
1076
- # WEIGHT_SCALES
1077
- # Allocate 2 scales for w1 and w3 respectively.
1078
- w13_weight_scale = torch.nn.Parameter(
1079
- torch.ones(num_experts_per_partition, 2, dtype=torch.float32),
1080
- requires_grad=False,
1081
- )
1082
- layer.register_parameter("w13_weight_scale", w13_weight_scale)
1083
-
1084
- w2_weight_scale = torch.nn.Parameter(
1085
- torch.ones(num_experts_per_partition, dtype=torch.float32),
1086
- requires_grad=False,
1087
- )
1088
- layer.register_parameter("w2_weight_scale", w2_weight_scale)
1089
- # Add the quantization method used (per tensor/grouped/channel)
1090
- # to ensure the weight scales are loaded in properly
1091
- extra_weight_attrs.update(
1092
- {"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
1093
- if self.block_quant
1094
- else {"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
1095
- )
1096
- # If loading fp8 checkpoint, pass the weight loaders.
1097
- # If loading an fp16 checkpoint, do not (we will quantize in
1098
- # process_weights_after_loading()
1099
- if self.quant_config.is_checkpoint_fp8_serialized:
1100
- set_weight_attrs(w13_weight_scale, extra_weight_attrs)
1101
- set_weight_attrs(w2_weight_scale, extra_weight_attrs)
1102
-
1103
- # INPUT_SCALES
1104
- if self.quant_config.activation_scheme == "static":
1105
- if not self.quant_config.is_checkpoint_fp8_serialized:
1106
- raise ValueError(
1107
- "Found static activation scheme for checkpoint that "
1108
- "was not serialized fp8."
1109
- )
1110
-
1111
- w13_input_scale = torch.nn.Parameter(
1112
- torch.ones(num_experts_per_partition, dtype=torch.float32),
1113
- requires_grad=False,
1114
- )
1115
- layer.register_parameter("w13_input_scale", w13_input_scale)
1116
- set_weight_attrs(w13_input_scale, extra_weight_attrs)
1117
-
1118
- w2_input_scale = torch.nn.Parameter(
1119
- torch.ones(num_experts_per_partition, dtype=torch.float32),
1120
- requires_grad=False,
1121
- )
1122
- layer.register_parameter("w2_input_scale", w2_input_scale)
1123
- set_weight_attrs(w2_input_scale, extra_weight_attrs)
1124
-
1125
- else:
1126
- layer.w13_input_scale = None
1127
- layer.w2_input_scale = None
1128
-
1129
- def process_weights_after_loading(self, layer: Module) -> None:
1130
-
1131
- # If checkpoint is fp16, quantize in place.
1132
- if not self.quant_config.is_checkpoint_fp8_serialized:
1133
- # If rocm, use float8_e4m3fnuz as dtype
1134
- fp8_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
1135
- w13_weight = torch.empty_like(layer.w13_weight.data, dtype=fp8_dtype)
1136
- w2_weight = torch.empty_like(layer.w2_weight.data, dtype=fp8_dtype)
1137
-
1138
- layer.w13_weight_scale = torch.nn.Parameter(
1139
- torch.ones(
1140
- layer.num_experts_per_partition,
1141
- dtype=torch.float32,
1142
- device=w13_weight.device,
1143
- ),
1144
- requires_grad=False,
1145
- )
1146
-
1147
- for expert in range(layer.num_experts_per_partition):
1148
- w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
1149
- scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
1150
- )
1151
- w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
1152
- scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
1153
- )
1154
- layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
1155
- layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
1156
- return
1157
-
1158
- # If checkpoint is fp8, we need to handle that the
1159
- # MoE kernels require single activation scale and single weight
1160
- # scale for w13 per expert.
1161
- else:
1162
- if self.quant_config.activation_scheme == "static":
1163
- if layer.w13_input_scale is None or layer.w2_input_scale is None:
1164
- raise ValueError(
1165
- "QuantConfig has static quantization, but found "
1166
- "activation scales are None."
1167
- )
1168
- layer.w13_weight_scale = torch.nn.Parameter(
1169
- torch.max(layer.w13_weight_scale, dim=1).values,
1170
- requires_grad=False,
1171
- )
1172
- if self.block_quant:
1173
- # If ROCm, normalize the weights and scales to e4m3fnuz
1174
- if _is_fp8_fnuz:
1175
- # activation_scheme: dynamic
1176
- w13_weight, w13_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
1177
- weight=layer.w13_weight,
1178
- weight_scale=layer.w13_weight_scale_inv,
1179
- input_scale=None,
1180
- )
1181
- w2_weight, w2_weight_scale, _ = normalize_e4m3fn_to_e4m3fnuz(
1182
- weight=layer.w2_weight,
1183
- weight_scale=layer.w2_weight_scale_inv,
1184
- input_scale=None,
1185
- )
1186
- # Reset the parameter
1187
- layer.w13_weight = torch.nn.Parameter(
1188
- w13_weight, requires_grad=False
1189
- )
1190
- layer.w13_weight_scale_inv = torch.nn.Parameter(
1191
- w13_weight_scale, requires_grad=False
1192
- )
1193
- layer.w13_input_scale = None
1194
- layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
1195
- layer.w2_weight_scale_inv = torch.nn.Parameter(
1196
- w2_weight_scale, requires_grad=False
1197
- )
1198
- layer.w2_input_scale = None
1199
- if _use_aiter:
1200
- layer.w13_weight = torch.nn.Parameter(
1201
- shuffle_weight(layer.w13_weight.data, (16, 16)),
1202
- requires_grad=False,
1203
- )
1204
- layer.w2_weight = torch.nn.Parameter(
1205
- shuffle_weight(layer.w2_weight.data, (16, 16)),
1206
- requires_grad=False,
1207
- )
1208
- return
1209
-
1210
- def apply(
1211
- self,
1212
- layer: torch.nn.Module,
1213
- x: torch.Tensor,
1214
- router_logits: torch.Tensor,
1215
- top_k: int,
1216
- renormalize: bool,
1217
- use_grouped_topk: bool,
1218
- topk_group: Optional[int] = None,
1219
- num_expert_group: Optional[int] = None,
1220
- custom_routing_function: Optional[Callable] = None,
1221
- ) -> torch.Tensor:
1222
- raise NotImplementedError
1223
-
1224
-
1225
854
  class DeepEPMoE(EPMoE):
1226
855
  """
1227
856
  MoE Expert Parallel Impl based on DeepEP (https://github.com/deepseek-ai/DeepEP/tree/main)
@@ -1237,16 +866,9 @@ class DeepEPMoE(EPMoE):
1237
866
  intermediate_size: int,
1238
867
  layer_id: int,
1239
868
  params_dtype: Optional[torch.dtype] = None,
1240
- renormalize: bool = True,
1241
- use_grouped_topk: bool = False,
1242
- num_expert_group: Optional[int] = None,
1243
- num_fused_shared_experts: int = 0,
1244
- topk_group: Optional[int] = None,
1245
869
  quant_config: Optional[QuantizationConfig] = None,
1246
870
  tp_size: Optional[int] = None,
1247
871
  prefix: str = "",
1248
- correction_bias: Optional[torch.Tensor] = None,
1249
- custom_routing_function: Optional[Callable] = None,
1250
872
  activation: str = "silu",
1251
873
  routed_scaling_factor: Optional[float] = None,
1252
874
  deepep_mode: DeepEPMode = DeepEPMode.auto,
@@ -1258,20 +880,19 @@ class DeepEPMoE(EPMoE):
1258
880
  intermediate_size=intermediate_size,
1259
881
  layer_id=layer_id,
1260
882
  params_dtype=params_dtype,
1261
- renormalize=renormalize,
1262
- use_grouped_topk=use_grouped_topk,
1263
- num_expert_group=num_expert_group,
1264
- num_fused_shared_experts=num_fused_shared_experts,
1265
- topk_group=topk_group,
1266
883
  quant_config=quant_config,
1267
884
  tp_size=tp_size,
1268
885
  prefix=prefix,
1269
- correction_bias=correction_bias,
1270
- custom_routing_function=custom_routing_function,
1271
886
  activation=activation,
1272
887
  routed_scaling_factor=routed_scaling_factor,
1273
888
  )
1274
889
  self.deepep_mode = deepep_mode
890
+ if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM:
891
+ assert self.use_fp8_w8a8, (
892
+ "DeepGEMM requires an fp8_w8a8 model; "
893
+ "alternatively, you can disable DeepGEMM by turning off the ENABLE_JIT_DEEPGEMM environment variable."
894
+ )
895
+
1275
896
  if self.deepep_mode.enable_low_latency():
1276
897
  assert (
1277
898
  deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM