sglang 0.4.4__py3-none-any.whl → 0.4.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (176) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +6 -0
  3. sglang/bench_one_batch.py +1 -1
  4. sglang/bench_one_batch_server.py +1 -1
  5. sglang/bench_serving.py +3 -1
  6. sglang/check_env.py +3 -4
  7. sglang/lang/backend/openai.py +18 -5
  8. sglang/lang/chat_template.py +28 -7
  9. sglang/lang/interpreter.py +7 -3
  10. sglang/lang/ir.py +10 -0
  11. sglang/srt/_custom_ops.py +1 -1
  12. sglang/srt/code_completion_parser.py +174 -0
  13. sglang/srt/configs/__init__.py +2 -6
  14. sglang/srt/configs/deepseekvl2.py +667 -0
  15. sglang/srt/configs/janus_pro.py +3 -4
  16. sglang/srt/configs/load_config.py +1 -0
  17. sglang/srt/configs/model_config.py +63 -11
  18. sglang/srt/configs/utils.py +25 -0
  19. sglang/srt/connector/__init__.py +51 -0
  20. sglang/srt/connector/base_connector.py +112 -0
  21. sglang/srt/connector/redis.py +85 -0
  22. sglang/srt/connector/s3.py +122 -0
  23. sglang/srt/connector/serde/__init__.py +31 -0
  24. sglang/srt/connector/serde/safe_serde.py +29 -0
  25. sglang/srt/connector/serde/serde.py +43 -0
  26. sglang/srt/connector/utils.py +35 -0
  27. sglang/srt/conversation.py +88 -0
  28. sglang/srt/disaggregation/conn.py +81 -0
  29. sglang/srt/disaggregation/decode.py +495 -0
  30. sglang/srt/disaggregation/mini_lb.py +285 -0
  31. sglang/srt/disaggregation/prefill.py +249 -0
  32. sglang/srt/disaggregation/utils.py +44 -0
  33. sglang/srt/distributed/parallel_state.py +10 -3
  34. sglang/srt/entrypoints/engine.py +55 -5
  35. sglang/srt/entrypoints/http_server.py +71 -12
  36. sglang/srt/function_call_parser.py +164 -54
  37. sglang/srt/hf_transformers_utils.py +28 -3
  38. sglang/srt/layers/activation.py +4 -2
  39. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  40. sglang/srt/layers/attention/flashattention_backend.py +295 -0
  41. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  42. sglang/srt/layers/attention/flashmla_backend.py +284 -0
  43. sglang/srt/layers/attention/triton_backend.py +171 -38
  44. sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
  45. sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
  46. sglang/srt/layers/attention/utils.py +53 -0
  47. sglang/srt/layers/attention/vision.py +9 -28
  48. sglang/srt/layers/dp_attention.py +62 -23
  49. sglang/srt/layers/elementwise.py +411 -0
  50. sglang/srt/layers/layernorm.py +24 -2
  51. sglang/srt/layers/linear.py +17 -5
  52. sglang/srt/layers/logits_processor.py +26 -7
  53. sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
  54. sglang/srt/layers/moe/ep_moe/layer.py +273 -1
  55. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
  56. sglang/srt/layers/moe/fused_moe_native.py +2 -1
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
  63. sglang/srt/layers/moe/router.py +342 -0
  64. sglang/srt/layers/moe/topk.py +31 -18
  65. sglang/srt/layers/parameter.py +1 -1
  66. sglang/srt/layers/quantization/__init__.py +184 -126
  67. sglang/srt/layers/quantization/base_config.py +5 -0
  68. sglang/srt/layers/quantization/blockwise_int8.py +1 -1
  69. sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
  70. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
  71. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
  72. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
  73. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
  74. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
  75. sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
  76. sglang/srt/layers/quantization/fp8.py +76 -34
  77. sglang/srt/layers/quantization/fp8_kernel.py +24 -8
  78. sglang/srt/layers/quantization/fp8_utils.py +284 -28
  79. sglang/srt/layers/quantization/gptq.py +36 -9
  80. sglang/srt/layers/quantization/kv_cache.py +98 -0
  81. sglang/srt/layers/quantization/modelopt_quant.py +9 -7
  82. sglang/srt/layers/quantization/utils.py +153 -0
  83. sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
  84. sglang/srt/layers/rotary_embedding.py +66 -87
  85. sglang/srt/layers/sampler.py +1 -1
  86. sglang/srt/lora/layers.py +68 -0
  87. sglang/srt/lora/lora.py +2 -22
  88. sglang/srt/lora/lora_manager.py +47 -23
  89. sglang/srt/lora/mem_pool.py +110 -51
  90. sglang/srt/lora/utils.py +12 -1
  91. sglang/srt/managers/cache_controller.py +4 -5
  92. sglang/srt/managers/data_parallel_controller.py +31 -9
  93. sglang/srt/managers/expert_distribution.py +81 -0
  94. sglang/srt/managers/io_struct.py +39 -3
  95. sglang/srt/managers/mm_utils.py +373 -0
  96. sglang/srt/managers/multimodal_processor.py +68 -0
  97. sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
  98. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
  99. sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
  100. sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
  101. sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
  102. sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
  103. sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
  104. sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
  105. sglang/srt/managers/schedule_batch.py +134 -31
  106. sglang/srt/managers/scheduler.py +325 -38
  107. sglang/srt/managers/scheduler_output_processor_mixin.py +4 -1
  108. sglang/srt/managers/session_controller.py +1 -1
  109. sglang/srt/managers/tokenizer_manager.py +59 -23
  110. sglang/srt/managers/tp_worker.py +1 -1
  111. sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
  112. sglang/srt/managers/utils.py +6 -1
  113. sglang/srt/mem_cache/hiradix_cache.py +27 -8
  114. sglang/srt/mem_cache/memory_pool.py +258 -98
  115. sglang/srt/mem_cache/paged_allocator.py +2 -2
  116. sglang/srt/mem_cache/radix_cache.py +4 -4
  117. sglang/srt/model_executor/cuda_graph_runner.py +85 -28
  118. sglang/srt/model_executor/forward_batch_info.py +81 -15
  119. sglang/srt/model_executor/model_runner.py +70 -6
  120. sglang/srt/model_loader/loader.py +160 -2
  121. sglang/srt/model_loader/weight_utils.py +45 -0
  122. sglang/srt/models/deepseek_janus_pro.py +29 -86
  123. sglang/srt/models/deepseek_nextn.py +22 -10
  124. sglang/srt/models/deepseek_v2.py +326 -192
  125. sglang/srt/models/deepseek_vl2.py +358 -0
  126. sglang/srt/models/gemma3_causal.py +684 -0
  127. sglang/srt/models/gemma3_mm.py +462 -0
  128. sglang/srt/models/grok.py +374 -119
  129. sglang/srt/models/llama.py +47 -7
  130. sglang/srt/models/llama_eagle.py +1 -0
  131. sglang/srt/models/llama_eagle3.py +196 -0
  132. sglang/srt/models/llava.py +3 -3
  133. sglang/srt/models/llavavid.py +3 -3
  134. sglang/srt/models/minicpmo.py +1995 -0
  135. sglang/srt/models/minicpmv.py +62 -137
  136. sglang/srt/models/mllama.py +4 -4
  137. sglang/srt/models/phi3_small.py +1 -1
  138. sglang/srt/models/qwen2.py +3 -0
  139. sglang/srt/models/qwen2_5_vl.py +68 -146
  140. sglang/srt/models/qwen2_classification.py +75 -0
  141. sglang/srt/models/qwen2_moe.py +9 -1
  142. sglang/srt/models/qwen2_vl.py +25 -63
  143. sglang/srt/openai_api/adapter.py +145 -47
  144. sglang/srt/openai_api/protocol.py +23 -2
  145. sglang/srt/sampling/sampling_batch_info.py +1 -1
  146. sglang/srt/sampling/sampling_params.py +6 -6
  147. sglang/srt/server_args.py +104 -14
  148. sglang/srt/speculative/build_eagle_tree.py +7 -347
  149. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
  150. sglang/srt/speculative/eagle_utils.py +208 -252
  151. sglang/srt/speculative/eagle_worker.py +139 -53
  152. sglang/srt/speculative/spec_info.py +6 -1
  153. sglang/srt/torch_memory_saver_adapter.py +22 -0
  154. sglang/srt/utils.py +182 -21
  155. sglang/test/__init__.py +0 -0
  156. sglang/test/attention/__init__.py +0 -0
  157. sglang/test/attention/test_flashattn_backend.py +312 -0
  158. sglang/test/runners.py +2 -0
  159. sglang/test/test_activation.py +2 -1
  160. sglang/test/test_block_fp8.py +5 -4
  161. sglang/test/test_block_fp8_ep.py +2 -1
  162. sglang/test/test_dynamic_grad_mode.py +58 -0
  163. sglang/test/test_layernorm.py +3 -2
  164. sglang/test/test_utils.py +55 -4
  165. sglang/utils.py +31 -0
  166. sglang/version.py +1 -1
  167. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/METADATA +12 -8
  168. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/RECORD +171 -125
  169. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/WHEEL +1 -1
  170. sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
  171. sglang/srt/managers/image_processor.py +0 -55
  172. sglang/srt/managers/image_processors/base_image_processor.py +0 -219
  173. sglang/srt/managers/image_processors/minicpmv.py +0 -86
  174. sglang/srt/managers/multi_modality_padding.py +0 -134
  175. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info/licenses}/LICENSE +0 -0
  176. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,416 @@
1
+ try:
2
+ from deep_ep import Buffer
3
+
4
+ use_deepep = True
5
+ except ImportError:
6
+ use_deepep = False
7
+
8
+ from typing import Optional, Tuple
9
+
10
+ import torch
11
+ import torch.distributed as dist
12
+
13
+ from sglang.srt.layers.moe.ep_moe.kernels import (
14
+ deepep_permute_triton_kernel,
15
+ deepep_post_reorder_triton_kernel,
16
+ deepep_run_moe_deep_preprocess,
17
+ )
18
+ from sglang.srt.model_executor.forward_batch_info import ForwardMode
19
+
20
+ _buffer_normal = None
21
+ _buffer_low_latency = None
22
+
23
+
24
+ def get_buffer_normal(group: dist.ProcessGroup, hidden_bytes: int):
25
+ """
26
+ Copy from DeepEP example usage in model inference prefilling.
27
+ https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-model-training-or-inference-prefilling
28
+ """
29
+
30
+ global _buffer_normal
31
+
32
+ num_nvl_bytes, num_rdma_bytes = 0, 0
33
+ for config in (
34
+ Buffer.get_dispatch_config(group.size()),
35
+ Buffer.get_combine_config(group.size()),
36
+ ):
37
+ num_nvl_bytes = max(
38
+ config.get_nvl_buffer_size_hint(hidden_bytes, group.size()), num_nvl_bytes
39
+ )
40
+ num_rdma_bytes = max(
41
+ config.get_rdma_buffer_size_hint(hidden_bytes, group.size()), num_rdma_bytes
42
+ )
43
+
44
+ if (
45
+ _buffer_normal is None
46
+ or _buffer_normal.group != group
47
+ or _buffer_normal.num_nvl_bytes < num_nvl_bytes
48
+ or _buffer_normal.num_rdma_bytes < num_rdma_bytes
49
+ ):
50
+ _buffer_normal = Buffer(group, num_nvl_bytes, num_rdma_bytes)
51
+ return _buffer_normal
52
+
53
+
54
+ def get_buffer_low_latency(
55
+ group: dist.ProcessGroup,
56
+ num_max_dispatch_tokens_per_rank: int,
57
+ hidden: int,
58
+ num_experts: int,
59
+ ):
60
+ """
61
+ Copy from DeepEP example usage in model inference decoding.
62
+ https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding
63
+ """
64
+
65
+ global _buffer_low_latency
66
+ num_rdma_bytes = Buffer.get_low_latency_rdma_size_hint(
67
+ num_max_dispatch_tokens_per_rank, hidden, group.size(), num_experts
68
+ )
69
+
70
+ if (
71
+ _buffer_low_latency is None
72
+ or _buffer_low_latency.group != group
73
+ or not _buffer_low_latency.low_latency_mode
74
+ or _buffer_low_latency.num_rdma_bytes < num_rdma_bytes
75
+ ):
76
+ assert num_experts % group.size() == 0
77
+ _buffer_low_latency = Buffer(
78
+ group,
79
+ 0,
80
+ num_rdma_bytes,
81
+ low_latency_mode=True,
82
+ num_qps_per_rank=num_experts // group.size(),
83
+ )
84
+ return _buffer_low_latency
85
+
86
+
87
+ class DeepEPDispatcher:
88
+ """
89
+ Copy from Megatron-Core token_dispatcher MoEFlexTokenDispatcher
90
+ https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/moe/token_dispatcher.py
91
+ """
92
+
93
+ def __init__(
94
+ self,
95
+ group: torch.distributed.ProcessGroup,
96
+ router_topk: int,
97
+ permute_fusion: bool = False,
98
+ capacity_factor: float = None,
99
+ num_experts: int = None,
100
+ num_local_experts: int = None,
101
+ hidden_size: int = None,
102
+ params_dtype: torch.dtype = None,
103
+ async_finish: bool = False,
104
+ ):
105
+ self.group = group
106
+ self.router_topk = router_topk
107
+ self.capacity_factor = capacity_factor
108
+ self.permute_fusion = permute_fusion
109
+ self.num_experts = num_experts
110
+ self.num_local_experts = num_local_experts
111
+ self.hidden_size = hidden_size
112
+ self.recv_expert_count = None
113
+ self.params_dtype = params_dtype
114
+ self.params_bytes = 2
115
+ # Metadata
116
+ self.token_indices = None
117
+ self.token_probs = None
118
+ # Handle used for combine operation
119
+ self.handle = None
120
+ self.async_finish = async_finish
121
+
122
+ # `num_max_dispatch_tokens_per_rank` (the actual batch size in the decoding engine) should be less than 256
123
+ # https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding
124
+ self.num_max_dispatch_tokens_per_rank = 128
125
+
126
+ if not use_deepep:
127
+ raise ImportError(
128
+ "DeepEP is not installed. Please install DeepEP package from "
129
+ "https://github.com/deepseek-ai/deepep."
130
+ )
131
+ self.buffer_normal = get_buffer_normal(
132
+ self.group, self.hidden_size * self.params_bytes
133
+ )
134
+ self.buffer_low_latency = None
135
+ # Todo: enable low latency dispatch
136
+ """
137
+ self.buffer_low_latency = get_buffer_low_latency(
138
+ self.group,
139
+ self.num_max_dispatch_tokens_per_rank,
140
+ self.hidden_size * self.params_bytes,
141
+ self.num_experts,
142
+ )
143
+ """
144
+
145
+ def deepep_permute(
146
+ self,
147
+ hidden_states,
148
+ fp8_dtype=None,
149
+ use_fp8_w8a8=False,
150
+ use_block_quant=False,
151
+ ):
152
+ reorder_topk_ids, src2dst, seg_indptr = deepep_run_moe_deep_preprocess(
153
+ self.topk_idx, self.num_experts
154
+ )
155
+ num_total_tokens = reorder_topk_ids.numel()
156
+ gateup_input = torch.empty(
157
+ (int(num_total_tokens), hidden_states.shape[1]),
158
+ device=hidden_states.device,
159
+ dtype=(
160
+ fp8_dtype
161
+ if (use_fp8_w8a8 and not use_block_quant)
162
+ else hidden_states.dtype
163
+ ),
164
+ )
165
+ # PreReorder
166
+ deepep_permute_triton_kernel[(hidden_states.shape[0],)](
167
+ hidden_states,
168
+ gateup_input,
169
+ src2dst,
170
+ self.topk_idx,
171
+ None,
172
+ self.router_topk,
173
+ hidden_states.shape[1],
174
+ BLOCK_SIZE=512,
175
+ )
176
+ self.src2dst = src2dst
177
+ return reorder_topk_ids, seg_indptr, gateup_input
178
+
179
+ def dispatch(
180
+ self,
181
+ hidden_states: torch.Tensor,
182
+ topk_idx: torch.Tensor,
183
+ topk_weights: torch.Tensor,
184
+ num_experts: int,
185
+ forward_mode: ForwardMode,
186
+ num_max_dispatch_tokens_per_rank: int = 128,
187
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
188
+ topk_idx = topk_idx.to(torch.int64)
189
+ # Todo: enable low latency dispatch
190
+ if True: # not forward_mode.is_decode():
191
+ (
192
+ hidden_states,
193
+ topk_idx,
194
+ topk_weights,
195
+ num_recv_tokens_per_expert_list,
196
+ handle,
197
+ event,
198
+ ) = self.dispatch_normal(hidden_states, topk_idx, topk_weights, num_experts)
199
+ self.tokens_per_expert = torch.tensor(
200
+ num_recv_tokens_per_expert_list,
201
+ device=hidden_states.device,
202
+ dtype=torch.int64,
203
+ )
204
+ else:
205
+ hidden_states, recv_expert_count, handle, event, hook = (
206
+ self.dispatch_low_latency(
207
+ hidden_states,
208
+ topk_idx,
209
+ num_max_dispatch_tokens_per_rank,
210
+ num_experts,
211
+ )
212
+ )
213
+ self.recv_expert_count = recv_expert_count
214
+
215
+ if self.async_finish:
216
+ event.current_stream_wait()
217
+
218
+ self.handle = handle
219
+ self.topk_idx = topk_idx
220
+ self.topk_weights = topk_weights
221
+ if hidden_states.shape[0] > 0:
222
+ reorder_topk_ids, seg_indptr, hidden_states = self.deepep_permute(
223
+ hidden_states, fp8_dtype=hidden_states.dtype
224
+ )
225
+ else:
226
+ reorder_topk_ids = torch.empty(
227
+ (0,), device=hidden_states.device, dtype=torch.int64
228
+ )
229
+ seg_indptr = torch.zeros(
230
+ (num_experts + 1,), device=hidden_states.device, dtype=torch.int64
231
+ )
232
+ return hidden_states, reorder_topk_ids, seg_indptr
233
+
234
+ def dispatch_normal(
235
+ self,
236
+ x: torch.Tensor,
237
+ topk_idx: torch.Tensor,
238
+ topk_weights: torch.Tensor,
239
+ num_experts: int,
240
+ ):
241
+ previous_event = Buffer.capture() if self.async_finish else None
242
+
243
+ (
244
+ num_tokens_per_rank,
245
+ num_tokens_per_rdma_rank,
246
+ num_tokens_per_expert,
247
+ is_token_in_rank,
248
+ previous_event,
249
+ ) = self.buffer_normal.get_dispatch_layout(
250
+ topk_idx,
251
+ num_experts,
252
+ previous_event=previous_event,
253
+ async_finish=self.async_finish,
254
+ allocate_on_comm_stream=previous_event is not None,
255
+ )
256
+
257
+ (
258
+ recv_x,
259
+ recv_topk_idx,
260
+ recv_topk_weights,
261
+ num_recv_tokens_per_expert_list,
262
+ handle,
263
+ event,
264
+ ) = self.buffer_normal.dispatch(
265
+ x,
266
+ topk_idx=topk_idx,
267
+ topk_weights=topk_weights,
268
+ num_tokens_per_rank=num_tokens_per_rank,
269
+ num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
270
+ is_token_in_rank=is_token_in_rank,
271
+ num_tokens_per_expert=num_tokens_per_expert,
272
+ previous_event=previous_event,
273
+ async_finish=self.async_finish,
274
+ allocate_on_comm_stream=(previous_event is not None) and self.async_finish,
275
+ )
276
+
277
+ return (
278
+ recv_x,
279
+ recv_topk_idx,
280
+ recv_topk_weights,
281
+ num_recv_tokens_per_expert_list,
282
+ handle,
283
+ event,
284
+ )
285
+
286
+ def dispatch_low_latency(
287
+ self,
288
+ hidden_states: torch.Tensor,
289
+ topk_idx: torch.Tensor,
290
+ num_max_dispatch_tokens_per_rank: int,
291
+ num_experts: int,
292
+ ):
293
+ """
294
+ # For H20, there will be an CUDA error: DeepEP/csrc/kernels/internode_ll.cu:337 'too many blocks in cooperative launch'
295
+ # Please please make sure to change DeepEP code in internode_ll.cu dispatch / combine first and then reinstall!
296
+ # More details refer: https://github.com/deepseek-ai/DeepEP/issues/15#issuecomment-2709715782
297
+ +
298
+ diff --git a/csrc/kernels/internode_ll.cu b/csrc/kernels/internode_ll.cu
299
+ index f60e933..cddaabf 100644
300
+ --- a/csrc/kernels/internode_ll.cu
301
+ +++ b/csrc/kernels/internode_ll.cu
302
+ @@ -307,14 +307,14 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
303
+ int num_topk, int num_experts, int rank, int num_ranks,
304
+ void* workspace, cudaStream_t stream, int phases) {
305
+ constexpr int kNumMaxTopK = 9;
306
+ - constexpr int kNumWarpsPerGroup = 10;
307
+ - constexpr int kNumWarpGroups = 3;
308
+ + constexpr int kNumWarpsPerGroup = 8;
309
+ + constexpr int kNumWarpGroups = 4;
310
+ EP_STATIC_ASSERT(kNumMaxTopK + 1 <= kNumWarpGroups * kNumWarpsPerGroup, "Too many top-k selections");
311
+ +
312
+ const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup;
313
+ const auto num_sms = cell_div(num_experts, kNumWarpGroups);
314
+ EP_HOST_ASSERT(num_topk <= kNumMaxTopK);
315
+ - EP_HOST_ASSERT(cell_div(static_cast<int>(hidden * 2 / sizeof(int4)), 32 * (num_warps - 1)) <= 2);
316
+ + // EP_HOST_ASSERT(cell_div(static_cast<int>(hidden * 2 / sizeof(int4)), 32 * (num_warps - 1)) <= 2);
317
+ +
318
+ // Workspace checks
319
+ auto atomic_counter_per_expert = reinterpret_cast<int*>(workspace);
320
+ @@ -505,8 +505,8 @@ void combine(void* combined_x,
321
+ int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
322
+ int num_topk, int num_experts, int rank, int num_ranks,
323
+ void* workspace, cudaStream_t stream, int phases) {
324
+ - constexpr int kNumWarpsPerGroup = 10;
325
+ - constexpr int kNumWarpGroups = 3;
326
+ + constexpr int kNumWarpsPerGroup = 8;
327
+ + constexpr int kNumWarpGroups = 4;
328
+ constexpr int kNumMaxTopk = 9;
329
+ +
330
+ const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup;
331
+ """
332
+
333
+ recv_hidden_states, recv_expert_count, handle, event, hook = (
334
+ self.buffer_low_latency.low_latency_dispatch(
335
+ hidden_states,
336
+ topk_idx,
337
+ num_max_dispatch_tokens_per_rank,
338
+ num_experts,
339
+ async_finish=self.async_finish,
340
+ return_recv_hook=False, # True for double-batch overlapping, need call hook()
341
+ )
342
+ )
343
+ # hook()
344
+ return recv_hidden_states, recv_expert_count, handle, event, hook
345
+
346
+ def combine(
347
+ self, hidden_states: torch.Tensor, forward_mode: ForwardMode
348
+ ) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
349
+ # Todo: enable low latency combine
350
+ if True: # not forward_mode.is_decode():
351
+ if hidden_states.shape[0] > 0:
352
+ num_tokens = self.src2dst.shape[0] // self.router_topk
353
+ output = torch.empty(
354
+ (num_tokens, hidden_states.shape[1]),
355
+ device=hidden_states.device,
356
+ dtype=hidden_states.dtype,
357
+ )
358
+ deepep_post_reorder_triton_kernel[(num_tokens,)](
359
+ hidden_states,
360
+ output,
361
+ self.src2dst,
362
+ self.topk_idx,
363
+ self.topk_weights,
364
+ self.router_topk,
365
+ hidden_states.shape[1],
366
+ BLOCK_SIZE=512,
367
+ )
368
+ else:
369
+ output = torch.zeros(
370
+ (0, hidden_states.shape[1]),
371
+ device=hidden_states.device,
372
+ dtype=hidden_states.dtype,
373
+ )
374
+ hidden_states, event = self.combine_normal(output, self.handle)
375
+ else:
376
+ hidden_states, event, hook = self.combine_low_latency(
377
+ hidden_states, self.topk_idx, self.topk_weights, self.handle
378
+ )
379
+
380
+ if self.async_finish:
381
+ event.current_stream_wait()
382
+
383
+ self.handle = None
384
+ return hidden_states
385
+
386
+ def combine_normal(self, x: torch.Tensor, handle: Tuple):
387
+ previous_event = Buffer.capture() if self.async_finish else None
388
+
389
+ combined_x, _, event = self.buffer_normal.combine(
390
+ x,
391
+ handle,
392
+ async_finish=self.async_finish,
393
+ previous_event=previous_event,
394
+ allocate_on_comm_stream=previous_event is not None,
395
+ )
396
+ return combined_x, event
397
+
398
+ def combine_low_latency(
399
+ self,
400
+ hidden_states: torch.Tensor,
401
+ topk_idx: torch.Tensor,
402
+ topk_weights: torch.Tensor,
403
+ handle: Tuple,
404
+ ):
405
+ combined_hidden_states, event_overlap, hook = (
406
+ self.buffer_low_latency.low_latency_combine(
407
+ hidden_states,
408
+ topk_idx,
409
+ topk_weights,
410
+ handle,
411
+ async_finish=self.async_finish,
412
+ return_recv_hook=False, # True for double-batch overlapping, need call hook()
413
+ )
414
+ )
415
+ # hook()
416
+ return combined_hidden_states, event_overlap, hook
@@ -8,7 +8,6 @@ from typing import Callable, Optional
8
8
  import torch
9
9
  from torch.nn import functional as F
10
10
 
11
- from sglang.srt.layers.activation import GeluAndMul, SiluAndMul
12
11
  from sglang.srt.layers.moe.topk import select_experts
13
12
 
14
13
 
@@ -69,6 +68,8 @@ def moe_forward_native(
69
68
  activation: str = "silu",
70
69
  ) -> torch.Tensor:
71
70
 
71
+ from sglang.srt.layers.activation import GeluAndMul, SiluAndMul
72
+
72
73
  topk_weights, topk_ids = select_experts(
73
74
  hidden_states=x,
74
75
  router_logits=router_logits,
@@ -0,0 +1,146 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 64,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 32,
7
+ "num_warps": 4,
8
+ "num_stages": 2
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 64,
13
+ "BLOCK_SIZE_K": 128,
14
+ "GROUP_SIZE_M": 64,
15
+ "num_warps": 4,
16
+ "num_stages": 2
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 64,
21
+ "BLOCK_SIZE_K": 128,
22
+ "GROUP_SIZE_M": 64,
23
+ "num_warps": 4,
24
+ "num_stages": 2
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 64,
29
+ "BLOCK_SIZE_K": 128,
30
+ "GROUP_SIZE_M": 16,
31
+ "num_warps": 4,
32
+ "num_stages": 2
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 16,
36
+ "BLOCK_SIZE_N": 128,
37
+ "BLOCK_SIZE_K": 128,
38
+ "GROUP_SIZE_M": 64,
39
+ "num_warps": 4,
40
+ "num_stages": 2
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 128,
45
+ "BLOCK_SIZE_K": 128,
46
+ "GROUP_SIZE_M": 32,
47
+ "num_warps": 4,
48
+ "num_stages": 2
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 64,
53
+ "BLOCK_SIZE_K": 64,
54
+ "GROUP_SIZE_M": 64,
55
+ "num_warps": 8,
56
+ "num_stages": 2
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 32,
61
+ "BLOCK_SIZE_K": 128,
62
+ "GROUP_SIZE_M": 16,
63
+ "num_warps": 4,
64
+ "num_stages": 3
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 32,
69
+ "BLOCK_SIZE_K": 128,
70
+ "GROUP_SIZE_M": 64,
71
+ "num_warps": 4,
72
+ "num_stages": 3
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 16,
76
+ "BLOCK_SIZE_N": 32,
77
+ "BLOCK_SIZE_K": 128,
78
+ "GROUP_SIZE_M": 64,
79
+ "num_warps": 4,
80
+ "num_stages": 3
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 16,
84
+ "BLOCK_SIZE_N": 32,
85
+ "BLOCK_SIZE_K": 128,
86
+ "GROUP_SIZE_M": 16,
87
+ "num_warps": 4,
88
+ "num_stages": 3
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 16,
92
+ "BLOCK_SIZE_N": 128,
93
+ "BLOCK_SIZE_K": 256,
94
+ "GROUP_SIZE_M": 32,
95
+ "num_warps": 8,
96
+ "num_stages": 2
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 16,
100
+ "BLOCK_SIZE_N": 64,
101
+ "BLOCK_SIZE_K": 128,
102
+ "GROUP_SIZE_M": 64,
103
+ "num_warps": 4,
104
+ "num_stages": 2
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 32,
108
+ "BLOCK_SIZE_N": 64,
109
+ "BLOCK_SIZE_K": 128,
110
+ "GROUP_SIZE_M": 32,
111
+ "num_warps": 4,
112
+ "num_stages": 2
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 32,
116
+ "BLOCK_SIZE_N": 64,
117
+ "BLOCK_SIZE_K": 128,
118
+ "GROUP_SIZE_M": 32,
119
+ "num_warps": 4,
120
+ "num_stages": 2
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 64,
124
+ "BLOCK_SIZE_N": 64,
125
+ "BLOCK_SIZE_K": 128,
126
+ "GROUP_SIZE_M": 16,
127
+ "num_warps": 4,
128
+ "num_stages": 2
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 64,
132
+ "BLOCK_SIZE_N": 64,
133
+ "BLOCK_SIZE_K": 128,
134
+ "GROUP_SIZE_M": 16,
135
+ "num_warps": 4,
136
+ "num_stages": 2
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 64,
140
+ "BLOCK_SIZE_N": 64,
141
+ "BLOCK_SIZE_K": 128,
142
+ "GROUP_SIZE_M": 16,
143
+ "num_warps": 4,
144
+ "num_stages": 2
145
+ }
146
+ }