sglang 0.4.4__py3-none-any.whl → 0.4.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (176) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +6 -0
  3. sglang/bench_one_batch.py +1 -1
  4. sglang/bench_one_batch_server.py +1 -1
  5. sglang/bench_serving.py +3 -1
  6. sglang/check_env.py +3 -4
  7. sglang/lang/backend/openai.py +18 -5
  8. sglang/lang/chat_template.py +28 -7
  9. sglang/lang/interpreter.py +7 -3
  10. sglang/lang/ir.py +10 -0
  11. sglang/srt/_custom_ops.py +1 -1
  12. sglang/srt/code_completion_parser.py +174 -0
  13. sglang/srt/configs/__init__.py +2 -6
  14. sglang/srt/configs/deepseekvl2.py +667 -0
  15. sglang/srt/configs/janus_pro.py +3 -4
  16. sglang/srt/configs/load_config.py +1 -0
  17. sglang/srt/configs/model_config.py +63 -11
  18. sglang/srt/configs/utils.py +25 -0
  19. sglang/srt/connector/__init__.py +51 -0
  20. sglang/srt/connector/base_connector.py +112 -0
  21. sglang/srt/connector/redis.py +85 -0
  22. sglang/srt/connector/s3.py +122 -0
  23. sglang/srt/connector/serde/__init__.py +31 -0
  24. sglang/srt/connector/serde/safe_serde.py +29 -0
  25. sglang/srt/connector/serde/serde.py +43 -0
  26. sglang/srt/connector/utils.py +35 -0
  27. sglang/srt/conversation.py +88 -0
  28. sglang/srt/disaggregation/conn.py +81 -0
  29. sglang/srt/disaggregation/decode.py +495 -0
  30. sglang/srt/disaggregation/mini_lb.py +285 -0
  31. sglang/srt/disaggregation/prefill.py +249 -0
  32. sglang/srt/disaggregation/utils.py +44 -0
  33. sglang/srt/distributed/parallel_state.py +10 -3
  34. sglang/srt/entrypoints/engine.py +55 -5
  35. sglang/srt/entrypoints/http_server.py +71 -12
  36. sglang/srt/function_call_parser.py +164 -54
  37. sglang/srt/hf_transformers_utils.py +28 -3
  38. sglang/srt/layers/activation.py +4 -2
  39. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  40. sglang/srt/layers/attention/flashattention_backend.py +295 -0
  41. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  42. sglang/srt/layers/attention/flashmla_backend.py +284 -0
  43. sglang/srt/layers/attention/triton_backend.py +171 -38
  44. sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
  45. sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
  46. sglang/srt/layers/attention/utils.py +53 -0
  47. sglang/srt/layers/attention/vision.py +9 -28
  48. sglang/srt/layers/dp_attention.py +62 -23
  49. sglang/srt/layers/elementwise.py +411 -0
  50. sglang/srt/layers/layernorm.py +24 -2
  51. sglang/srt/layers/linear.py +17 -5
  52. sglang/srt/layers/logits_processor.py +26 -7
  53. sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
  54. sglang/srt/layers/moe/ep_moe/layer.py +273 -1
  55. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
  56. sglang/srt/layers/moe/fused_moe_native.py +2 -1
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
  63. sglang/srt/layers/moe/router.py +342 -0
  64. sglang/srt/layers/moe/topk.py +31 -18
  65. sglang/srt/layers/parameter.py +1 -1
  66. sglang/srt/layers/quantization/__init__.py +184 -126
  67. sglang/srt/layers/quantization/base_config.py +5 -0
  68. sglang/srt/layers/quantization/blockwise_int8.py +1 -1
  69. sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
  70. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
  71. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
  72. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
  73. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
  74. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
  75. sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
  76. sglang/srt/layers/quantization/fp8.py +76 -34
  77. sglang/srt/layers/quantization/fp8_kernel.py +24 -8
  78. sglang/srt/layers/quantization/fp8_utils.py +284 -28
  79. sglang/srt/layers/quantization/gptq.py +36 -9
  80. sglang/srt/layers/quantization/kv_cache.py +98 -0
  81. sglang/srt/layers/quantization/modelopt_quant.py +9 -7
  82. sglang/srt/layers/quantization/utils.py +153 -0
  83. sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
  84. sglang/srt/layers/rotary_embedding.py +66 -87
  85. sglang/srt/layers/sampler.py +1 -1
  86. sglang/srt/lora/layers.py +68 -0
  87. sglang/srt/lora/lora.py +2 -22
  88. sglang/srt/lora/lora_manager.py +47 -23
  89. sglang/srt/lora/mem_pool.py +110 -51
  90. sglang/srt/lora/utils.py +12 -1
  91. sglang/srt/managers/cache_controller.py +4 -5
  92. sglang/srt/managers/data_parallel_controller.py +31 -9
  93. sglang/srt/managers/expert_distribution.py +81 -0
  94. sglang/srt/managers/io_struct.py +39 -3
  95. sglang/srt/managers/mm_utils.py +373 -0
  96. sglang/srt/managers/multimodal_processor.py +68 -0
  97. sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
  98. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
  99. sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
  100. sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
  101. sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
  102. sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
  103. sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
  104. sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
  105. sglang/srt/managers/schedule_batch.py +134 -31
  106. sglang/srt/managers/scheduler.py +325 -38
  107. sglang/srt/managers/scheduler_output_processor_mixin.py +4 -1
  108. sglang/srt/managers/session_controller.py +1 -1
  109. sglang/srt/managers/tokenizer_manager.py +59 -23
  110. sglang/srt/managers/tp_worker.py +1 -1
  111. sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
  112. sglang/srt/managers/utils.py +6 -1
  113. sglang/srt/mem_cache/hiradix_cache.py +27 -8
  114. sglang/srt/mem_cache/memory_pool.py +258 -98
  115. sglang/srt/mem_cache/paged_allocator.py +2 -2
  116. sglang/srt/mem_cache/radix_cache.py +4 -4
  117. sglang/srt/model_executor/cuda_graph_runner.py +85 -28
  118. sglang/srt/model_executor/forward_batch_info.py +81 -15
  119. sglang/srt/model_executor/model_runner.py +70 -6
  120. sglang/srt/model_loader/loader.py +160 -2
  121. sglang/srt/model_loader/weight_utils.py +45 -0
  122. sglang/srt/models/deepseek_janus_pro.py +29 -86
  123. sglang/srt/models/deepseek_nextn.py +22 -10
  124. sglang/srt/models/deepseek_v2.py +326 -192
  125. sglang/srt/models/deepseek_vl2.py +358 -0
  126. sglang/srt/models/gemma3_causal.py +684 -0
  127. sglang/srt/models/gemma3_mm.py +462 -0
  128. sglang/srt/models/grok.py +374 -119
  129. sglang/srt/models/llama.py +47 -7
  130. sglang/srt/models/llama_eagle.py +1 -0
  131. sglang/srt/models/llama_eagle3.py +196 -0
  132. sglang/srt/models/llava.py +3 -3
  133. sglang/srt/models/llavavid.py +3 -3
  134. sglang/srt/models/minicpmo.py +1995 -0
  135. sglang/srt/models/minicpmv.py +62 -137
  136. sglang/srt/models/mllama.py +4 -4
  137. sglang/srt/models/phi3_small.py +1 -1
  138. sglang/srt/models/qwen2.py +3 -0
  139. sglang/srt/models/qwen2_5_vl.py +68 -146
  140. sglang/srt/models/qwen2_classification.py +75 -0
  141. sglang/srt/models/qwen2_moe.py +9 -1
  142. sglang/srt/models/qwen2_vl.py +25 -63
  143. sglang/srt/openai_api/adapter.py +145 -47
  144. sglang/srt/openai_api/protocol.py +23 -2
  145. sglang/srt/sampling/sampling_batch_info.py +1 -1
  146. sglang/srt/sampling/sampling_params.py +6 -6
  147. sglang/srt/server_args.py +104 -14
  148. sglang/srt/speculative/build_eagle_tree.py +7 -347
  149. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
  150. sglang/srt/speculative/eagle_utils.py +208 -252
  151. sglang/srt/speculative/eagle_worker.py +139 -53
  152. sglang/srt/speculative/spec_info.py +6 -1
  153. sglang/srt/torch_memory_saver_adapter.py +22 -0
  154. sglang/srt/utils.py +182 -21
  155. sglang/test/__init__.py +0 -0
  156. sglang/test/attention/__init__.py +0 -0
  157. sglang/test/attention/test_flashattn_backend.py +312 -0
  158. sglang/test/runners.py +2 -0
  159. sglang/test/test_activation.py +2 -1
  160. sglang/test/test_block_fp8.py +5 -4
  161. sglang/test/test_block_fp8_ep.py +2 -1
  162. sglang/test/test_dynamic_grad_mode.py +58 -0
  163. sglang/test/test_layernorm.py +3 -2
  164. sglang/test/test_utils.py +55 -4
  165. sglang/utils.py +31 -0
  166. sglang/version.py +1 -1
  167. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/METADATA +12 -8
  168. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/RECORD +171 -125
  169. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/WHEEL +1 -1
  170. sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
  171. sglang/srt/managers/image_processor.py +0 -55
  172. sglang/srt/managers/image_processors/base_image_processor.py +0 -219
  173. sglang/srt/managers/image_processors/minicpmv.py +0 -86
  174. sglang/srt/managers/multi_modality_padding.py +0 -134
  175. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info/licenses}/LICENSE +0 -0
  176. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/top_level.txt +0 -0
@@ -23,11 +23,12 @@ import triton.language as tl
23
23
  from torch import nn
24
24
 
25
25
  from sglang.srt.distributed import (
26
+ get_tensor_model_parallel_rank,
26
27
  get_tensor_model_parallel_world_size,
27
28
  tensor_model_parallel_all_gather,
28
29
  )
29
30
  from sglang.srt.layers.dp_attention import (
30
- dp_gather,
31
+ dp_gather_replicate,
31
32
  dp_scatter,
32
33
  get_attention_dp_rank,
33
34
  get_attention_dp_size,
@@ -222,16 +223,18 @@ class LogitsProcessor(nn.Module):
222
223
  hidden_states,
223
224
  lm_head: VocabParallelEmbedding,
224
225
  logits_metadata: Union[LogitsMetadata, ForwardBatch],
226
+ aux_hidden_states: Optional[torch.Tensor] = None,
225
227
  ) -> LogitsProcessorOutput:
226
228
  if isinstance(logits_metadata, ForwardBatch):
227
229
  logits_metadata = LogitsMetadata.from_forward_batch(logits_metadata)
228
-
229
230
  # Get the last hidden states and last logits for the next token prediction
230
231
  if (
231
232
  logits_metadata.forward_mode.is_decode_or_idle()
232
233
  or logits_metadata.forward_mode.is_target_verify()
233
234
  ):
234
235
  pruned_states = hidden_states
236
+ if aux_hidden_states is not None:
237
+ aux_pruned_states = [hidden for hidden in aux_hidden_states]
235
238
  sample_indices = None
236
239
  input_logprob_indices = None
237
240
  elif (
@@ -255,6 +258,8 @@ class LogitsProcessor(nn.Module):
255
258
  - 1
256
259
  )
257
260
  pruned_states = hidden_states[last_index]
261
+ if aux_hidden_states is not None:
262
+ aux_pruned_states = [hidden[last_index] for hidden in aux_hidden_states]
258
263
  sample_indices = None
259
264
  input_logprob_indices = None
260
265
  else:
@@ -318,13 +323,27 @@ class LogitsProcessor(nn.Module):
318
323
  hidden_states_to_store: Optional[torch.Tensor] = None
319
324
  if logits_metadata.capture_hidden_mode.need_capture():
320
325
  if logits_metadata.capture_hidden_mode.is_full():
321
- hidden_states_to_store = hidden_states
326
+ if aux_hidden_states is not None:
327
+ aux_hidden_states = torch.cat(aux_hidden_states, dim=-1)
328
+ hidden_states_to_store = aux_hidden_states
329
+ else:
330
+ hidden_states_to_store = hidden_states
322
331
  elif logits_metadata.capture_hidden_mode.is_last():
323
332
  # Get the last token hidden states. If sample_indices is None,
324
333
  # pruned states only contain the last tokens already.
325
- hidden_states_to_store = (
326
- pruned_states[sample_indices] if sample_indices else pruned_states
327
- )
334
+ if aux_hidden_states is not None:
335
+ aux_pruned_states = torch.cat(aux_pruned_states, dim=-1)
336
+ hidden_states_to_store = (
337
+ aux_pruned_states[sample_indices]
338
+ if sample_indices
339
+ else aux_pruned_states
340
+ )
341
+ else:
342
+ hidden_states_to_store = (
343
+ pruned_states[sample_indices]
344
+ if sample_indices
345
+ else pruned_states
346
+ )
328
347
  else:
329
348
  assert False, "Should never reach"
330
349
 
@@ -409,7 +428,7 @@ class LogitsProcessor(nn.Module):
409
428
  logits_metadata.gathered_buffer,
410
429
  hidden_states.clone(),
411
430
  )
412
- dp_gather(hidden_states, local_hidden_states, logits_metadata, "embedding")
431
+ dp_gather_replicate(hidden_states, local_hidden_states, logits_metadata)
413
432
 
414
433
  if hasattr(lm_head, "weight"):
415
434
  logits = torch.matmul(
@@ -5,6 +5,7 @@ import torch
5
5
  import triton
6
6
  import triton.language as tl
7
7
 
8
+ from sglang.srt.distributed import get_tensor_model_parallel_rank
8
9
  from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
9
10
  from sglang.srt.utils import is_cuda
10
11
 
@@ -16,6 +17,115 @@ if _is_cuda:
16
17
  logger = logging.getLogger(__name__)
17
18
 
18
19
 
20
+ @triton.jit
21
+ def deepep_permute_triton_kernel(
22
+ input_ptr,
23
+ gateup_input_ptr,
24
+ src2dst_ptr,
25
+ topk_ids_ptr,
26
+ a1_scales_ptr,
27
+ topk,
28
+ hidden_size,
29
+ BLOCK_SIZE: tl.constexpr,
30
+ ):
31
+ OutDtype = gateup_input_ptr.dtype.element_ty
32
+
33
+ src_idx = tl.program_id(0)
34
+ src2dst_ptr = src2dst_ptr + src_idx * topk
35
+ topk_ids_ptr = topk_ids_ptr + src_idx * topk
36
+
37
+ src_ptr = input_ptr + src_idx * hidden_size
38
+
39
+ for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
40
+ offset = start_offset + tl.arange(0, BLOCK_SIZE)
41
+ mask = offset < hidden_size
42
+ in_data = tl.load(src_ptr + offset, mask=mask).to(OutDtype)
43
+
44
+ for idx in range(topk):
45
+ dst_idx = tl.load(src2dst_ptr + idx)
46
+ if dst_idx >= 0:
47
+ dst_ptr = gateup_input_ptr + dst_idx * hidden_size
48
+ tl.store(dst_ptr + offset, in_data, mask=mask)
49
+
50
+
51
+ @triton.jit
52
+ def deepep_post_reorder_triton_kernel(
53
+ down_output_ptr,
54
+ output_ptr,
55
+ src2dst_ptr,
56
+ topk_ids_ptr,
57
+ topk_weights_ptr,
58
+ topk,
59
+ hidden_size,
60
+ BLOCK_SIZE: tl.constexpr,
61
+ ):
62
+ InDtype = down_output_ptr.dtype.element_ty
63
+
64
+ src_idx = tl.program_id(0)
65
+ src2dst_ptr = src2dst_ptr + src_idx * topk
66
+ topk_ids_ptr = topk_ids_ptr + src_idx * topk
67
+ topk_weights_ptr = topk_weights_ptr + src_idx * topk
68
+
69
+ store_ptr = output_ptr + src_idx * hidden_size
70
+ for start_offset in tl.range(0, hidden_size, BLOCK_SIZE):
71
+ offset = start_offset + tl.arange(0, BLOCK_SIZE)
72
+ mask = offset < hidden_size
73
+ sum_vec = tl.zeros([BLOCK_SIZE], dtype=InDtype)
74
+ for idx in range(topk):
75
+ dst_idx = tl.load(src2dst_ptr + idx)
76
+ if dst_idx >= 0:
77
+ weigh_scale = tl.load(topk_weights_ptr + idx).to(InDtype)
78
+ load_ptr = down_output_ptr + dst_idx * hidden_size
79
+ in_data = tl.load(load_ptr + offset, mask=mask)
80
+ sum_vec += in_data * weigh_scale
81
+ tl.store(store_ptr + offset, sum_vec, mask=mask)
82
+
83
+
84
+ @triton.jit
85
+ def compute_src2dst_triton_kernel(
86
+ reorder_ids, src2dst, num_toks, BLOCK_SIZE: tl.constexpr
87
+ ):
88
+ pid = tl.program_id(axis=0)
89
+ dst_id = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
90
+ mask = dst_id < num_toks
91
+ src_id = tl.load(reorder_ids + dst_id, mask=mask)
92
+ tl.store(src2dst + src_id, dst_id, mask=mask)
93
+
94
+
95
+ @triton.jit
96
+ def deepep_compute_src2dst_triton_kernel(
97
+ reorder_ids, src2dst, num_toks, num_minus_one, BLOCK_SIZE: tl.constexpr
98
+ ):
99
+ pid = tl.program_id(axis=0)
100
+ dst_id = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
101
+ mask = dst_id < num_toks
102
+ src_id = tl.load(reorder_ids + dst_id, mask=mask)
103
+ num_invalid = tl.load(num_minus_one)
104
+ tl.store(src2dst + src_id, dst_id - num_invalid, mask=mask)
105
+
106
+
107
+ def deepep_run_moe_deep_preprocess(topk_ids: torch.Tensor, num_experts: int):
108
+ reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True)
109
+ seg_indptr = torch.empty(num_experts + 1, device=topk_ids.device, dtype=torch.int64)
110
+ src2dst = torch.empty(topk_ids.numel(), device=topk_ids.device, dtype=torch.int64)
111
+
112
+ # Find offet
113
+ expert_ids = torch.arange(
114
+ num_experts + 1, device=topk_ids.device, dtype=reorder_topk_ids.dtype
115
+ )
116
+ torch.searchsorted(reorder_topk_ids, expert_ids, out=seg_indptr)
117
+ num_minus_one = seg_indptr[0]
118
+ seg_indptr = seg_indptr - num_minus_one
119
+
120
+ BLOCK_SIZE = 512
121
+ grid = (triton.cdiv(topk_ids.numel(), BLOCK_SIZE),)
122
+ deepep_compute_src2dst_triton_kernel[grid](
123
+ reorder_ids, src2dst, topk_ids.numel(), num_minus_one, BLOCK_SIZE
124
+ )
125
+ reorder_topk_ids = reorder_topk_ids[num_minus_one:]
126
+ return reorder_topk_ids, src2dst, seg_indptr
127
+
128
+
19
129
  @triton.jit
20
130
  def compute_seg_indptr_triton_kernel(reorder_topk_ids, seg_indptr, num_toks):
21
131
  expert = tl.program_id(0)
@@ -33,17 +143,6 @@ def compute_seg_indptr_triton_kernel(reorder_topk_ids, seg_indptr, num_toks):
33
143
  tl.store(seg_indptr + expert + 1, target_location + 1)
34
144
 
35
145
 
36
- @triton.jit
37
- def compute_src2dst_triton_kernel(
38
- reorder_ids, src2dst, num_toks, BLOCK_SIZE: tl.constexpr
39
- ):
40
- pid = tl.program_id(axis=0)
41
- dst_id = pid * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
42
- mask = dst_id < num_toks
43
- src_id = tl.load(reorder_ids + dst_id, mask=mask)
44
- tl.store(src2dst + src_id, dst_id, mask=mask)
45
-
46
-
47
146
  def run_moe_ep_preproess(topk_ids: torch.Tensor, num_experts: int):
48
147
  reorder_topk_ids, reorder_ids = torch.sort(topk_ids.view(-1), stable=True)
49
148
  seg_indptr = torch.zeros(num_experts + 1, device=topk_ids.device, dtype=torch.int64)
@@ -2,8 +2,14 @@ import logging
2
2
  from typing import Callable, List, Optional, Tuple
3
3
 
4
4
  import torch
5
+
6
+ # TODO: use deep_gemm masked kernel after low latency dispatch
7
+ # import deep_gemm
8
+ # from deep_gemm import (
9
+ # get_col_major_tma_aligned_tensor,
10
+ # m_grouped_gemm_fp8_fp8_bf16_nt_masked,
11
+ # )
5
12
  from torch.nn import Module
6
- from vllm import _custom_ops as vllm_ops
7
13
 
8
14
  from sglang.srt.custom_op import CustomOp
9
15
  from sglang.srt.distributed import (
@@ -26,18 +32,23 @@ from sglang.srt.layers.quantization.base_config import (
26
32
  QuantizeMethodBase,
27
33
  )
28
34
  from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
35
+ from sglang.srt.model_executor.forward_batch_info import ForwardMode
29
36
  from sglang.srt.utils import is_cuda, is_hip, set_weight_attrs
30
37
 
31
38
  _is_cuda = is_cuda()
32
39
 
33
40
  if _is_cuda:
34
41
  from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
42
+ else:
43
+ from vllm import _custom_ops as vllm_ops
35
44
 
36
45
 
37
46
  logger = logging.getLogger(__name__)
38
47
 
39
48
  _is_hip = is_hip()
40
49
 
50
+ _buffer = None
51
+
41
52
 
42
53
  class GroupedGemmRunner(torch.nn.Module):
43
54
  flashinfer_gemm_warpper = None
@@ -772,3 +783,264 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
772
783
  custom_routing_function: Optional[Callable] = None,
773
784
  ) -> torch.Tensor:
774
785
  raise NotImplementedError
786
+
787
+
788
+ class DeepEPMoE(EPMoE):
789
+ """
790
+ MoE Expert Parallel Impl based on DeepEP (https://github.com/deepseek-ai/DeepEP/tree/main)
791
+ """
792
+
793
+ _has_printed = False
794
+
795
+ def __init__(
796
+ self,
797
+ num_experts: int,
798
+ top_k: int,
799
+ hidden_size: int,
800
+ intermediate_size: int,
801
+ params_dtype: Optional[torch.dtype] = None,
802
+ renormalize: bool = True,
803
+ use_grouped_topk: bool = False,
804
+ num_expert_group: Optional[int] = None,
805
+ topk_group: Optional[int] = None,
806
+ quant_config: Optional[QuantizationConfig] = None,
807
+ tp_size: Optional[int] = None,
808
+ prefix: str = "",
809
+ correction_bias: Optional[torch.Tensor] = None,
810
+ custom_routing_function: Optional[Callable] = None,
811
+ activation: str = "silu",
812
+ ):
813
+ super().__init__(
814
+ num_experts,
815
+ top_k,
816
+ hidden_size,
817
+ intermediate_size,
818
+ params_dtype,
819
+ renormalize,
820
+ use_grouped_topk,
821
+ num_expert_group,
822
+ topk_group,
823
+ quant_config,
824
+ tp_size,
825
+ prefix,
826
+ correction_bias,
827
+ custom_routing_function,
828
+ activation,
829
+ )
830
+
831
+ def forward(
832
+ self,
833
+ hidden_states: torch.Tensor,
834
+ reorder_topk_ids: torch.Tensor,
835
+ seg_indptr: torch.Tensor,
836
+ forward_mode: ForwardMode,
837
+ ):
838
+ # Todo: use m_grouped_gemm_fp8_fp8_bf16_nt_masked after low_latency dispatch (decode)
839
+ if True: # not forward_mode.is_decode():
840
+ return self.forward_normal(hidden_states, reorder_topk_ids, seg_indptr)
841
+ else:
842
+ return self.forward_deepgemm_masked(
843
+ hidden_states, reorder_topk_ids, seg_indptr
844
+ )
845
+
846
+ def forward_normal(
847
+ self,
848
+ hidden_states: torch.Tensor,
849
+ reorder_topk_ids: torch.Tensor,
850
+ seg_indptr: torch.Tensor,
851
+ ):
852
+ assert self.quant_method is not None
853
+ assert self.activation == "silu"
854
+ if self.grouped_gemm_runner is None:
855
+ self.grouped_gemm_runner = GroupedGemmRunner(
856
+ hidden_states.device, use_flashinfer=False # TODO: use flashinfer
857
+ )
858
+
859
+ if self.activation_scheme == "dynamic" and not self.use_block_quant:
860
+ max_value = (
861
+ torch.max(hidden_states)
862
+ .repeat(self.num_experts_per_partition)
863
+ .to(torch.float32)
864
+ )
865
+ self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max
866
+ weight_indices_cur_rank = torch.arange(
867
+ 0,
868
+ self.num_experts_per_partition,
869
+ device=hidden_states.device,
870
+ dtype=torch.int64,
871
+ )
872
+
873
+ # GroupGemm-0
874
+ gateup_output = torch.empty(
875
+ hidden_states.shape[0],
876
+ self.w13_weight.shape[1],
877
+ device=hidden_states.device,
878
+ dtype=hidden_states.dtype,
879
+ )
880
+
881
+ if hidden_states.shape[0] > 0:
882
+ gateup_output = self.grouped_gemm_runner(
883
+ a=hidden_states,
884
+ b=self.w13_weight,
885
+ c=gateup_output,
886
+ batch_size=self.num_experts_per_partition,
887
+ weight_column_major=True,
888
+ seg_indptr=seg_indptr,
889
+ weight_indices=weight_indices_cur_rank,
890
+ use_fp8_w8a8=self.use_fp8_w8a8,
891
+ scale_a=self.w13_input_scale,
892
+ scale_b=(
893
+ self.w13_weight_scale_inv
894
+ if self.use_block_quant
895
+ else self.w13_weight_scale
896
+ ),
897
+ block_shape=self.block_shape,
898
+ )
899
+
900
+ # Act
901
+ down_input = torch.empty(
902
+ gateup_output.shape[0],
903
+ gateup_output.shape[1] // 2,
904
+ device=gateup_output.device,
905
+ dtype=(
906
+ self.fp8_dtype
907
+ if (self.use_fp8_w8a8 and not self.use_block_quant)
908
+ else hidden_states.dtype
909
+ ),
910
+ )
911
+ if self.w2_input_scale is None and not self.use_block_quant:
912
+ self.w2_input_scale = torch.ones(
913
+ self.num_experts_per_partition,
914
+ dtype=torch.float32,
915
+ device=hidden_states.device,
916
+ )
917
+
918
+ if self.activation == "silu":
919
+ silu_and_mul_triton_kernel[(gateup_output.shape[0],)](
920
+ gateup_output,
921
+ down_input,
922
+ gateup_output.shape[1],
923
+ reorder_topk_ids,
924
+ self.w2_input_scale,
925
+ 0,
926
+ self.num_experts_per_partition - 1,
927
+ BLOCK_SIZE=512,
928
+ )
929
+ else:
930
+ raise ValueError(f"Unsupported activation: {self.activation=}")
931
+
932
+ # GroupGemm-1
933
+ down_output = torch.empty(
934
+ down_input.shape[0],
935
+ self.w2_weight.shape[1],
936
+ device=hidden_states.device,
937
+ dtype=hidden_states.dtype,
938
+ )
939
+ if down_input.shape[0] > 0:
940
+ down_output = self.grouped_gemm_runner(
941
+ a=down_input,
942
+ b=self.w2_weight,
943
+ c=down_output,
944
+ batch_size=self.num_experts_per_partition,
945
+ weight_column_major=True,
946
+ seg_indptr=seg_indptr,
947
+ weight_indices=weight_indices_cur_rank,
948
+ use_fp8_w8a8=self.use_fp8_w8a8,
949
+ scale_a=self.w2_input_scale,
950
+ scale_b=(
951
+ self.w2_weight_scale_inv
952
+ if self.use_block_quant
953
+ else self.w2_weight_scale
954
+ ),
955
+ block_shape=self.block_shape,
956
+ )
957
+ return down_output
958
+
959
+ def forward_deepgemm_masked(
960
+ self,
961
+ hidden_states: torch.Tensor,
962
+ reorder_topk_ids: torch.Tensor,
963
+ seg_indptr: torch.Tensor,
964
+ ):
965
+ assert self.quant_method is not None
966
+ assert self.activation == "silu"
967
+
968
+ if self.activation_scheme == "dynamic" and not self.use_block_quant:
969
+ max_value = (
970
+ torch.max(hidden_states)
971
+ .repeat(self.num_experts_per_partition)
972
+ .to(torch.float32)
973
+ )
974
+ self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max
975
+
976
+ # GroupGemm-0
977
+ gateup_output = torch.empty(
978
+ hidden_states.shape[0],
979
+ self.w13_weight.shape[1],
980
+ device=hidden_states.device,
981
+ dtype=hidden_states.dtype,
982
+ )
983
+ if hidden_states.shape[0] > 0:
984
+ # Transpose earlier so that the testing will not trigger transposing kernels
985
+ hidden_states = (
986
+ hidden_states[0],
987
+ get_col_major_tma_aligned_tensor(hidden_states[1]),
988
+ )
989
+ """
990
+ gateup_output = deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(
991
+ hidden_states, self.w13_weight, out, masked_m, expected_m
992
+ )
993
+ """
994
+
995
+ # Act
996
+ down_input = torch.empty(
997
+ gateup_output.shape[0],
998
+ gateup_output.shape[1] // 2,
999
+ device=gateup_output.device,
1000
+ dtype=(
1001
+ self.fp8_dtype
1002
+ if (self.use_fp8_w8a8 and not self.use_block_quant)
1003
+ else hidden_states.dtype
1004
+ ),
1005
+ )
1006
+ if self.w2_input_scale is None and not self.use_block_quant:
1007
+ self.w2_input_scale = torch.ones(
1008
+ self.num_experts_per_partition,
1009
+ dtype=torch.float32,
1010
+ device=hidden_states.device,
1011
+ )
1012
+
1013
+ if self.activation == "silu":
1014
+ silu_and_mul_triton_kernel[(gateup_output.shape[0],)](
1015
+ gateup_output,
1016
+ down_input,
1017
+ gateup_output.shape[1],
1018
+ reorder_topk_ids,
1019
+ self.w2_input_scale,
1020
+ 0,
1021
+ self.num_experts_per_partition - 1,
1022
+ BLOCK_SIZE=512,
1023
+ )
1024
+ else:
1025
+ raise ValueError(f"Unsupported activation: {self.activation=}")
1026
+
1027
+ # GroupGemm-1
1028
+ down_output = torch.empty(
1029
+ down_input.shape[0],
1030
+ self.w2_weight.shape[1],
1031
+ device=hidden_states.device,
1032
+ dtype=hidden_states.dtype,
1033
+ )
1034
+ if down_input.shape[0] > 0:
1035
+ # Transpose earlier so that the testing will not trigger transposing kernels
1036
+ down_input = (
1037
+ down_input[0],
1038
+ get_col_major_tma_aligned_tensor(down_input[1]),
1039
+ )
1040
+ """
1041
+ down_output = deep_gemm.m_grouped_gemm_fp8_fp8_bf16_nt_masked(
1042
+ down_input, self.w2_weight, out, masked_m, expected_m
1043
+ )
1044
+ """
1045
+
1046
+ return down_output