sglang 0.4.4.post1__py3-none-any.whl → 0.4.4.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (185) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +6 -0
  3. sglang/bench_one_batch.py +1 -1
  4. sglang/bench_one_batch_server.py +1 -1
  5. sglang/bench_serving.py +26 -4
  6. sglang/check_env.py +3 -4
  7. sglang/lang/backend/openai.py +18 -5
  8. sglang/lang/chat_template.py +28 -7
  9. sglang/lang/interpreter.py +7 -3
  10. sglang/lang/ir.py +10 -0
  11. sglang/srt/_custom_ops.py +1 -1
  12. sglang/srt/code_completion_parser.py +174 -0
  13. sglang/srt/configs/__init__.py +2 -6
  14. sglang/srt/configs/deepseekvl2.py +676 -0
  15. sglang/srt/configs/janus_pro.py +3 -4
  16. sglang/srt/configs/load_config.py +1 -0
  17. sglang/srt/configs/model_config.py +49 -8
  18. sglang/srt/configs/utils.py +25 -0
  19. sglang/srt/connector/__init__.py +51 -0
  20. sglang/srt/connector/base_connector.py +112 -0
  21. sglang/srt/connector/redis.py +85 -0
  22. sglang/srt/connector/s3.py +122 -0
  23. sglang/srt/connector/serde/__init__.py +31 -0
  24. sglang/srt/connector/serde/safe_serde.py +29 -0
  25. sglang/srt/connector/serde/serde.py +43 -0
  26. sglang/srt/connector/utils.py +35 -0
  27. sglang/srt/conversation.py +88 -0
  28. sglang/srt/disaggregation/conn.py +81 -0
  29. sglang/srt/disaggregation/decode.py +495 -0
  30. sglang/srt/disaggregation/mini_lb.py +285 -0
  31. sglang/srt/disaggregation/prefill.py +249 -0
  32. sglang/srt/disaggregation/utils.py +44 -0
  33. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
  34. sglang/srt/distributed/parallel_state.py +42 -8
  35. sglang/srt/entrypoints/engine.py +55 -5
  36. sglang/srt/entrypoints/http_server.py +78 -13
  37. sglang/srt/entrypoints/verl_engine.py +2 -0
  38. sglang/srt/function_call_parser.py +133 -55
  39. sglang/srt/hf_transformers_utils.py +28 -3
  40. sglang/srt/layers/activation.py +4 -2
  41. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  42. sglang/srt/layers/attention/flashattention_backend.py +434 -0
  43. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  44. sglang/srt/layers/attention/flashmla_backend.py +284 -0
  45. sglang/srt/layers/attention/triton_backend.py +171 -38
  46. sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
  47. sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
  48. sglang/srt/layers/attention/utils.py +53 -0
  49. sglang/srt/layers/attention/vision.py +9 -28
  50. sglang/srt/layers/dp_attention.py +41 -19
  51. sglang/srt/layers/layernorm.py +24 -2
  52. sglang/srt/layers/linear.py +17 -5
  53. sglang/srt/layers/logits_processor.py +25 -7
  54. sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
  55. sglang/srt/layers/moe/ep_moe/layer.py +273 -1
  56. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
  57. sglang/srt/layers/moe/fused_moe_native.py +2 -1
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
  63. sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
  64. sglang/srt/layers/moe/topk.py +60 -20
  65. sglang/srt/layers/parameter.py +1 -1
  66. sglang/srt/layers/quantization/__init__.py +80 -53
  67. sglang/srt/layers/quantization/awq.py +200 -0
  68. sglang/srt/layers/quantization/base_config.py +5 -0
  69. sglang/srt/layers/quantization/blockwise_int8.py +1 -1
  70. sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
  71. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
  72. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
  73. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
  74. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
  75. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
  76. sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
  77. sglang/srt/layers/quantization/fp8.py +76 -34
  78. sglang/srt/layers/quantization/fp8_kernel.py +25 -8
  79. sglang/srt/layers/quantization/fp8_utils.py +284 -28
  80. sglang/srt/layers/quantization/gptq.py +36 -19
  81. sglang/srt/layers/quantization/kv_cache.py +98 -0
  82. sglang/srt/layers/quantization/modelopt_quant.py +9 -7
  83. sglang/srt/layers/quantization/utils.py +153 -0
  84. sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
  85. sglang/srt/layers/rotary_embedding.py +78 -87
  86. sglang/srt/layers/sampler.py +1 -1
  87. sglang/srt/lora/backend/base_backend.py +4 -4
  88. sglang/srt/lora/backend/flashinfer_backend.py +12 -9
  89. sglang/srt/lora/backend/triton_backend.py +5 -8
  90. sglang/srt/lora/layers.py +87 -33
  91. sglang/srt/lora/lora.py +2 -22
  92. sglang/srt/lora/lora_manager.py +67 -30
  93. sglang/srt/lora/mem_pool.py +117 -52
  94. sglang/srt/lora/triton_ops/gate_up_lora_b.py +10 -4
  95. sglang/srt/lora/triton_ops/qkv_lora_b.py +8 -3
  96. sglang/srt/lora/triton_ops/sgemm_lora_a.py +16 -5
  97. sglang/srt/lora/triton_ops/sgemm_lora_b.py +11 -6
  98. sglang/srt/lora/utils.py +18 -1
  99. sglang/srt/managers/cache_controller.py +2 -5
  100. sglang/srt/managers/data_parallel_controller.py +30 -8
  101. sglang/srt/managers/expert_distribution.py +81 -0
  102. sglang/srt/managers/io_struct.py +43 -5
  103. sglang/srt/managers/mm_utils.py +373 -0
  104. sglang/srt/managers/multimodal_processor.py +68 -0
  105. sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
  106. sglang/srt/managers/multimodal_processors/clip.py +63 -0
  107. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
  108. sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
  109. sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
  110. sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
  111. sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
  112. sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
  113. sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
  114. sglang/srt/managers/schedule_batch.py +134 -30
  115. sglang/srt/managers/scheduler.py +290 -31
  116. sglang/srt/managers/session_controller.py +1 -1
  117. sglang/srt/managers/tokenizer_manager.py +59 -24
  118. sglang/srt/managers/tp_worker.py +4 -1
  119. sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
  120. sglang/srt/managers/utils.py +6 -1
  121. sglang/srt/mem_cache/hiradix_cache.py +18 -7
  122. sglang/srt/mem_cache/memory_pool.py +255 -98
  123. sglang/srt/mem_cache/paged_allocator.py +2 -2
  124. sglang/srt/mem_cache/radix_cache.py +4 -4
  125. sglang/srt/model_executor/cuda_graph_runner.py +36 -21
  126. sglang/srt/model_executor/forward_batch_info.py +68 -11
  127. sglang/srt/model_executor/model_runner.py +75 -8
  128. sglang/srt/model_loader/loader.py +171 -3
  129. sglang/srt/model_loader/weight_utils.py +51 -3
  130. sglang/srt/models/clip.py +563 -0
  131. sglang/srt/models/deepseek_janus_pro.py +31 -88
  132. sglang/srt/models/deepseek_nextn.py +22 -10
  133. sglang/srt/models/deepseek_v2.py +329 -73
  134. sglang/srt/models/deepseek_vl2.py +358 -0
  135. sglang/srt/models/gemma3_causal.py +694 -0
  136. sglang/srt/models/gemma3_mm.py +468 -0
  137. sglang/srt/models/llama.py +47 -7
  138. sglang/srt/models/llama_eagle.py +1 -0
  139. sglang/srt/models/llama_eagle3.py +196 -0
  140. sglang/srt/models/llava.py +3 -3
  141. sglang/srt/models/llavavid.py +3 -3
  142. sglang/srt/models/minicpmo.py +1995 -0
  143. sglang/srt/models/minicpmv.py +62 -137
  144. sglang/srt/models/mllama.py +4 -4
  145. sglang/srt/models/phi3_small.py +1 -1
  146. sglang/srt/models/qwen2.py +3 -0
  147. sglang/srt/models/qwen2_5_vl.py +68 -146
  148. sglang/srt/models/qwen2_classification.py +75 -0
  149. sglang/srt/models/qwen2_moe.py +9 -1
  150. sglang/srt/models/qwen2_vl.py +25 -63
  151. sglang/srt/openai_api/adapter.py +201 -104
  152. sglang/srt/openai_api/protocol.py +33 -7
  153. sglang/srt/patch_torch.py +71 -0
  154. sglang/srt/sampling/sampling_batch_info.py +1 -1
  155. sglang/srt/sampling/sampling_params.py +6 -6
  156. sglang/srt/server_args.py +114 -14
  157. sglang/srt/speculative/build_eagle_tree.py +7 -347
  158. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
  159. sglang/srt/speculative/eagle_utils.py +208 -252
  160. sglang/srt/speculative/eagle_worker.py +140 -54
  161. sglang/srt/speculative/spec_info.py +6 -1
  162. sglang/srt/torch_memory_saver_adapter.py +22 -0
  163. sglang/srt/utils.py +215 -21
  164. sglang/test/__init__.py +0 -0
  165. sglang/test/attention/__init__.py +0 -0
  166. sglang/test/attention/test_flashattn_backend.py +312 -0
  167. sglang/test/runners.py +29 -2
  168. sglang/test/test_activation.py +2 -1
  169. sglang/test/test_block_fp8.py +5 -4
  170. sglang/test/test_block_fp8_ep.py +2 -1
  171. sglang/test/test_dynamic_grad_mode.py +58 -0
  172. sglang/test/test_layernorm.py +3 -2
  173. sglang/test/test_utils.py +56 -5
  174. sglang/utils.py +31 -0
  175. sglang/version.py +1 -1
  176. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/METADATA +16 -8
  177. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/RECORD +180 -132
  178. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/WHEEL +1 -1
  179. sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
  180. sglang/srt/managers/image_processor.py +0 -55
  181. sglang/srt/managers/image_processors/base_image_processor.py +0 -219
  182. sglang/srt/managers/image_processors/minicpmv.py +0 -86
  183. sglang/srt/managers/multi_modality_padding.py +0 -134
  184. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info/licenses}/LICENSE +0 -0
  185. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/top_level.txt +0 -0
@@ -21,7 +21,9 @@ import torch.nn as nn
21
21
 
22
22
  from sglang.srt.utils import is_cuda_available
23
23
 
24
- if is_cuda_available():
24
+ _is_cuda = is_cuda_available()
25
+
26
+ if _is_cuda:
25
27
  from sgl_kernel import (
26
28
  fused_add_rmsnorm,
27
29
  gemma_fused_add_rmsnorm,
@@ -117,7 +119,27 @@ class GemmaRMSNorm(CustomOp):
117
119
  return out
118
120
 
119
121
 
120
- if not is_cuda_available():
122
+ class Gemma3RMSNorm(nn.Module):
123
+ def __init__(self, dim: int, eps: float = 1e-6):
124
+ super().__init__()
125
+ self.eps = eps
126
+ self.weight = nn.Parameter(torch.zeros(dim))
127
+
128
+ def _norm(self, x):
129
+ return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps)
130
+
131
+ def forward(self, x):
132
+ output = self._norm(x.float())
133
+ # Llama does x.to(float16) * w whilst Gemma3 is (x * w).to(float16)
134
+ # See https://github.com/huggingface/transformers/pull/29402
135
+ output = output * (1.0 + self.weight.float())
136
+ return output.type_as(x)
137
+
138
+ def extra_repr(self):
139
+ return f"{tuple(self.weight.shape)}, eps={self.eps}"
140
+
141
+
142
+ if not _is_cuda:
121
143
  logger.info(
122
144
  "sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries."
123
145
  )
@@ -23,6 +23,7 @@ from sglang.srt.layers.parameter import (
23
23
  PackedvLLMParameter,
24
24
  PerTensorScaleParameter,
25
25
  RowvLLMParameter,
26
+ _ColumnvLLMParameter,
26
27
  )
27
28
  from sglang.srt.layers.quantization.base_config import (
28
29
  QuantizationConfig,
@@ -423,8 +424,6 @@ class ColumnParallelLinear(LinearBase):
423
424
  assert loaded_weight.numel() == 1
424
425
  loaded_weight = loaded_weight.reshape(1)
425
426
 
426
- from sglang.srt.layers.parameter import _ColumnvLLMParameter
427
-
428
427
  if isinstance(param, _ColumnvLLMParameter):
429
428
  param.load_column_parallel_weight(
430
429
  loaded_weight,
@@ -687,10 +686,19 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
687
686
  ):
688
687
  if loaded_shard_id is None:
689
688
  if isinstance(param, PerTensorScaleParameter):
690
- param.load_merged_column_weight(loaded_weight=loaded_weight, shard_id=0)
689
+ param.load_merged_column_weight(
690
+ loaded_weight=loaded_weight,
691
+ shard_id=0,
692
+ tp_rank=self.tp_rank,
693
+ tp_size=self.tp_size,
694
+ )
691
695
  return
692
696
  elif type(param) in (RowvLLMParameter, BasevLLMParameter):
693
- param.load_merged_column_weight(loaded_weight=loaded_weight)
697
+ param.load_merged_column_weight(
698
+ loaded_weight=loaded_weight,
699
+ tp_rank=self.tp_rank,
700
+ tp_size=self.tp_size,
701
+ )
694
702
  return
695
703
  # TODO: @dsikka - move to parameter.py
696
704
  self._load_fused_module_from_checkpoint(param, loaded_weight)
@@ -719,6 +727,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
719
727
  shard_offset=shard_offset,
720
728
  shard_size=shard_size,
721
729
  use_presharded_weights=self.use_presharded_weights,
730
+ tp_rank=self.tp_rank,
731
+ tp_size=self.tp_size,
722
732
  )
723
733
 
724
734
 
@@ -782,6 +792,8 @@ class QKVParallelLinear(ColumnParallelLinear):
782
792
  else:
783
793
  self.num_kv_heads = divide(self.total_num_kv_heads, tp_size)
784
794
  self.num_kv_head_replicas = 1
795
+ self.q_proj_shard_size = self.num_heads * self.head_size
796
+ self.kv_proj_shard_size = self.num_kv_heads * self.head_size
785
797
  input_size = self.hidden_size
786
798
  output_size = (
787
799
  (self.num_heads + 2 * self.num_kv_heads) * tp_size * self.head_size
@@ -1234,7 +1246,7 @@ class RowParallelLinear(LinearBase):
1234
1246
  assert loaded_weight.numel() == 1
1235
1247
  loaded_weight = loaded_weight.reshape(1)
1236
1248
 
1237
- if isinstance(param, BasevLLMParameter):
1249
+ if isinstance(param, RowvLLMParameter):
1238
1250
  # This `BasevLLMParameter` is defined in sglang/srt/layers/parameter.py,
1239
1251
  # It supports additional parameters like tp_rank and use_presharded_weights.
1240
1252
  param.load_row_parallel_weight(
@@ -28,7 +28,7 @@ from sglang.srt.distributed import (
28
28
  tensor_model_parallel_all_gather,
29
29
  )
30
30
  from sglang.srt.layers.dp_attention import (
31
- dp_gather,
31
+ dp_gather_replicate,
32
32
  dp_scatter,
33
33
  get_attention_dp_rank,
34
34
  get_attention_dp_size,
@@ -223,16 +223,18 @@ class LogitsProcessor(nn.Module):
223
223
  hidden_states,
224
224
  lm_head: VocabParallelEmbedding,
225
225
  logits_metadata: Union[LogitsMetadata, ForwardBatch],
226
+ aux_hidden_states: Optional[torch.Tensor] = None,
226
227
  ) -> LogitsProcessorOutput:
227
228
  if isinstance(logits_metadata, ForwardBatch):
228
229
  logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata)
229
-
230
230
  # Get the last hidden states and last logits for the next token prediction
231
231
  if (
232
232
  logits_metadata.forward_mode.is_decode_or_idle()
233
233
  or logits_metadata.forward_mode.is_target_verify()
234
234
  ):
235
235
  pruned_states = hidden_states
236
+ if aux_hidden_states is not None:
237
+ aux_pruned_states = [hidden for hidden in aux_hidden_states]
236
238
  sample_indices = None
237
239
  input_logprob_indices = None
238
240
  elif (
@@ -256,6 +258,8 @@ class LogitsProcessor(nn.Module):
256
258
  - 1
257
259
  )
258
260
  pruned_states = hidden_states[last_index]
261
+ if aux_hidden_states is not None:
262
+ aux_pruned_states = [hidden[last_index] for hidden in aux_hidden_states]
259
263
  sample_indices = None
260
264
  input_logprob_indices = None
261
265
  else:
@@ -319,13 +323,27 @@ class LogitsProcessor(nn.Module):
319
323
  hidden_states_to_store: Optional[torch.Tensor] = None
320
324
  if logits_metadata.capture_hidden_mode.need_capture():
321
325
  if logits_metadata.capture_hidden_mode.is_full():
322
- hidden_states_to_store = hidden_states
326
+ if aux_hidden_states is not None:
327
+ aux_hidden_states = torch.cat(aux_hidden_states, dim=-1)
328
+ hidden_states_to_store = aux_hidden_states
329
+ else:
330
+ hidden_states_to_store = hidden_states
323
331
  elif logits_metadata.capture_hidden_mode.is_last():
324
332
  # Get the last token hidden states. If sample_indices is None,
325
333
  # pruned states only contain the last tokens already.
326
- hidden_states_to_store = (
327
- pruned_states[sample_indices] if sample_indices else pruned_states
328
- )
334
+ if aux_hidden_states is not None:
335
+ aux_pruned_states = torch.cat(aux_pruned_states, dim=-1)
336
+ hidden_states_to_store = (
337
+ aux_pruned_states[sample_indices]
338
+ if sample_indices
339
+ else aux_pruned_states
340
+ )
341
+ else:
342
+ hidden_states_to_store = (
343
+ pruned_states[sample_indices]
344
+ if sample_indices
345
+ else pruned_states
346
+ )
329
347
  else:
330
348
  assert False, "Should never reach"
331
349
 
@@ -410,7 +428,7 @@ class LogitsProcessor(nn.Module):
410
428
  logits_metadata.gathered_buffer,
411
429
  hidden_states.clone(),
412
430
  )
413
- dp_gather(hidden_states, local_hidden_states, logits_metadata, "embedding")
431
+ dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata)
414
432
 
415
433
  if hasattr(lm_head, "weight"):
416
434
  logits = torch.matmul(
@@ -5,6 +5,7 @@ import torch
5
5
  import triton
6
6
  import triton.language as tl
7
7
 
8
+ from sglang.srt.distributed import get_tensor_model_parallel_rank
8
9
  from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
9
10
  from sglang.srt.utils import is_cuda
10
11
 
@@ -16,6 +17,115 @@ if _is_cuda:
16
17
  logger = logging.getLogger(__name__)
17
18
 
18
19
 
20
+ @triton.jit
21
+ def deepep_permute_triton_kernel(
22
+ input_ptr,
23
+ gateup_input_ptr,
24
+ src2dst_ptr,
25
+ topk_ids_ptr,
26
+ a1_scales_ptr,
27
+ topk,
28
+ hidden_size,
29
+ BLOCK_SIZE: tl.constexpr,
30
+ ):
31
+ OutDtype = gateup_input_ptr.dtype.element_ty
32
+
33
+ src_idx = tl.program_id(0)
34
+ src2dst_ptr = src2dst_ptr + src_idx * topk
35
+ topk_ids_ptr = topk_ids_ptr + src_idx * topk
36
+
37
+ src_ptr = input_ptr + src_idx * hidden_size
38
+
39
+ for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
40
+ offset = start_offset + tl.arange(0, BLOCK_SIZE)
41
+ mask = offset < hidden_size
42
+ in_data = tl.load(src_ptr + offset, mask=mask).to(OutDtype)
43
+
44
+ for idx in range(topk):
45
+ dst_idx = tl.load(src2dst_ptr + idx)
46
+ if dst_idx >= 0:
47
+ dst_ptr = gateup_input_ptr + dst_idx * hidden_size
48
+ tl.store(dst_ptr + offset, in_data, mask=mask)
49
+
50
+
51
+ @triton.jit
52
+ def deepep_post_reorder_triton_kernel(
53
+ down_output_ptr,
54
+ output_ptr,
55
+ src2dst_ptr,
56
+ topk_ids_ptr,
57
+ topk_weights_ptr,
58
+ topk,
59
+ hidden_size,
60
+ BLOCK_SIZE: tl.constexpr,
61
+ ):
62
+ InDtype = down_output_ptr.dtype.element_ty
63
+
64
+ src_idx = tl.program_id(0)
65
+ src2dst_ptr = src2dst_ptr + src_idx * topk
66
+ topk_ids_ptr = topk_ids_ptr + src_idx * topk
67
+ topk_weights_ptr = topk_weights_ptr + src_idx * topk
68
+
69
+ store_ptr = output_ptr + src_idx * hidden_size
70
+ for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
71
+ offset = start_offset + tl.arange(0, BLOCK_SIZE)
72
+ mask = offset < hidden_size
73
+ sum_vec = tl.zeros([BLOCK_SIZE], dtype=InDtype)
74
+ for idx in range(topk):
75
+ dst_idx = tl.load(src2dst_ptr + idx)
76
+ if dst_idx >= 0:
77
+ weigh_scale = tl.load(topk_weights_ptr + idx).to(InDtype)
78
+ load_ptr = down_output_ptr + dst_idx * hidden_size
79
+ in_data = tl.load(load_ptr + offset, mask=mask)
80
+ sum_vec += in_data * weigh_scale
81
+ tl.store(store_ptr + offset, sum_vec, mask=mask)
82
+
83
+
84
+ @triton.jit
85
+ def compute_src2dst_triton_kernel(
86
+ reorder_ids, src2dst, num_toks, BLOCK_SIZE: tl.constexpr
87
+ ):
88
+ pid = tl.program_id(axis=0)
89
+ dst_id = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
90
+ mask = dst_id < num_toks
91
+ src_id = tl.load(reorder_ids + dst_id, mask=mask)
92
+ tl.store(src2dst + src_id, dst_id, mask=mask)
93
+
94
+
95
+ @triton.jit
96
+ def deepep_compute_src2dst_triton_kernel(
97
+ reorder_ids, src2dst, num_toks, num_minus_one, BLOCK_SIZE: tl.constexpr
98
+ ):
99
+ pid = tl.program_id(axis=0)
100
+ dst_id = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
101
+ mask = dst_id < num_toks
102
+ src_id = tl.load(reorder_ids + dst_id, mask=mask)
103
+ num_invalid = tl.load(num_minus_one)
104
+ tl.store(src2dst + src_id, dst_id - num_invalid, mask=mask)
105
+
106
+
107
+ def deepep_run_moe_deep_preprocess(topk_ids: torch.Tensor, num_experts: int):
108
+ reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True)
109
+ seg_indptr = torch.empty(num_experts + 1, device=topk_ids.device, dtype=torch.int64)
110
+ src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int64)
111
+
112
+ # Find offet
113
+ expert_ids = torch.arange(
114
+ num_experts + 1, device=topk_ids.device, dtype=reorder_topk_ids.dtype
115
+ )
116
+ torch.searchsorted(reorder_topk_ids, expert_ids, out=seg_indptr)
117
+ num_minus_one = seg_indptr[0]
118
+ seg_indptr = seg_indptr - num_minus_one
119
+
120
+ BLOCK_SIZE = 512
121
+ grid = (triton.cdiv(topk_ids.numel(), BLOCK_SIZE),)
122
+ deepep_compute_src2dst_triton_kernel[grid](
123
+ reorder_ids, src2dst, topk_ids.numel(), num_minus_one, BLOCK_SIZE
124
+ )
125
+ reorder_topk_ids = reorder_topk_ids[num_minus_one:]
126
+ return reorder_topk_ids, src2dst, seg_indptr
127
+
128
+
19
129
  @triton.jit
20
130
  def compute_seg_indptr_triton_kernel(reorder_topk_ids, seg_indptr, num_toks):
21
131
  expert = tl.program_id(0)
@@ -33,17 +143,6 @@ def compute_seg_indptr_triton_kernel(reorder_topk_ids, seg_indptr, num_toks):
33
143
  tl.store(seg_indptr + expert + 1, target_location + 1)
34
144
 
35
145
 
36
- @triton.jit
37
- def compute_src2dst_triton_kernel(
38
- reorder_ids, src2dst, num_toks, BLOCK_SIZE: tl.constexpr
39
- ):
40
- pid = tl.program_id(axis=0)
41
- dst_id = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
42
- mask = dst_id < num_toks
43
- src_id = tl.load(reorder_ids + dst_id, mask=mask)
44
- tl.store(src2dst + src_id, dst_id, mask=mask)
45
-
46
-
47
146
  def run_moe_ep_preproess(topk_ids: torch.Tensor, num_experts: int):
48
147
  reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True)
49
148
  seg_indptr = torch.zeros(num_experts + 1, device=topk_ids.device, dtype=torch.int64)
@@ -2,8 +2,14 @@ import logging
2
2
  from typing import Callable, List, Optional, Tuple
3
3
 
4
4
  import torch
5
+
6
+ # TODO: use deep_gemm masked kernel after low latency dispatch
7
+ # import deep_gemm
8
+ # from deep_gemm import (
9
+ # get_col_major_tma_aligned_tensor,
10
+ # m_grouped_gemm_fp8_fp8_bf16_nt_masked,
11
+ # )
5
12
  from torch.nn import Module
6
- from vllm import _custom_ops as vllm_ops
7
13
 
8
14
  from sglang.srt.custom_op import CustomOp
9
15
  from sglang.srt.distributed import (
@@ -26,18 +32,23 @@ from sglang.srt.layers.quantization.base_config import (
26
32
  QuantizeMethodBase,
27
33
  )
28
34
  from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
35
+ from sglang.srt.model_executor.forward_batch_info import ForwardMode
29
36
  from sglang.srt.utils import is_cuda, is_hip, set_weight_attrs
30
37
 
31
38
  _is_cuda = is_cuda()
32
39
 
33
40
  if _is_cuda:
34
41
  from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
42
+ else:
43
+ from vllm import _custom_ops as vllm_ops
35
44
 
36
45
 
37
46
  logger = logging.getLogger(__name__)
38
47
 
39
48
  _is_hip = is_hip()
40
49
 
50
+ _buffer = None
51
+
41
52
 
42
53
  class GroupedGemmRunner(torch.nn.Module):
43
54
  flashinfer_gemm_warpper = None
@@ -772,3 +783,264 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
772
783
  custom_routing_function: Optional[Callable] = None,
773
784
  ) -> torch.Tensor:
774
785
  raise NotImplementedError
786
+
787
+
788
+ class DeepEPMoE(EPMoE):
789
+ """
790
+ MoE Expert Parallel Impl based on DeepEP (https://github.com/deepseek-ai/DeepEP/tree/main)
791
+ """
792
+
793
+ _has_printed = False
794
+
795
+ def __init__(
796
+ self,
797
+ num_experts: int,
798
+ top_k: int,
799
+ hidden_size: int,
800
+ intermediate_size: int,
801
+ params_dtype: Optional[torch.dtype] = None,
802
+ renormalize: bool = True,
803
+ use_grouped_topk: bool = False,
804
+ num_expert_group: Optional[int] = None,
805
+ topk_group: Optional[int] = None,
806
+ quant_config: Optional[QuantizationConfig] = None,
807
+ tp_size: Optional[int] = None,
808
+ prefix: str = "",
809
+ correction_bias: Optional[torch.Tensor] = None,
810
+ custom_routing_function: Optional[Callable] = None,
811
+ activation: str = "silu",
812
+ ):
813
+ super().__init__(
814
+ num_experts,
815
+ top_k,
816
+ hidden_size,
817
+ intermediate_size,
818
+ params_dtype,
819
+ renormalize,
820
+ use_grouped_topk,
821
+ num_expert_group,
822
+ topk_group,
823
+ quant_config,
824
+ tp_size,
825
+ prefix,
826
+ correction_bias,
827
+ custom_routing_function,
828
+ activation,
829
+ )
830
+
831
+ def forward(
832
+ self,
833
+ hidden_states: torch.Tensor,
834
+ reorder_topk_ids: torch.Tensor,
835
+ seg_indptr: torch.Tensor,
836
+ forward_mode: ForwardMode,
837
+ ):
838
+ # Todo: use m_grouped_gemm_fp8_fp8_bf16_nt_masked after low_latency dispatch (decode)
839
+ if True: # not forward_mode.is_decode():
840
+ return self.forward_normal(hidden_states, reorder_topk_ids, seg_indptr)
841
+ else:
842
+ return self.forward_deepgemm_masked(
843
+ hidden_states, reorder_topk_ids, seg_indptr
844
+ )
845
+
846
+ def forward_normal(
847
+ self,
848
+ hidden_states: torch.Tensor,
849
+ reorder_topk_ids: torch.Tensor,
850
+ seg_indptr: torch.Tensor,
851
+ ):
852
+ assert self.quant_method is not None
853
+ assert self.activation == "silu"
854
+ if self.grouped_gemm_runner is None:
855
+ self.grouped_gemm_runner = GroupedGemmRunner(
856
+ hidden_states.device, use_flashinfer=False # TODO: use flashinfer
857
+ )
858
+
859
+ if self.activation_scheme == "dynamic" and not self.use_block_quant:
860
+ max_value = (
861
+ torch.max(hidden_states)
862
+ .repeat(self.num_experts_per_partition)
863
+ .to(torch.float32)
864
+ )
865
+ self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max
866
+ weight_indices_cur_rank = torch.arange(
867
+ 0,
868
+ self.num_experts_per_partition,
869
+ device=hidden_states.device,
870
+ dtype=torch.int64,
871
+ )
872
+
873
+ # GroupGemm-0
874
+ gateup_output = torch.empty(
875
+ hidden_states.shape[0],
876
+ self.w13_weight.shape[1],
877
+ device=hidden_states.device,
878
+ dtype=hidden_states.dtype,
879
+ )
880
+
881
+ if hidden_states.shape[0] > 0:
882
+ gateup_output = self.grouped_gemm_runner(
883
+ a=hidden_states,
884
+ b=self.w13_weight,
885
+ c=gateup_output,
886
+ batch_size=self.num_experts_per_partition,
887
+ weight_column_major=True,
888
+ seg_indptr=seg_indptr,
889
+ weight_indices=weight_indices_cur_rank,
890
+ use_fp8_w8a8=self.use_fp8_w8a8,
891
+ scale_a=self.w13_input_scale,
892
+ scale_b=(
893
+ self.w13_weight_scale_inv
894
+ if self.use_block_quant
895
+ else self.w13_weight_scale
896
+ ),
897
+ block_shape=self.block_shape,
898
+ )
899
+
900
+ # Act
901
+ down_input = torch.empty(
902
+ gateup_output.shape[0],
903
+ gateup_output.shape[1] // 2,
904
+ device=gateup_output.device,
905
+ dtype=(
906
+ self.fp8_dtype
907
+ if (self.use_fp8_w8a8 and not self.use_block_quant)
908
+ else hidden_states.dtype
909
+ ),
910
+ )
911
+ if self.w2_input_scale is None and not self.use_block_quant:
912
+ self.w2_input_scale = torch.ones(
913
+ self.num_experts_per_partition,
914
+ dtype=torch.float32,
915
+ device=hidden_states.device,
916
+ )
917
+
918
+ if self.activation == "silu":
919
+ silu_and_mul_triton_kernel[(gateup_output.shape[0],)](
920
+ gateup_output,
921
+ down_input,
922
+ gateup_output.shape[1],
923
+ reorder_topk_ids,
924
+ self.w2_input_scale,
925
+ 0,
926
+ self.num_experts_per_partition - 1,
927
+ BLOCK_SIZE=512,
928
+ )
929
+ else:
930
+ raise ValueError(f"Unsupported activation: {self.activation=}")
931
+
932
+ # GroupGemm-1
933
+ down_output = torch.empty(
934
+ down_input.shape[0],
935
+ self.w2_weight.shape[1],
936
+ device=hidden_states.device,
937
+ dtype=hidden_states.dtype,
938
+ )
939
+ if down_input.shape[0] > 0:
940
+ down_output = self.grouped_gemm_runner(
941
+ a=down_input,
942
+ b=self.w2_weight,
943
+ c=down_output,
944
+ batch_size=self.num_experts_per_partition,
945
+ weight_column_major=True,
946
+ seg_indptr=seg_indptr,
947
+ weight_indices=weight_indices_cur_rank,
948
+ use_fp8_w8a8=self.use_fp8_w8a8,
949
+ scale_a=self.w2_input_scale,
950
+ scale_b=(
951
+ self.w2_weight_scale_inv
952
+ if self.use_block_quant
953
+ else self.w2_weight_scale
954
+ ),
955
+ block_shape=self.block_shape,
956
+ )
957
+ return down_output
958
+
959
+ def forward_deepgemm_masked(
960
+ self,
961
+ hidden_states: torch.Tensor,
962
+ reorder_topk_ids: torch.Tensor,
963
+ seg_indptr: torch.Tensor,
964
+ ):
965
+ assert self.quant_method is not None
966
+ assert self.activation == "silu"
967
+
968
+ if self.activation_scheme == "dynamic" and not self.use_block_quant:
969
+ max_value = (
970
+ torch.max(hidden_states)
971
+ .repeat(self.num_experts_per_partition)
972
+ .to(torch.float32)
973
+ )
974
+ self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max
975
+
976
+ # GroupGemm-0
977
+ gateup_output = torch.empty(
978
+ hidden_states.shape[0],
979
+ self.w13_weight.shape[1],
980
+ device=hidden_states.device,
981
+ dtype=hidden_states.dtype,
982
+ )
983
+ if hidden_states.shape[0] > 0:
984
+ # Transpose earlier so that the testing will not trigger transposing kernels
985
+ hidden_states = (
986
+ hidden_states[0],
987
+ get_col_major_tma_aligned_tensor(hidden_states[1]),
988
+ )
989
+ """
990
+ gateup_output = deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(
991
+ hidden_states, self.w13_weight, out, masked_m, expected_m
992
+ )
993
+ """
994
+
995
+ # Act
996
+ down_input = torch.empty(
997
+ gateup_output.shape[0],
998
+ gateup_output.shape[1] // 2,
999
+ device=gateup_output.device,
1000
+ dtype=(
1001
+ self.fp8_dtype
1002
+ if (self.use_fp8_w8a8 and not self.use_block_quant)
1003
+ else hidden_states.dtype
1004
+ ),
1005
+ )
1006
+ if self.w2_input_scale is None and not self.use_block_quant:
1007
+ self.w2_input_scale = torch.ones(
1008
+ self.num_experts_per_partition,
1009
+ dtype=torch.float32,
1010
+ device=hidden_states.device,
1011
+ )
1012
+
1013
+ if self.activation == "silu":
1014
+ silu_and_mul_triton_kernel[(gateup_output.shape[0],)](
1015
+ gateup_output,
1016
+ down_input,
1017
+ gateup_output.shape[1],
1018
+ reorder_topk_ids,
1019
+ self.w2_input_scale,
1020
+ 0,
1021
+ self.num_experts_per_partition - 1,
1022
+ BLOCK_SIZE=512,
1023
+ )
1024
+ else:
1025
+ raise ValueError(f"Unsupported activation: {self.activation=}")
1026
+
1027
+ # GroupGemm-1
1028
+ down_output = torch.empty(
1029
+ down_input.shape[0],
1030
+ self.w2_weight.shape[1],
1031
+ device=hidden_states.device,
1032
+ dtype=hidden_states.dtype,
1033
+ )
1034
+ if down_input.shape[0] > 0:
1035
+ # Transpose earlier so that the testing will not trigger transposing kernels
1036
+ down_input = (
1037
+ down_input[0],
1038
+ get_col_major_tma_aligned_tensor(down_input[1]),
1039
+ )
1040
+ """
1041
+ down_output = deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(
1042
+ down_input, self.w2_weight, out, masked_m, expected_m
1043
+ )
1044
+ """
1045
+
1046
+ return down_output