sglang 0.4.4__py3-none-any.whl → 0.4.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +6 -0
- sglang/bench_one_batch.py +1 -1
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +3 -1
- sglang/check_env.py +3 -4
- sglang/lang/backend/openai.py +18 -5
- sglang/lang/chat_template.py +28 -7
- sglang/lang/interpreter.py +7 -3
- sglang/lang/ir.py +10 -0
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/code_completion_parser.py +174 -0
- sglang/srt/configs/__init__.py +2 -6
- sglang/srt/configs/deepseekvl2.py +667 -0
- sglang/srt/configs/janus_pro.py +3 -4
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +63 -11
- sglang/srt/configs/utils.py +25 -0
- sglang/srt/connector/__init__.py +51 -0
- sglang/srt/connector/base_connector.py +112 -0
- sglang/srt/connector/redis.py +85 -0
- sglang/srt/connector/s3.py +122 -0
- sglang/srt/connector/serde/__init__.py +31 -0
- sglang/srt/connector/serde/safe_serde.py +29 -0
- sglang/srt/connector/serde/serde.py +43 -0
- sglang/srt/connector/utils.py +35 -0
- sglang/srt/conversation.py +88 -0
- sglang/srt/disaggregation/conn.py +81 -0
- sglang/srt/disaggregation/decode.py +495 -0
- sglang/srt/disaggregation/mini_lb.py +285 -0
- sglang/srt/disaggregation/prefill.py +249 -0
- sglang/srt/disaggregation/utils.py +44 -0
- sglang/srt/distributed/parallel_state.py +10 -3
- sglang/srt/entrypoints/engine.py +55 -5
- sglang/srt/entrypoints/http_server.py +71 -12
- sglang/srt/function_call_parser.py +164 -54
- sglang/srt/hf_transformers_utils.py +28 -3
- sglang/srt/layers/activation.py +4 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +295 -0
- sglang/srt/layers/attention/flashinfer_backend.py +1 -1
- sglang/srt/layers/attention/flashmla_backend.py +284 -0
- sglang/srt/layers/attention/triton_backend.py +171 -38
- sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
- sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
- sglang/srt/layers/attention/utils.py +53 -0
- sglang/srt/layers/attention/vision.py +9 -28
- sglang/srt/layers/dp_attention.py +62 -23
- sglang/srt/layers/elementwise.py +411 -0
- sglang/srt/layers/layernorm.py +24 -2
- sglang/srt/layers/linear.py +17 -5
- sglang/srt/layers/logits_processor.py +26 -7
- sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
- sglang/srt/layers/moe/ep_moe/layer.py +273 -1
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
- sglang/srt/layers/moe/fused_moe_native.py +2 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
- sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
- sglang/srt/layers/moe/router.py +342 -0
- sglang/srt/layers/moe/topk.py +31 -18
- sglang/srt/layers/parameter.py +1 -1
- sglang/srt/layers/quantization/__init__.py +184 -126
- sglang/srt/layers/quantization/base_config.py +5 -0
- sglang/srt/layers/quantization/blockwise_int8.py +1 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
- sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
- sglang/srt/layers/quantization/fp8.py +76 -34
- sglang/srt/layers/quantization/fp8_kernel.py +24 -8
- sglang/srt/layers/quantization/fp8_utils.py +284 -28
- sglang/srt/layers/quantization/gptq.py +36 -9
- sglang/srt/layers/quantization/kv_cache.py +98 -0
- sglang/srt/layers/quantization/modelopt_quant.py +9 -7
- sglang/srt/layers/quantization/utils.py +153 -0
- sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
- sglang/srt/layers/rotary_embedding.py +66 -87
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/layers.py +68 -0
- sglang/srt/lora/lora.py +2 -22
- sglang/srt/lora/lora_manager.py +47 -23
- sglang/srt/lora/mem_pool.py +110 -51
- sglang/srt/lora/utils.py +12 -1
- sglang/srt/managers/cache_controller.py +4 -5
- sglang/srt/managers/data_parallel_controller.py +31 -9
- sglang/srt/managers/expert_distribution.py +81 -0
- sglang/srt/managers/io_struct.py +39 -3
- sglang/srt/managers/mm_utils.py +373 -0
- sglang/srt/managers/multimodal_processor.py +68 -0
- sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
- sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
- sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
- sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
- sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
- sglang/srt/managers/schedule_batch.py +134 -31
- sglang/srt/managers/scheduler.py +325 -38
- sglang/srt/managers/scheduler_output_processor_mixin.py +4 -1
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +59 -23
- sglang/srt/managers/tp_worker.py +1 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
- sglang/srt/managers/utils.py +6 -1
- sglang/srt/mem_cache/hiradix_cache.py +27 -8
- sglang/srt/mem_cache/memory_pool.py +258 -98
- sglang/srt/mem_cache/paged_allocator.py +2 -2
- sglang/srt/mem_cache/radix_cache.py +4 -4
- sglang/srt/model_executor/cuda_graph_runner.py +85 -28
- sglang/srt/model_executor/forward_batch_info.py +81 -15
- sglang/srt/model_executor/model_runner.py +70 -6
- sglang/srt/model_loader/loader.py +160 -2
- sglang/srt/model_loader/weight_utils.py +45 -0
- sglang/srt/models/deepseek_janus_pro.py +29 -86
- sglang/srt/models/deepseek_nextn.py +22 -10
- sglang/srt/models/deepseek_v2.py +326 -192
- sglang/srt/models/deepseek_vl2.py +358 -0
- sglang/srt/models/gemma3_causal.py +684 -0
- sglang/srt/models/gemma3_mm.py +462 -0
- sglang/srt/models/grok.py +374 -119
- sglang/srt/models/llama.py +47 -7
- sglang/srt/models/llama_eagle.py +1 -0
- sglang/srt/models/llama_eagle3.py +196 -0
- sglang/srt/models/llava.py +3 -3
- sglang/srt/models/llavavid.py +3 -3
- sglang/srt/models/minicpmo.py +1995 -0
- sglang/srt/models/minicpmv.py +62 -137
- sglang/srt/models/mllama.py +4 -4
- sglang/srt/models/phi3_small.py +1 -1
- sglang/srt/models/qwen2.py +3 -0
- sglang/srt/models/qwen2_5_vl.py +68 -146
- sglang/srt/models/qwen2_classification.py +75 -0
- sglang/srt/models/qwen2_moe.py +9 -1
- sglang/srt/models/qwen2_vl.py +25 -63
- sglang/srt/openai_api/adapter.py +145 -47
- sglang/srt/openai_api/protocol.py +23 -2
- sglang/srt/sampling/sampling_batch_info.py +1 -1
- sglang/srt/sampling/sampling_params.py +6 -6
- sglang/srt/server_args.py +104 -14
- sglang/srt/speculative/build_eagle_tree.py +7 -347
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
- sglang/srt/speculative/eagle_utils.py +208 -252
- sglang/srt/speculative/eagle_worker.py +139 -53
- sglang/srt/speculative/spec_info.py +6 -1
- sglang/srt/torch_memory_saver_adapter.py +22 -0
- sglang/srt/utils.py +182 -21
- sglang/test/__init__.py +0 -0
- sglang/test/attention/__init__.py +0 -0
- sglang/test/attention/test_flashattn_backend.py +312 -0
- sglang/test/runners.py +2 -0
- sglang/test/test_activation.py +2 -1
- sglang/test/test_block_fp8.py +5 -4
- sglang/test/test_block_fp8_ep.py +2 -1
- sglang/test/test_dynamic_grad_mode.py +58 -0
- sglang/test/test_layernorm.py +3 -2
- sglang/test/test_utils.py +55 -4
- sglang/utils.py +31 -0
- sglang/version.py +1 -1
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/METADATA +12 -8
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/RECORD +171 -125
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/WHEEL +1 -1
- sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
- sglang/srt/managers/image_processor.py +0 -55
- sglang/srt/managers/image_processors/base_image_processor.py +0 -219
- sglang/srt/managers/image_processors/minicpmv.py +0 -86
- sglang/srt/managers/multi_modality_padding.py +0 -134
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info/licenses}/LICENSE +0 -0
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,342 @@
|
|
1
|
+
from typing import Tuple
|
2
|
+
|
3
|
+
import torch
|
4
|
+
import triton
|
5
|
+
import triton.language as tl
|
6
|
+
|
7
|
+
from sglang.srt.layers.moe.topk import fused_topk
|
8
|
+
|
9
|
+
|
10
|
+
@triton.jit
|
11
|
+
def fused_moe_router_kernel(
|
12
|
+
input_ptr, # input (bs, hidden_dim)
|
13
|
+
moe_router_weight_ptr, # input (num_experts, hidden_dim)
|
14
|
+
topk_weights_ptr, # output (bs, topk)
|
15
|
+
topk_ids_ptr, # output (bs, topk)
|
16
|
+
num_experts: tl.constexpr,
|
17
|
+
topk: tl.constexpr,
|
18
|
+
moe_softcapping: tl.constexpr,
|
19
|
+
moe_renormalize: tl.constexpr, # not supported
|
20
|
+
hidden_dim: tl.constexpr,
|
21
|
+
BLOCK_SIZE: tl.constexpr,
|
22
|
+
):
|
23
|
+
pid = tl.program_id(axis=0)
|
24
|
+
|
25
|
+
offsets = tl.arange(0, BLOCK_SIZE)
|
26
|
+
mask = offsets < hidden_dim
|
27
|
+
|
28
|
+
# moe_router_weight is k major
|
29
|
+
expert_offsets = tl.arange(0, num_experts)[:, None]
|
30
|
+
router_mask = mask[None, :]
|
31
|
+
w_router = tl.load(
|
32
|
+
moe_router_weight_ptr + expert_offsets * hidden_dim + offsets[None, :],
|
33
|
+
mask=router_mask,
|
34
|
+
other=0.0,
|
35
|
+
)
|
36
|
+
|
37
|
+
x = tl.load(input_ptr + pid * hidden_dim + offsets, mask=mask, other=0.0)
|
38
|
+
|
39
|
+
# todo: tl.dot?
|
40
|
+
logits = tl.sum((w_router.to(tl.float32) * x[None, :].to(tl.float32)), axis=-1)
|
41
|
+
|
42
|
+
# logit softcap
|
43
|
+
logits_scaled = logits / moe_softcapping
|
44
|
+
exped = tl.exp(2 * logits_scaled)
|
45
|
+
top = exped - 1
|
46
|
+
bottom = exped + 1
|
47
|
+
logits_softcapped = top / bottom * moe_softcapping
|
48
|
+
|
49
|
+
# topk
|
50
|
+
# assert 1 <= topk <= num_experts
|
51
|
+
|
52
|
+
# 5.38 us
|
53
|
+
|
54
|
+
top1 = tl.argmax(logits_softcapped, axis=0)
|
55
|
+
tl.store(topk_ids_ptr + pid * topk + 0, top1) # 5.63 us
|
56
|
+
|
57
|
+
top1_v = tl.max(logits_softcapped, axis=0)
|
58
|
+
invsumexp = 1.0 / tl.sum(tl.exp(logits_softcapped - top1_v), axis=0)
|
59
|
+
|
60
|
+
tl.store(
|
61
|
+
topk_weights_ptr + pid * topk + 0,
|
62
|
+
invsumexp,
|
63
|
+
) # 5.73 us
|
64
|
+
|
65
|
+
if topk >= 2:
|
66
|
+
top2 = tl.argmax(
|
67
|
+
tl.where(
|
68
|
+
tl.arange(0, num_experts) != top1, logits_softcapped, float("-inf")
|
69
|
+
),
|
70
|
+
axis=0,
|
71
|
+
)
|
72
|
+
tl.store(topk_ids_ptr + pid * topk + 1, top2)
|
73
|
+
top2_v = tl.sum(logits_softcapped * (tl.arange(0, num_experts) == top2), axis=0)
|
74
|
+
tl.store(
|
75
|
+
topk_weights_ptr + pid * topk + 1,
|
76
|
+
tl.exp(top2_v - top1_v) * invsumexp,
|
77
|
+
) # 5.95us
|
78
|
+
|
79
|
+
# probably slow
|
80
|
+
if topk > 2:
|
81
|
+
topk_mask = tl.full(logits_softcapped.shape, 1.0, dtype=logits_softcapped.dtype)
|
82
|
+
topk_mask = tl.where(
|
83
|
+
tl.arange(0, num_experts) != top1, topk_mask, float("-inf")
|
84
|
+
)
|
85
|
+
topk_mask = tl.where(
|
86
|
+
tl.arange(0, num_experts) != top2, topk_mask, float("-inf")
|
87
|
+
)
|
88
|
+
for i in range(2, topk):
|
89
|
+
topi = tl.argmax(logits_softcapped + topk_mask, axis=0)
|
90
|
+
topk_mask = tl.where(
|
91
|
+
tl.arange(0, num_experts) != topi, topk_mask, float("-inf")
|
92
|
+
)
|
93
|
+
tl.store(topk_ids_ptr + pid * topk + i, topi)
|
94
|
+
topi_v = tl.sum(
|
95
|
+
logits_softcapped * (tl.arange(0, num_experts) == topi), axis=0
|
96
|
+
)
|
97
|
+
tl.store(
|
98
|
+
topk_weights_ptr + pid * topk + i,
|
99
|
+
tl.exp(topi_v - top1_v) * invsumexp,
|
100
|
+
)
|
101
|
+
# assert not moe_renormalize, "moe weight renormalization not implemented"
|
102
|
+
|
103
|
+
|
104
|
+
def fused_moe_router_impl(
|
105
|
+
x: torch.Tensor,
|
106
|
+
router_weight: torch.Tensor,
|
107
|
+
topk: int,
|
108
|
+
moe_softcapping: float,
|
109
|
+
):
|
110
|
+
assert len(x.shape) == 2 and x.shape[1] == router_weight.shape[1]
|
111
|
+
bs, hidden_dim = x.shape
|
112
|
+
num_experts = router_weight.shape[0]
|
113
|
+
|
114
|
+
# router_logits = torch.empty((bs, num_experts), dtype=torch.float32, device=x.device)
|
115
|
+
topk_weights = torch.empty((bs, topk), dtype=torch.float32, device=x.device)
|
116
|
+
topk_ids = torch.empty((bs, topk), dtype=torch.int32, device=x.device)
|
117
|
+
|
118
|
+
grid = lambda meta: (bs,)
|
119
|
+
config = {
|
120
|
+
"BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
|
121
|
+
"num_warps": max(
|
122
|
+
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), 32), 4
|
123
|
+
),
|
124
|
+
}
|
125
|
+
|
126
|
+
fused_moe_router_kernel[grid](
|
127
|
+
x,
|
128
|
+
router_weight,
|
129
|
+
topk_weights,
|
130
|
+
topk_ids,
|
131
|
+
num_experts=num_experts,
|
132
|
+
topk=topk,
|
133
|
+
moe_softcapping=moe_softcapping,
|
134
|
+
moe_renormalize=False,
|
135
|
+
hidden_dim=hidden_dim,
|
136
|
+
**config,
|
137
|
+
)
|
138
|
+
|
139
|
+
return topk_weights, topk_ids
|
140
|
+
|
141
|
+
|
142
|
+
@triton.jit
|
143
|
+
def fused_moe_router_large_bs_kernel(
|
144
|
+
a_ptr, # input (bs, hidden_dim)
|
145
|
+
b_ptr, # input (num_experts, hidden_dim)
|
146
|
+
topk_weights_ptr, # output (bs, topk)
|
147
|
+
topk_ids_ptr, # output (bs, topk)
|
148
|
+
bs,
|
149
|
+
num_experts: tl.constexpr,
|
150
|
+
topk: tl.constexpr, # only support topk == 1
|
151
|
+
moe_softcapping: tl.constexpr,
|
152
|
+
moe_renormalize: tl.constexpr, # not supported
|
153
|
+
K: tl.constexpr,
|
154
|
+
BLOCK_SIZE_M: tl.constexpr,
|
155
|
+
BLOCK_SIZE_N: tl.constexpr,
|
156
|
+
BLOCK_SIZE_K: tl.constexpr,
|
157
|
+
stride_am: tl.constexpr,
|
158
|
+
stride_bn: tl.constexpr,
|
159
|
+
):
|
160
|
+
|
161
|
+
# 1. get block id
|
162
|
+
pid = tl.program_id(axis=0)
|
163
|
+
|
164
|
+
# 2. create pointers for the first block of A and B
|
165
|
+
# 2.1. setup a_ptrs with offsets in m and k
|
166
|
+
offs_m = pid * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)[:, None]
|
167
|
+
bs_mask = offs_m < bs
|
168
|
+
offs_k = tl.arange(0, BLOCK_SIZE_K)[None, :]
|
169
|
+
a_ptrs = a_ptr + (offs_m * stride_am + offs_k)
|
170
|
+
|
171
|
+
# 2.2. setup b_ptrs with offsets in k and n.
|
172
|
+
# Note: b matrix is k-major.
|
173
|
+
offs_k = tl.arange(0, BLOCK_SIZE_K)[None, :]
|
174
|
+
offs_n = tl.arange(0, BLOCK_SIZE_N)[:, None]
|
175
|
+
expert_mask = offs_n < num_experts
|
176
|
+
b_ptrs = b_ptr + (offs_n * stride_bn + offs_k)
|
177
|
+
|
178
|
+
# 3. Create an accumulator of float32 of size [BLOCK_SIZE_M, BLOCK_SIZE_N]
|
179
|
+
# 3.1. iterate in K dimension
|
180
|
+
# 3.2. transpose tile B
|
181
|
+
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
182
|
+
for k in range(0, K // BLOCK_SIZE_K): # hidden_dim % BLOCK_SIZE_K == 0
|
183
|
+
a = tl.load(
|
184
|
+
a_ptrs,
|
185
|
+
mask=bs_mask,
|
186
|
+
other=0.0,
|
187
|
+
).to(tl.float32)
|
188
|
+
b = tl.load(b_ptrs, mask=expert_mask, other=0.0).to(tl.float32).T
|
189
|
+
acc += tl.dot(a, b)
|
190
|
+
|
191
|
+
# Advance the ptrs to the next K block.
|
192
|
+
a_ptrs += BLOCK_SIZE_K
|
193
|
+
b_ptrs += BLOCK_SIZE_K
|
194
|
+
|
195
|
+
# 4. logit softcap
|
196
|
+
logits_scaled = acc / moe_softcapping
|
197
|
+
exped = tl.exp(2 * logits_scaled)
|
198
|
+
logits_softcapped = (exped - 1) / (exped + 1) * moe_softcapping
|
199
|
+
|
200
|
+
# 5. top1
|
201
|
+
cond = tl.arange(0, BLOCK_SIZE_N)[None, :] < num_experts
|
202
|
+
top1 = tl.argmax(tl.where(cond, logits_softcapped, float("-inf")), axis=1)
|
203
|
+
top1_v = tl.max(
|
204
|
+
tl.where(cond, logits_softcapped, float("-inf")), axis=1, keep_dims=True
|
205
|
+
)
|
206
|
+
invsumexp = 1.0 / tl.sum(
|
207
|
+
tl.where(cond, tl.exp(logits_softcapped - top1_v), 0.0), axis=1
|
208
|
+
)
|
209
|
+
|
210
|
+
# 6. store to output
|
211
|
+
offs_topk = pid * topk * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
212
|
+
topk_mask = offs_topk < bs
|
213
|
+
tl.store(topk_ids_ptr + offs_topk, top1, mask=topk_mask)
|
214
|
+
tl.store(
|
215
|
+
topk_weights_ptr + offs_topk,
|
216
|
+
invsumexp,
|
217
|
+
mask=topk_mask,
|
218
|
+
)
|
219
|
+
|
220
|
+
|
221
|
+
def fused_moe_router_large_bs_impl(
|
222
|
+
x: torch.Tensor,
|
223
|
+
router_weight: torch.Tensor,
|
224
|
+
topk: int,
|
225
|
+
moe_softcapping: float,
|
226
|
+
BLOCK_SIZE_M: int,
|
227
|
+
BLOCK_SIZE_N: int,
|
228
|
+
BLOCK_SIZE_K: int,
|
229
|
+
):
|
230
|
+
assert len(x.shape) == 2 and x.shape[1] == router_weight.shape[1]
|
231
|
+
bs, hidden_dim = x.shape
|
232
|
+
num_experts = router_weight.shape[0]
|
233
|
+
|
234
|
+
assert num_experts <= BLOCK_SIZE_N
|
235
|
+
assert hidden_dim % BLOCK_SIZE_K == 0
|
236
|
+
assert topk == 1
|
237
|
+
|
238
|
+
topk_weights = torch.empty((bs, topk), dtype=torch.float32, device=x.device)
|
239
|
+
topk_ids = torch.empty((bs, topk), dtype=torch.int32, device=x.device)
|
240
|
+
|
241
|
+
grid = (triton.cdiv(bs, BLOCK_SIZE_M) * triton.cdiv(num_experts, BLOCK_SIZE_N),)
|
242
|
+
|
243
|
+
fused_moe_router_large_bs_kernel[grid](
|
244
|
+
a_ptr=x,
|
245
|
+
b_ptr=router_weight,
|
246
|
+
topk_weights_ptr=topk_weights,
|
247
|
+
topk_ids_ptr=topk_ids,
|
248
|
+
bs=bs,
|
249
|
+
num_experts=num_experts,
|
250
|
+
topk=topk,
|
251
|
+
moe_softcapping=moe_softcapping,
|
252
|
+
moe_renormalize=False,
|
253
|
+
K=hidden_dim,
|
254
|
+
BLOCK_SIZE_M=BLOCK_SIZE_M,
|
255
|
+
BLOCK_SIZE_N=BLOCK_SIZE_N,
|
256
|
+
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
257
|
+
stride_am=hidden_dim,
|
258
|
+
stride_bn=hidden_dim,
|
259
|
+
)
|
260
|
+
|
261
|
+
return topk_weights, topk_ids
|
262
|
+
|
263
|
+
|
264
|
+
def fused_moe_router_shim(
|
265
|
+
moe_softcapping,
|
266
|
+
hidden_states,
|
267
|
+
gating_output,
|
268
|
+
topk,
|
269
|
+
renormalize,
|
270
|
+
):
|
271
|
+
assert not renormalize
|
272
|
+
assert (
|
273
|
+
len(hidden_states.shape) == 2
|
274
|
+
and hidden_states.shape[1] == gating_output.shape[1]
|
275
|
+
)
|
276
|
+
bs, hidden_dim = hidden_states.shape
|
277
|
+
num_experts = gating_output.shape[0]
|
278
|
+
BLOCK_SIZE_M = 32
|
279
|
+
BLOCK_SIZE_N = 16
|
280
|
+
BLOCK_SIZE_K = 256
|
281
|
+
if (
|
282
|
+
bs >= 512
|
283
|
+
and topk == 1
|
284
|
+
and num_experts <= BLOCK_SIZE_N
|
285
|
+
and hidden_dim % BLOCK_SIZE_K == 0
|
286
|
+
):
|
287
|
+
return fused_moe_router_large_bs_impl(
|
288
|
+
x=hidden_states,
|
289
|
+
router_weight=gating_output,
|
290
|
+
topk=topk,
|
291
|
+
moe_softcapping=moe_softcapping,
|
292
|
+
BLOCK_SIZE_M=BLOCK_SIZE_M,
|
293
|
+
BLOCK_SIZE_N=BLOCK_SIZE_N,
|
294
|
+
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
295
|
+
)
|
296
|
+
else:
|
297
|
+
return fused_moe_router_impl(
|
298
|
+
x=hidden_states,
|
299
|
+
router_weight=gating_output,
|
300
|
+
topk=topk,
|
301
|
+
moe_softcapping=moe_softcapping,
|
302
|
+
)
|
303
|
+
|
304
|
+
|
305
|
+
class FusedMoeRouter:
|
306
|
+
def __init__(self, router_linear, topk, moe_softcapping) -> None:
|
307
|
+
self.router_linear = router_linear
|
308
|
+
self.topk = topk
|
309
|
+
self.moe_softcapping = moe_softcapping
|
310
|
+
|
311
|
+
def __call__(self, *args, **kwargs):
|
312
|
+
return self.forward(*args, **kwargs)
|
313
|
+
|
314
|
+
def forward(
|
315
|
+
self, x: torch.Tensor, residual: torch.Tensor
|
316
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
317
|
+
if x.is_cuda:
|
318
|
+
return self.forward_cuda(x, residual)
|
319
|
+
else:
|
320
|
+
return self.forward_vllm(x, residual)
|
321
|
+
|
322
|
+
def forward_cuda(
|
323
|
+
self, x: torch.Tensor, autotune=False
|
324
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
325
|
+
return fused_moe_router_shim(
|
326
|
+
moe_softcapping=self.moe_softcapping,
|
327
|
+
hidden_states=x,
|
328
|
+
gating_output=self.router_linear.weight,
|
329
|
+
topk=self.topk,
|
330
|
+
renormalize=False,
|
331
|
+
)
|
332
|
+
|
333
|
+
def forward_vllm(
|
334
|
+
self,
|
335
|
+
x: torch.Tensor,
|
336
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
337
|
+
# g, _ = self.router_linear.forward(x)
|
338
|
+
g = x.float() @ self.router_linear.weight.T.float()
|
339
|
+
|
340
|
+
g = torch.tanh(g.float() / self.moe_softcapping) * self.moe_softcapping
|
341
|
+
|
342
|
+
return fused_topk(x, g, self.topk, False)
|
sglang/srt/layers/moe/topk.py
CHANGED
@@ -17,7 +17,14 @@ from typing import Callable, Optional
|
|
17
17
|
import torch
|
18
18
|
import torch.nn.functional as F
|
19
19
|
|
20
|
-
from sglang.srt.utils import get_compiler_backend
|
20
|
+
from sglang.srt.utils import get_compiler_backend, is_cuda, is_hip
|
21
|
+
|
22
|
+
_is_cuda = is_cuda()
|
23
|
+
_is_hip = is_hip()
|
24
|
+
|
25
|
+
from sglang.srt.managers.expert_distribution import ExpertDistributionRecorder
|
26
|
+
|
27
|
+
expert_distribution_recorder = ExpertDistributionRecorder()
|
21
28
|
|
22
29
|
|
23
30
|
def fused_topk_native(
|
@@ -47,7 +54,10 @@ def fused_topk(
|
|
47
54
|
topk: int,
|
48
55
|
renormalize: bool,
|
49
56
|
):
|
50
|
-
|
57
|
+
if _is_cuda or _is_hip:
|
58
|
+
from sgl_kernel import topk_softmax
|
59
|
+
else:
|
60
|
+
from vllm import _custom_ops as vllm_ops
|
51
61
|
|
52
62
|
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
53
63
|
|
@@ -61,12 +71,20 @@ def fused_topk(
|
|
61
71
|
M, topk, dtype=torch.int32, device=hidden_states.device
|
62
72
|
)
|
63
73
|
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
74
|
+
if _is_cuda or _is_hip:
|
75
|
+
topk_softmax(
|
76
|
+
topk_weights,
|
77
|
+
topk_ids,
|
78
|
+
token_expert_indicies,
|
79
|
+
gating_output.float(),
|
80
|
+
)
|
81
|
+
else:
|
82
|
+
vllm_ops.topk_softmax(
|
83
|
+
topk_weights,
|
84
|
+
topk_ids,
|
85
|
+
token_expert_indicies,
|
86
|
+
gating_output.float(),
|
87
|
+
)
|
70
88
|
del token_expert_indicies
|
71
89
|
|
72
90
|
if renormalize:
|
@@ -75,6 +93,7 @@ def fused_topk(
|
|
75
93
|
return topk_weights, topk_ids
|
76
94
|
|
77
95
|
|
96
|
+
# This is used by the Deepseek V2/V3/R1 series models
|
78
97
|
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
79
98
|
def grouped_topk(
|
80
99
|
hidden_states: torch.Tensor,
|
@@ -83,17 +102,10 @@ def grouped_topk(
|
|
83
102
|
renormalize: bool,
|
84
103
|
num_expert_group: int = 0,
|
85
104
|
topk_group: int = 0,
|
86
|
-
scoring_func: str = "softmax",
|
87
105
|
):
|
88
106
|
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
89
107
|
|
90
|
-
|
91
|
-
scores = torch.softmax(gating_output, dim=-1)
|
92
|
-
elif scoring_func == "sigmoid":
|
93
|
-
scores = gating_output.sigmoid()
|
94
|
-
else:
|
95
|
-
raise ValueError(f"Scoring function '{scoring_func}' is not supported.")
|
96
|
-
|
108
|
+
scores = torch.softmax(gating_output, dim=-1)
|
97
109
|
num_token = scores.shape[0]
|
98
110
|
group_scores = (
|
99
111
|
scores.view(num_token, num_expert_group, -1).max(dim=-1).values
|
@@ -117,7 +129,6 @@ def grouped_topk(
|
|
117
129
|
return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
118
130
|
|
119
131
|
|
120
|
-
# DeepSeek V2/V3/R1 uses biased_grouped_top
|
121
132
|
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
122
133
|
def biased_grouped_topk(
|
123
134
|
hidden_states: torch.Tensor,
|
@@ -172,7 +183,7 @@ def select_experts(
|
|
172
183
|
correction_bias: Optional[torch.Tensor] = None,
|
173
184
|
torch_native: bool = False,
|
174
185
|
):
|
175
|
-
#
|
186
|
+
# DeekSeekv2 uses grouped_top_k
|
176
187
|
if use_grouped_topk:
|
177
188
|
assert topk_group is not None
|
178
189
|
assert num_expert_group is not None
|
@@ -217,4 +228,6 @@ def select_experts(
|
|
217
228
|
renormalize=renormalize,
|
218
229
|
)
|
219
230
|
|
231
|
+
expert_distribution_recorder.record_new_token(topk_ids)
|
232
|
+
|
220
233
|
return topk_weights, topk_ids
|
sglang/srt/layers/parameter.py
CHANGED
@@ -105,6 +105,7 @@ class _ColumnvLLMParameter(BasevLLMParameter):
|
|
105
105
|
|
106
106
|
shard_offset = kwargs.get("shard_offset")
|
107
107
|
shard_size = kwargs.get("shard_size")
|
108
|
+
tp_rank = kwargs.get("tp_rank")
|
108
109
|
use_presharded_weights = kwargs.get("use_presharded_weights")
|
109
110
|
if (
|
110
111
|
isinstance(self, (PackedColumnParameter, PackedvLLMParameter))
|
@@ -116,7 +117,6 @@ class _ColumnvLLMParameter(BasevLLMParameter):
|
|
116
117
|
|
117
118
|
param_data = self.data
|
118
119
|
|
119
|
-
tp_rank = get_tensor_model_parallel_rank()
|
120
120
|
param_data = param_data.narrow(self.output_dim, shard_offset, shard_size)
|
121
121
|
if not use_presharded_weights:
|
122
122
|
loaded_weight = loaded_weight.narrow(
|