sglang 0.4.4.post1__py3-none-any.whl → 0.4.4.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +6 -0
- sglang/bench_one_batch.py +1 -1
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +26 -4
- sglang/check_env.py +3 -4
- sglang/lang/backend/openai.py +18 -5
- sglang/lang/chat_template.py +28 -7
- sglang/lang/interpreter.py +7 -3
- sglang/lang/ir.py +10 -0
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/code_completion_parser.py +174 -0
- sglang/srt/configs/__init__.py +2 -6
- sglang/srt/configs/deepseekvl2.py +676 -0
- sglang/srt/configs/janus_pro.py +3 -4
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +49 -8
- sglang/srt/configs/utils.py +25 -0
- sglang/srt/connector/__init__.py +51 -0
- sglang/srt/connector/base_connector.py +112 -0
- sglang/srt/connector/redis.py +85 -0
- sglang/srt/connector/s3.py +122 -0
- sglang/srt/connector/serde/__init__.py +31 -0
- sglang/srt/connector/serde/safe_serde.py +29 -0
- sglang/srt/connector/serde/serde.py +43 -0
- sglang/srt/connector/utils.py +35 -0
- sglang/srt/conversation.py +88 -0
- sglang/srt/disaggregation/conn.py +81 -0
- sglang/srt/disaggregation/decode.py +495 -0
- sglang/srt/disaggregation/mini_lb.py +285 -0
- sglang/srt/disaggregation/prefill.py +249 -0
- sglang/srt/disaggregation/utils.py +44 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
- sglang/srt/distributed/parallel_state.py +42 -8
- sglang/srt/entrypoints/engine.py +55 -5
- sglang/srt/entrypoints/http_server.py +78 -13
- sglang/srt/entrypoints/verl_engine.py +2 -0
- sglang/srt/function_call_parser.py +133 -55
- sglang/srt/hf_transformers_utils.py +28 -3
- sglang/srt/layers/activation.py +4 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +434 -0
- sglang/srt/layers/attention/flashinfer_backend.py +1 -1
- sglang/srt/layers/attention/flashmla_backend.py +284 -0
- sglang/srt/layers/attention/triton_backend.py +171 -38
- sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
- sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
- sglang/srt/layers/attention/utils.py +53 -0
- sglang/srt/layers/attention/vision.py +9 -28
- sglang/srt/layers/dp_attention.py +41 -19
- sglang/srt/layers/layernorm.py +24 -2
- sglang/srt/layers/linear.py +17 -5
- sglang/srt/layers/logits_processor.py +25 -7
- sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
- sglang/srt/layers/moe/ep_moe/layer.py +273 -1
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
- sglang/srt/layers/moe/fused_moe_native.py +2 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
- sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
- sglang/srt/layers/moe/topk.py +60 -20
- sglang/srt/layers/parameter.py +1 -1
- sglang/srt/layers/quantization/__init__.py +80 -53
- sglang/srt/layers/quantization/awq.py +200 -0
- sglang/srt/layers/quantization/base_config.py +5 -0
- sglang/srt/layers/quantization/blockwise_int8.py +1 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
- sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
- sglang/srt/layers/quantization/fp8.py +76 -34
- sglang/srt/layers/quantization/fp8_kernel.py +25 -8
- sglang/srt/layers/quantization/fp8_utils.py +284 -28
- sglang/srt/layers/quantization/gptq.py +36 -19
- sglang/srt/layers/quantization/kv_cache.py +98 -0
- sglang/srt/layers/quantization/modelopt_quant.py +9 -7
- sglang/srt/layers/quantization/utils.py +153 -0
- sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
- sglang/srt/layers/rotary_embedding.py +78 -87
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/backend/base_backend.py +4 -4
- sglang/srt/lora/backend/flashinfer_backend.py +12 -9
- sglang/srt/lora/backend/triton_backend.py +5 -8
- sglang/srt/lora/layers.py +87 -33
- sglang/srt/lora/lora.py +2 -22
- sglang/srt/lora/lora_manager.py +67 -30
- sglang/srt/lora/mem_pool.py +117 -52
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +10 -4
- sglang/srt/lora/triton_ops/qkv_lora_b.py +8 -3
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +16 -5
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +11 -6
- sglang/srt/lora/utils.py +18 -1
- sglang/srt/managers/cache_controller.py +2 -5
- sglang/srt/managers/data_parallel_controller.py +30 -8
- sglang/srt/managers/expert_distribution.py +81 -0
- sglang/srt/managers/io_struct.py +43 -5
- sglang/srt/managers/mm_utils.py +373 -0
- sglang/srt/managers/multimodal_processor.py +68 -0
- sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
- sglang/srt/managers/multimodal_processors/clip.py +63 -0
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
- sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
- sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
- sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
- sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
- sglang/srt/managers/schedule_batch.py +134 -30
- sglang/srt/managers/scheduler.py +290 -31
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +59 -24
- sglang/srt/managers/tp_worker.py +4 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
- sglang/srt/managers/utils.py +6 -1
- sglang/srt/mem_cache/hiradix_cache.py +18 -7
- sglang/srt/mem_cache/memory_pool.py +255 -98
- sglang/srt/mem_cache/paged_allocator.py +2 -2
- sglang/srt/mem_cache/radix_cache.py +4 -4
- sglang/srt/model_executor/cuda_graph_runner.py +36 -21
- sglang/srt/model_executor/forward_batch_info.py +68 -11
- sglang/srt/model_executor/model_runner.py +75 -8
- sglang/srt/model_loader/loader.py +171 -3
- sglang/srt/model_loader/weight_utils.py +51 -3
- sglang/srt/models/clip.py +563 -0
- sglang/srt/models/deepseek_janus_pro.py +31 -88
- sglang/srt/models/deepseek_nextn.py +22 -10
- sglang/srt/models/deepseek_v2.py +329 -73
- sglang/srt/models/deepseek_vl2.py +358 -0
- sglang/srt/models/gemma3_causal.py +694 -0
- sglang/srt/models/gemma3_mm.py +468 -0
- sglang/srt/models/llama.py +47 -7
- sglang/srt/models/llama_eagle.py +1 -0
- sglang/srt/models/llama_eagle3.py +196 -0
- sglang/srt/models/llava.py +3 -3
- sglang/srt/models/llavavid.py +3 -3
- sglang/srt/models/minicpmo.py +1995 -0
- sglang/srt/models/minicpmv.py +62 -137
- sglang/srt/models/mllama.py +4 -4
- sglang/srt/models/phi3_small.py +1 -1
- sglang/srt/models/qwen2.py +3 -0
- sglang/srt/models/qwen2_5_vl.py +68 -146
- sglang/srt/models/qwen2_classification.py +75 -0
- sglang/srt/models/qwen2_moe.py +9 -1
- sglang/srt/models/qwen2_vl.py +25 -63
- sglang/srt/openai_api/adapter.py +201 -104
- sglang/srt/openai_api/protocol.py +33 -7
- sglang/srt/patch_torch.py +71 -0
- sglang/srt/sampling/sampling_batch_info.py +1 -1
- sglang/srt/sampling/sampling_params.py +6 -6
- sglang/srt/server_args.py +114 -14
- sglang/srt/speculative/build_eagle_tree.py +7 -347
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
- sglang/srt/speculative/eagle_utils.py +208 -252
- sglang/srt/speculative/eagle_worker.py +140 -54
- sglang/srt/speculative/spec_info.py +6 -1
- sglang/srt/torch_memory_saver_adapter.py +22 -0
- sglang/srt/utils.py +215 -21
- sglang/test/__init__.py +0 -0
- sglang/test/attention/__init__.py +0 -0
- sglang/test/attention/test_flashattn_backend.py +312 -0
- sglang/test/runners.py +29 -2
- sglang/test/test_activation.py +2 -1
- sglang/test/test_block_fp8.py +5 -4
- sglang/test/test_block_fp8_ep.py +2 -1
- sglang/test/test_dynamic_grad_mode.py +58 -0
- sglang/test/test_layernorm.py +3 -2
- sglang/test/test_utils.py +56 -5
- sglang/utils.py +31 -0
- sglang/version.py +1 -1
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/METADATA +16 -8
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/RECORD +180 -132
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/WHEEL +1 -1
- sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
- sglang/srt/managers/image_processor.py +0 -55
- sglang/srt/managers/image_processors/base_image_processor.py +0 -219
- sglang/srt/managers/image_processors/minicpmv.py +0 -86
- sglang/srt/managers/multi_modality_padding.py +0 -134
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info/licenses}/LICENSE +0 -0
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,416 @@
|
|
1
|
+
try:
|
2
|
+
from deep_ep import Buffer
|
3
|
+
|
4
|
+
use_deepep = True
|
5
|
+
except ImportError:
|
6
|
+
use_deepep = False
|
7
|
+
|
8
|
+
from typing import Optional, Tuple
|
9
|
+
|
10
|
+
import torch
|
11
|
+
import torch.distributed as dist
|
12
|
+
|
13
|
+
from sglang.srt.layers.moe.ep_moe.kernels import (
|
14
|
+
deepep_permute_triton_kernel,
|
15
|
+
deepep_post_reorder_triton_kernel,
|
16
|
+
deepep_run_moe_deep_preprocess,
|
17
|
+
)
|
18
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
19
|
+
|
20
|
+
_buffer_normal = None
|
21
|
+
_buffer_low_latency = None
|
22
|
+
|
23
|
+
|
24
|
+
def get_buffer_normal(group: dist.ProcessGroup, hidden_bytes: int):
|
25
|
+
"""
|
26
|
+
Copy from DeepEP example usage in model inference prefilling.
|
27
|
+
https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-model-training-or-inference-prefilling
|
28
|
+
"""
|
29
|
+
|
30
|
+
global _buffer_normal
|
31
|
+
|
32
|
+
num_nvl_bytes, num_rdma_bytes = 0, 0
|
33
|
+
for config in (
|
34
|
+
Buffer.get_dispatch_config(group.size()),
|
35
|
+
Buffer.get_combine_config(group.size()),
|
36
|
+
):
|
37
|
+
num_nvl_bytes = max(
|
38
|
+
config.get_nvl_buffer_size_hint(hidden_bytes, group.size()), num_nvl_bytes
|
39
|
+
)
|
40
|
+
num_rdma_bytes = max(
|
41
|
+
config.get_rdma_buffer_size_hint(hidden_bytes, group.size()), num_rdma_bytes
|
42
|
+
)
|
43
|
+
|
44
|
+
if (
|
45
|
+
_buffer_normal is None
|
46
|
+
or _buffer_normal.group != group
|
47
|
+
or _buffer_normal.num_nvl_bytes < num_nvl_bytes
|
48
|
+
or _buffer_normal.num_rdma_bytes < num_rdma_bytes
|
49
|
+
):
|
50
|
+
_buffer_normal = Buffer(group, num_nvl_bytes, num_rdma_bytes)
|
51
|
+
return _buffer_normal
|
52
|
+
|
53
|
+
|
54
|
+
def get_buffer_low_latency(
|
55
|
+
group: dist.ProcessGroup,
|
56
|
+
num_max_dispatch_tokens_per_rank: int,
|
57
|
+
hidden: int,
|
58
|
+
num_experts: int,
|
59
|
+
):
|
60
|
+
"""
|
61
|
+
Copy from DeepEP example usage in model inference decoding.
|
62
|
+
https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding
|
63
|
+
"""
|
64
|
+
|
65
|
+
global _buffer_low_latency
|
66
|
+
num_rdma_bytes = Buffer.get_low_latency_rdma_size_hint(
|
67
|
+
num_max_dispatch_tokens_per_rank, hidden, group.size(), num_experts
|
68
|
+
)
|
69
|
+
|
70
|
+
if (
|
71
|
+
_buffer_low_latency is None
|
72
|
+
or _buffer_low_latency.group != group
|
73
|
+
or not _buffer_low_latency.low_latency_mode
|
74
|
+
or _buffer_low_latency.num_rdma_bytes < num_rdma_bytes
|
75
|
+
):
|
76
|
+
assert num_experts % group.size() == 0
|
77
|
+
_buffer_low_latency = Buffer(
|
78
|
+
group,
|
79
|
+
0,
|
80
|
+
num_rdma_bytes,
|
81
|
+
low_latency_mode=True,
|
82
|
+
num_qps_per_rank=num_experts // group.size(),
|
83
|
+
)
|
84
|
+
return _buffer_low_latency
|
85
|
+
|
86
|
+
|
87
|
+
class DeepEPDispatcher:
|
88
|
+
"""
|
89
|
+
Copy from Megatron-Core token_dispatcher MoEFlexTokenDispatcher
|
90
|
+
https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/transformer/moe/token_dispatcher.py
|
91
|
+
"""
|
92
|
+
|
93
|
+
def __init__(
|
94
|
+
self,
|
95
|
+
group: torch.distributed.ProcessGroup,
|
96
|
+
router_topk: int,
|
97
|
+
permute_fusion: bool = False,
|
98
|
+
capacity_factor: float = None,
|
99
|
+
num_experts: int = None,
|
100
|
+
num_local_experts: int = None,
|
101
|
+
hidden_size: int = None,
|
102
|
+
params_dtype: torch.dtype = None,
|
103
|
+
async_finish: bool = False,
|
104
|
+
):
|
105
|
+
self.group = group
|
106
|
+
self.router_topk = router_topk
|
107
|
+
self.capacity_factor = capacity_factor
|
108
|
+
self.permute_fusion = permute_fusion
|
109
|
+
self.num_experts = num_experts
|
110
|
+
self.num_local_experts = num_local_experts
|
111
|
+
self.hidden_size = hidden_size
|
112
|
+
self.recv_expert_count = None
|
113
|
+
self.params_dtype = params_dtype
|
114
|
+
self.params_bytes = 2
|
115
|
+
# Metadata
|
116
|
+
self.token_indices = None
|
117
|
+
self.token_probs = None
|
118
|
+
# Handle used for combine operation
|
119
|
+
self.handle = None
|
120
|
+
self.async_finish = async_finish
|
121
|
+
|
122
|
+
# `num_max_dispatch_tokens_per_rank` (the actual batch size in the decoding engine) should be less than 256
|
123
|
+
# https://github.com/deepseek-ai/DeepEP?tab=readme-ov-file#example-use-in-inference-decoding
|
124
|
+
self.num_max_dispatch_tokens_per_rank = 128
|
125
|
+
|
126
|
+
if not use_deepep:
|
127
|
+
raise ImportError(
|
128
|
+
"DeepEP is not installed. Please install DeepEP package from "
|
129
|
+
"https://github.com/deepseek-ai/deepep."
|
130
|
+
)
|
131
|
+
self.buffer_normal = get_buffer_normal(
|
132
|
+
self.group, self.hidden_size * self.params_bytes
|
133
|
+
)
|
134
|
+
self.buffer_low_latency = None
|
135
|
+
# Todo: enable low latency dispatch
|
136
|
+
"""
|
137
|
+
self.buffer_low_latency = get_buffer_low_latency(
|
138
|
+
self.group,
|
139
|
+
self.num_max_dispatch_tokens_per_rank,
|
140
|
+
self.hidden_size * self.params_bytes,
|
141
|
+
self.num_experts,
|
142
|
+
)
|
143
|
+
"""
|
144
|
+
|
145
|
+
def deepep_permute(
|
146
|
+
self,
|
147
|
+
hidden_states,
|
148
|
+
fp8_dtype=None,
|
149
|
+
use_fp8_w8a8=False,
|
150
|
+
use_block_quant=False,
|
151
|
+
):
|
152
|
+
reorder_topk_ids, src2dst, seg_indptr = deepep_run_moe_deep_preprocess(
|
153
|
+
self.topk_idx, self.num_experts
|
154
|
+
)
|
155
|
+
num_total_tokens = reorder_topk_ids.numel()
|
156
|
+
gateup_input = torch.empty(
|
157
|
+
(int(num_total_tokens), hidden_states.shape[1]),
|
158
|
+
device=hidden_states.device,
|
159
|
+
dtype=(
|
160
|
+
fp8_dtype
|
161
|
+
if (use_fp8_w8a8 and not use_block_quant)
|
162
|
+
else hidden_states.dtype
|
163
|
+
),
|
164
|
+
)
|
165
|
+
# PreReorder
|
166
|
+
deepep_permute_triton_kernel[(hidden_states.shape[0],)](
|
167
|
+
hidden_states,
|
168
|
+
gateup_input,
|
169
|
+
src2dst,
|
170
|
+
self.topk_idx,
|
171
|
+
None,
|
172
|
+
self.router_topk,
|
173
|
+
hidden_states.shape[1],
|
174
|
+
BLOCK_SIZE=512,
|
175
|
+
)
|
176
|
+
self.src2dst = src2dst
|
177
|
+
return reorder_topk_ids, seg_indptr, gateup_input
|
178
|
+
|
179
|
+
def dispatch(
|
180
|
+
self,
|
181
|
+
hidden_states: torch.Tensor,
|
182
|
+
topk_idx: torch.Tensor,
|
183
|
+
topk_weights: torch.Tensor,
|
184
|
+
num_experts: int,
|
185
|
+
forward_mode: ForwardMode,
|
186
|
+
num_max_dispatch_tokens_per_rank: int = 128,
|
187
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
188
|
+
topk_idx = topk_idx.to(torch.int64)
|
189
|
+
# Todo: enable low latency dispatch
|
190
|
+
if True: # not forward_mode.is_decode():
|
191
|
+
(
|
192
|
+
hidden_states,
|
193
|
+
topk_idx,
|
194
|
+
topk_weights,
|
195
|
+
num_recv_tokens_per_expert_list,
|
196
|
+
handle,
|
197
|
+
event,
|
198
|
+
) = self.dispatch_normal(hidden_states, topk_idx, topk_weights, num_experts)
|
199
|
+
self.tokens_per_expert = torch.tensor(
|
200
|
+
num_recv_tokens_per_expert_list,
|
201
|
+
device=hidden_states.device,
|
202
|
+
dtype=torch.int64,
|
203
|
+
)
|
204
|
+
else:
|
205
|
+
hidden_states, recv_expert_count, handle, event, hook = (
|
206
|
+
self.dispatch_low_latency(
|
207
|
+
hidden_states,
|
208
|
+
topk_idx,
|
209
|
+
num_max_dispatch_tokens_per_rank,
|
210
|
+
num_experts,
|
211
|
+
)
|
212
|
+
)
|
213
|
+
self.recv_expert_count = recv_expert_count
|
214
|
+
|
215
|
+
if self.async_finish:
|
216
|
+
event.current_stream_wait()
|
217
|
+
|
218
|
+
self.handle = handle
|
219
|
+
self.topk_idx = topk_idx
|
220
|
+
self.topk_weights = topk_weights
|
221
|
+
if hidden_states.shape[0] > 0:
|
222
|
+
reorder_topk_ids, seg_indptr, hidden_states = self.deepep_permute(
|
223
|
+
hidden_states, fp8_dtype=hidden_states.dtype
|
224
|
+
)
|
225
|
+
else:
|
226
|
+
reorder_topk_ids = torch.empty(
|
227
|
+
(0,), device=hidden_states.device, dtype=torch.int64
|
228
|
+
)
|
229
|
+
seg_indptr = torch.zeros(
|
230
|
+
(num_experts + 1,), device=hidden_states.device, dtype=torch.int64
|
231
|
+
)
|
232
|
+
return hidden_states, reorder_topk_ids, seg_indptr
|
233
|
+
|
234
|
+
def dispatch_normal(
|
235
|
+
self,
|
236
|
+
x: torch.Tensor,
|
237
|
+
topk_idx: torch.Tensor,
|
238
|
+
topk_weights: torch.Tensor,
|
239
|
+
num_experts: int,
|
240
|
+
):
|
241
|
+
previous_event = Buffer.capture() if self.async_finish else None
|
242
|
+
|
243
|
+
(
|
244
|
+
num_tokens_per_rank,
|
245
|
+
num_tokens_per_rdma_rank,
|
246
|
+
num_tokens_per_expert,
|
247
|
+
is_token_in_rank,
|
248
|
+
previous_event,
|
249
|
+
) = self.buffer_normal.get_dispatch_layout(
|
250
|
+
topk_idx,
|
251
|
+
num_experts,
|
252
|
+
previous_event=previous_event,
|
253
|
+
async_finish=self.async_finish,
|
254
|
+
allocate_on_comm_stream=previous_event is not None,
|
255
|
+
)
|
256
|
+
|
257
|
+
(
|
258
|
+
recv_x,
|
259
|
+
recv_topk_idx,
|
260
|
+
recv_topk_weights,
|
261
|
+
num_recv_tokens_per_expert_list,
|
262
|
+
handle,
|
263
|
+
event,
|
264
|
+
) = self.buffer_normal.dispatch(
|
265
|
+
x,
|
266
|
+
topk_idx=topk_idx,
|
267
|
+
topk_weights=topk_weights,
|
268
|
+
num_tokens_per_rank=num_tokens_per_rank,
|
269
|
+
num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
|
270
|
+
is_token_in_rank=is_token_in_rank,
|
271
|
+
num_tokens_per_expert=num_tokens_per_expert,
|
272
|
+
previous_event=previous_event,
|
273
|
+
async_finish=self.async_finish,
|
274
|
+
allocate_on_comm_stream=(previous_event is not None) and self.async_finish,
|
275
|
+
)
|
276
|
+
|
277
|
+
return (
|
278
|
+
recv_x,
|
279
|
+
recv_topk_idx,
|
280
|
+
recv_topk_weights,
|
281
|
+
num_recv_tokens_per_expert_list,
|
282
|
+
handle,
|
283
|
+
event,
|
284
|
+
)
|
285
|
+
|
286
|
+
def dispatch_low_latency(
|
287
|
+
self,
|
288
|
+
hidden_states: torch.Tensor,
|
289
|
+
topk_idx: torch.Tensor,
|
290
|
+
num_max_dispatch_tokens_per_rank: int,
|
291
|
+
num_experts: int,
|
292
|
+
):
|
293
|
+
"""
|
294
|
+
# For H20, there will be an CUDA error: DeepEP/csrc/kernels/internode_ll.cu:337 'too many blocks in cooperative launch'
|
295
|
+
# Please please make sure to change DeepEP code in internode_ll.cu dispatch / combine first and then reinstall!
|
296
|
+
# More details refer: https://github.com/deepseek-ai/DeepEP/issues/15#issuecomment-2709715782
|
297
|
+
+
|
298
|
+
diff --git a/csrc/kernels/internode_ll.cu b/csrc/kernels/internode_ll.cu
|
299
|
+
index f60e933..cddaabf 100644
|
300
|
+
--- a/csrc/kernels/internode_ll.cu
|
301
|
+
+++ b/csrc/kernels/internode_ll.cu
|
302
|
+
@@ -307,14 +307,14 @@ void dispatch(void* packed_recv_x, float* packed_recv_x_scales,
|
303
|
+
int num_topk, int num_experts, int rank, int num_ranks,
|
304
|
+
void* workspace, cudaStream_t stream, int phases) {
|
305
|
+
constexpr int kNumMaxTopK = 9;
|
306
|
+
- constexpr int kNumWarpsPerGroup = 10;
|
307
|
+
- constexpr int kNumWarpGroups = 3;
|
308
|
+
+ constexpr int kNumWarpsPerGroup = 8;
|
309
|
+
+ constexpr int kNumWarpGroups = 4;
|
310
|
+
EP_STATIC_ASSERT(kNumMaxTopK + 1 <= kNumWarpGroups * kNumWarpsPerGroup, "Too many top-k selections");
|
311
|
+
+
|
312
|
+
const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup;
|
313
|
+
const auto num_sms = cell_div(num_experts, kNumWarpGroups);
|
314
|
+
EP_HOST_ASSERT(num_topk <= kNumMaxTopK);
|
315
|
+
- EP_HOST_ASSERT(cell_div(static_cast<int>(hidden * 2 / sizeof(int4)), 32 * (num_warps - 1)) <= 2);
|
316
|
+
+ // EP_HOST_ASSERT(cell_div(static_cast<int>(hidden * 2 / sizeof(int4)), 32 * (num_warps - 1)) <= 2);
|
317
|
+
+
|
318
|
+
// Workspace checks
|
319
|
+
auto atomic_counter_per_expert = reinterpret_cast<int*>(workspace);
|
320
|
+
@@ -505,8 +505,8 @@ void combine(void* combined_x,
|
321
|
+
int num_combined_tokens, int hidden, int num_max_dispatch_tokens_per_rank,
|
322
|
+
int num_topk, int num_experts, int rank, int num_ranks,
|
323
|
+
void* workspace, cudaStream_t stream, int phases) {
|
324
|
+
- constexpr int kNumWarpsPerGroup = 10;
|
325
|
+
- constexpr int kNumWarpGroups = 3;
|
326
|
+
+ constexpr int kNumWarpsPerGroup = 8;
|
327
|
+
+ constexpr int kNumWarpGroups = 4;
|
328
|
+
constexpr int kNumMaxTopk = 9;
|
329
|
+
+
|
330
|
+
const auto num_warps = kNumWarpGroups * kNumWarpsPerGroup;
|
331
|
+
"""
|
332
|
+
|
333
|
+
recv_hidden_states, recv_expert_count, handle, event, hook = (
|
334
|
+
self.buffer_low_latency.low_latency_dispatch(
|
335
|
+
hidden_states,
|
336
|
+
topk_idx,
|
337
|
+
num_max_dispatch_tokens_per_rank,
|
338
|
+
num_experts,
|
339
|
+
async_finish=self.async_finish,
|
340
|
+
return_recv_hook=False, # True for double-batch overlapping, need call hook()
|
341
|
+
)
|
342
|
+
)
|
343
|
+
# hook()
|
344
|
+
return recv_hidden_states, recv_expert_count, handle, event, hook
|
345
|
+
|
346
|
+
def combine(
|
347
|
+
self, hidden_states: torch.Tensor, forward_mode: ForwardMode
|
348
|
+
) -> Tuple[torch.Tensor, Optional[torch.Tensor]]:
|
349
|
+
# Todo: enable low latency combine
|
350
|
+
if True: # not forward_mode.is_decode():
|
351
|
+
if hidden_states.shape[0] > 0:
|
352
|
+
num_tokens = self.src2dst.shape[0] // self.router_topk
|
353
|
+
output = torch.empty(
|
354
|
+
(num_tokens, hidden_states.shape[1]),
|
355
|
+
device=hidden_states.device,
|
356
|
+
dtype=hidden_states.dtype,
|
357
|
+
)
|
358
|
+
deepep_post_reorder_triton_kernel[(num_tokens,)](
|
359
|
+
hidden_states,
|
360
|
+
output,
|
361
|
+
self.src2dst,
|
362
|
+
self.topk_idx,
|
363
|
+
self.topk_weights,
|
364
|
+
self.router_topk,
|
365
|
+
hidden_states.shape[1],
|
366
|
+
BLOCK_SIZE=512,
|
367
|
+
)
|
368
|
+
else:
|
369
|
+
output = torch.zeros(
|
370
|
+
(0, hidden_states.shape[1]),
|
371
|
+
device=hidden_states.device,
|
372
|
+
dtype=hidden_states.dtype,
|
373
|
+
)
|
374
|
+
hidden_states, event = self.combine_normal(output, self.handle)
|
375
|
+
else:
|
376
|
+
hidden_states, event, hook = self.combine_low_latency(
|
377
|
+
hidden_states, self.topk_idx, self.topk_weights, self.handle
|
378
|
+
)
|
379
|
+
|
380
|
+
if self.async_finish:
|
381
|
+
event.current_stream_wait()
|
382
|
+
|
383
|
+
self.handle = None
|
384
|
+
return hidden_states
|
385
|
+
|
386
|
+
def combine_normal(self, x: torch.Tensor, handle: Tuple):
|
387
|
+
previous_event = Buffer.capture() if self.async_finish else None
|
388
|
+
|
389
|
+
combined_x, _, event = self.buffer_normal.combine(
|
390
|
+
x,
|
391
|
+
handle,
|
392
|
+
async_finish=self.async_finish,
|
393
|
+
previous_event=previous_event,
|
394
|
+
allocate_on_comm_stream=previous_event is not None,
|
395
|
+
)
|
396
|
+
return combined_x, event
|
397
|
+
|
398
|
+
def combine_low_latency(
|
399
|
+
self,
|
400
|
+
hidden_states: torch.Tensor,
|
401
|
+
topk_idx: torch.Tensor,
|
402
|
+
topk_weights: torch.Tensor,
|
403
|
+
handle: Tuple,
|
404
|
+
):
|
405
|
+
combined_hidden_states, event_overlap, hook = (
|
406
|
+
self.buffer_low_latency.low_latency_combine(
|
407
|
+
hidden_states,
|
408
|
+
topk_idx,
|
409
|
+
topk_weights,
|
410
|
+
handle,
|
411
|
+
async_finish=self.async_finish,
|
412
|
+
return_recv_hook=False, # True for double-batch overlapping, need call hook()
|
413
|
+
)
|
414
|
+
)
|
415
|
+
# hook()
|
416
|
+
return combined_hidden_states, event_overlap, hook
|
@@ -8,7 +8,6 @@ from typing import Callable, Optional
|
|
8
8
|
import torch
|
9
9
|
from torch.nn import functional as F
|
10
10
|
|
11
|
-
from sglang.srt.layers.activation import GeluAndMul, SiluAndMul
|
12
11
|
from sglang.srt.layers.moe.topk import select_experts
|
13
12
|
|
14
13
|
|
@@ -69,6 +68,8 @@ def moe_forward_native(
|
|
69
68
|
activation: str = "silu",
|
70
69
|
) -> torch.Tensor:
|
71
70
|
|
71
|
+
from sglang.srt.layers.activation import GeluAndMul, SiluAndMul
|
72
|
+
|
72
73
|
topk_weights, topk_ids = select_experts(
|
73
74
|
hidden_states=x,
|
74
75
|
router_logits=router_logits,
|
@@ -0,0 +1,146 @@
|
|
1
|
+
{
|
2
|
+
"1": {
|
3
|
+
"BLOCK_SIZE_M": 16,
|
4
|
+
"BLOCK_SIZE_N": 64,
|
5
|
+
"BLOCK_SIZE_K": 128,
|
6
|
+
"GROUP_SIZE_M": 32,
|
7
|
+
"num_warps": 4,
|
8
|
+
"num_stages": 2
|
9
|
+
},
|
10
|
+
"2": {
|
11
|
+
"BLOCK_SIZE_M": 16,
|
12
|
+
"BLOCK_SIZE_N": 64,
|
13
|
+
"BLOCK_SIZE_K": 128,
|
14
|
+
"GROUP_SIZE_M": 64,
|
15
|
+
"num_warps": 4,
|
16
|
+
"num_stages": 2
|
17
|
+
},
|
18
|
+
"4": {
|
19
|
+
"BLOCK_SIZE_M": 16,
|
20
|
+
"BLOCK_SIZE_N": 64,
|
21
|
+
"BLOCK_SIZE_K": 128,
|
22
|
+
"GROUP_SIZE_M": 64,
|
23
|
+
"num_warps": 4,
|
24
|
+
"num_stages": 2
|
25
|
+
},
|
26
|
+
"8": {
|
27
|
+
"BLOCK_SIZE_M": 16,
|
28
|
+
"BLOCK_SIZE_N": 64,
|
29
|
+
"BLOCK_SIZE_K": 128,
|
30
|
+
"GROUP_SIZE_M": 16,
|
31
|
+
"num_warps": 4,
|
32
|
+
"num_stages": 2
|
33
|
+
},
|
34
|
+
"16": {
|
35
|
+
"BLOCK_SIZE_M": 16,
|
36
|
+
"BLOCK_SIZE_N": 128,
|
37
|
+
"BLOCK_SIZE_K": 128,
|
38
|
+
"GROUP_SIZE_M": 64,
|
39
|
+
"num_warps": 4,
|
40
|
+
"num_stages": 2
|
41
|
+
},
|
42
|
+
"24": {
|
43
|
+
"BLOCK_SIZE_M": 16,
|
44
|
+
"BLOCK_SIZE_N": 128,
|
45
|
+
"BLOCK_SIZE_K": 128,
|
46
|
+
"GROUP_SIZE_M": 32,
|
47
|
+
"num_warps": 4,
|
48
|
+
"num_stages": 2
|
49
|
+
},
|
50
|
+
"32": {
|
51
|
+
"BLOCK_SIZE_M": 16,
|
52
|
+
"BLOCK_SIZE_N": 64,
|
53
|
+
"BLOCK_SIZE_K": 64,
|
54
|
+
"GROUP_SIZE_M": 64,
|
55
|
+
"num_warps": 8,
|
56
|
+
"num_stages": 2
|
57
|
+
},
|
58
|
+
"48": {
|
59
|
+
"BLOCK_SIZE_M": 16,
|
60
|
+
"BLOCK_SIZE_N": 32,
|
61
|
+
"BLOCK_SIZE_K": 128,
|
62
|
+
"GROUP_SIZE_M": 16,
|
63
|
+
"num_warps": 4,
|
64
|
+
"num_stages": 3
|
65
|
+
},
|
66
|
+
"64": {
|
67
|
+
"BLOCK_SIZE_M": 16,
|
68
|
+
"BLOCK_SIZE_N": 32,
|
69
|
+
"BLOCK_SIZE_K": 128,
|
70
|
+
"GROUP_SIZE_M": 64,
|
71
|
+
"num_warps": 4,
|
72
|
+
"num_stages": 3
|
73
|
+
},
|
74
|
+
"96": {
|
75
|
+
"BLOCK_SIZE_M": 16,
|
76
|
+
"BLOCK_SIZE_N": 32,
|
77
|
+
"BLOCK_SIZE_K": 128,
|
78
|
+
"GROUP_SIZE_M": 64,
|
79
|
+
"num_warps": 4,
|
80
|
+
"num_stages": 3
|
81
|
+
},
|
82
|
+
"128": {
|
83
|
+
"BLOCK_SIZE_M": 16,
|
84
|
+
"BLOCK_SIZE_N": 32,
|
85
|
+
"BLOCK_SIZE_K": 128,
|
86
|
+
"GROUP_SIZE_M": 16,
|
87
|
+
"num_warps": 4,
|
88
|
+
"num_stages": 3
|
89
|
+
},
|
90
|
+
"256": {
|
91
|
+
"BLOCK_SIZE_M": 16,
|
92
|
+
"BLOCK_SIZE_N": 128,
|
93
|
+
"BLOCK_SIZE_K": 256,
|
94
|
+
"GROUP_SIZE_M": 32,
|
95
|
+
"num_warps": 8,
|
96
|
+
"num_stages": 2
|
97
|
+
},
|
98
|
+
"512": {
|
99
|
+
"BLOCK_SIZE_M": 16,
|
100
|
+
"BLOCK_SIZE_N": 64,
|
101
|
+
"BLOCK_SIZE_K": 128,
|
102
|
+
"GROUP_SIZE_M": 64,
|
103
|
+
"num_warps": 4,
|
104
|
+
"num_stages": 2
|
105
|
+
},
|
106
|
+
"1024": {
|
107
|
+
"BLOCK_SIZE_M": 32,
|
108
|
+
"BLOCK_SIZE_N": 64,
|
109
|
+
"BLOCK_SIZE_K": 128,
|
110
|
+
"GROUP_SIZE_M": 32,
|
111
|
+
"num_warps": 4,
|
112
|
+
"num_stages": 2
|
113
|
+
},
|
114
|
+
"1536": {
|
115
|
+
"BLOCK_SIZE_M": 32,
|
116
|
+
"BLOCK_SIZE_N": 64,
|
117
|
+
"BLOCK_SIZE_K": 128,
|
118
|
+
"GROUP_SIZE_M": 32,
|
119
|
+
"num_warps": 4,
|
120
|
+
"num_stages": 2
|
121
|
+
},
|
122
|
+
"2048": {
|
123
|
+
"BLOCK_SIZE_M": 64,
|
124
|
+
"BLOCK_SIZE_N": 64,
|
125
|
+
"BLOCK_SIZE_K": 128,
|
126
|
+
"GROUP_SIZE_M": 16,
|
127
|
+
"num_warps": 4,
|
128
|
+
"num_stages": 2
|
129
|
+
},
|
130
|
+
"3072": {
|
131
|
+
"BLOCK_SIZE_M": 64,
|
132
|
+
"BLOCK_SIZE_N": 64,
|
133
|
+
"BLOCK_SIZE_K": 128,
|
134
|
+
"GROUP_SIZE_M": 16,
|
135
|
+
"num_warps": 4,
|
136
|
+
"num_stages": 2
|
137
|
+
},
|
138
|
+
"4096": {
|
139
|
+
"BLOCK_SIZE_M": 64,
|
140
|
+
"BLOCK_SIZE_N": 64,
|
141
|
+
"BLOCK_SIZE_K": 128,
|
142
|
+
"GROUP_SIZE_M": 16,
|
143
|
+
"num_warps": 4,
|
144
|
+
"num_stages": 2
|
145
|
+
}
|
146
|
+
}
|