sglang 0.4.3.post4__py3-none-any.whl → 0.4.4.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +1 -1
- sglang/lang/chat_template.py +29 -0
- sglang/srt/_custom_ops.py +19 -17
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/janus_pro.py +629 -0
- sglang/srt/configs/model_config.py +24 -14
- sglang/srt/conversation.py +80 -2
- sglang/srt/custom_op.py +64 -3
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +18 -17
- sglang/srt/distributed/parallel_state.py +10 -1
- sglang/srt/entrypoints/engine.py +5 -3
- sglang/srt/entrypoints/http_server.py +1 -1
- sglang/srt/function_call_parser.py +33 -2
- sglang/srt/hf_transformers_utils.py +16 -1
- sglang/srt/layers/attention/flashinfer_backend.py +1 -1
- sglang/srt/layers/attention/flashinfer_mla_backend.py +317 -57
- sglang/srt/layers/attention/triton_backend.py +1 -3
- sglang/srt/layers/attention/triton_ops/decode_attention.py +6 -6
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +3 -3
- sglang/srt/layers/attention/triton_ops/extend_attention.py +4 -4
- sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +3 -3
- sglang/srt/layers/attention/vision.py +43 -62
- sglang/srt/layers/dp_attention.py +30 -2
- sglang/srt/layers/elementwise.py +411 -0
- sglang/srt/layers/linear.py +1 -1
- sglang/srt/layers/logits_processor.py +1 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -1
- sglang/srt/layers/moe/ep_moe/layer.py +25 -9
- sglang/srt/layers/moe/fused_moe_triton/configs/E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +63 -23
- sglang/srt/layers/moe/fused_moe_triton/layer.py +16 -4
- sglang/srt/layers/moe/router.py +342 -0
- sglang/srt/layers/parameter.py +10 -0
- sglang/srt/layers/quantization/__init__.py +90 -68
- sglang/srt/layers/quantization/blockwise_int8.py +1 -2
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/fp8.py +174 -106
- sglang/srt/layers/quantization/fp8_kernel.py +210 -38
- sglang/srt/layers/quantization/fp8_utils.py +156 -15
- sglang/srt/layers/quantization/modelopt_quant.py +5 -1
- sglang/srt/layers/quantization/w8a8_fp8.py +128 -0
- sglang/srt/layers/quantization/w8a8_int8.py +152 -3
- sglang/srt/layers/rotary_embedding.py +5 -3
- sglang/srt/layers/sampler.py +29 -35
- sglang/srt/layers/vocab_parallel_embedding.py +0 -1
- sglang/srt/lora/backend/__init__.py +9 -12
- sglang/srt/managers/cache_controller.py +74 -8
- sglang/srt/managers/data_parallel_controller.py +1 -1
- sglang/srt/managers/image_processor.py +37 -631
- sglang/srt/managers/image_processors/base_image_processor.py +219 -0
- sglang/srt/managers/image_processors/janus_pro.py +79 -0
- sglang/srt/managers/image_processors/llava.py +152 -0
- sglang/srt/managers/image_processors/minicpmv.py +86 -0
- sglang/srt/managers/image_processors/mlama.py +60 -0
- sglang/srt/managers/image_processors/qwen_vl.py +161 -0
- sglang/srt/managers/io_struct.py +32 -15
- sglang/srt/managers/multi_modality_padding.py +134 -0
- sglang/srt/managers/schedule_batch.py +213 -118
- sglang/srt/managers/schedule_policy.py +40 -8
- sglang/srt/managers/scheduler.py +176 -683
- sglang/srt/managers/scheduler_output_processor_mixin.py +614 -0
- sglang/srt/managers/tokenizer_manager.py +6 -6
- sglang/srt/managers/tp_worker_overlap_thread.py +4 -1
- sglang/srt/mem_cache/base_prefix_cache.py +6 -8
- sglang/srt/mem_cache/chunk_cache.py +12 -44
- sglang/srt/mem_cache/hiradix_cache.py +71 -34
- sglang/srt/mem_cache/memory_pool.py +81 -17
- sglang/srt/mem_cache/paged_allocator.py +283 -0
- sglang/srt/mem_cache/radix_cache.py +117 -36
- sglang/srt/model_executor/cuda_graph_runner.py +68 -20
- sglang/srt/model_executor/forward_batch_info.py +23 -10
- sglang/srt/model_executor/model_runner.py +63 -63
- sglang/srt/model_loader/loader.py +2 -1
- sglang/srt/model_loader/weight_utils.py +1 -1
- sglang/srt/models/deepseek_janus_pro.py +2127 -0
- sglang/srt/models/deepseek_nextn.py +23 -3
- sglang/srt/models/deepseek_v2.py +200 -191
- sglang/srt/models/grok.py +374 -119
- sglang/srt/models/minicpmv.py +28 -89
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/qwen2.py +0 -1
- sglang/srt/models/qwen2_5_vl.py +25 -50
- sglang/srt/models/qwen2_vl.py +33 -49
- sglang/srt/openai_api/adapter.py +59 -35
- sglang/srt/openai_api/protocol.py +8 -1
- sglang/srt/sampling/penaltylib/frequency_penalty.py +0 -1
- sglang/srt/sampling/penaltylib/presence_penalty.py +0 -1
- sglang/srt/server_args.py +24 -16
- sglang/srt/speculative/eagle_worker.py +75 -39
- sglang/srt/utils.py +104 -9
- sglang/test/runners.py +104 -10
- sglang/test/test_block_fp8.py +106 -16
- sglang/test/test_custom_ops.py +88 -0
- sglang/test/test_utils.py +20 -4
- sglang/utils.py +0 -4
- sglang/version.py +1 -1
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/METADATA +9 -10
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/RECORD +131 -84
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/WHEEL +1 -1
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/LICENSE +0 -0
- {sglang-0.4.3.post4.dist-info → sglang-0.4.4.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,342 @@
|
|
1
|
+
from typing import Tuple
|
2
|
+
|
3
|
+
import torch
|
4
|
+
import triton
|
5
|
+
import triton.language as tl
|
6
|
+
|
7
|
+
from sglang.srt.layers.moe.topk import fused_topk
|
8
|
+
|
9
|
+
|
10
|
+
@triton.jit
|
11
|
+
def fused_moe_router_kernel(
|
12
|
+
input_ptr, # input (bs, hidden_dim)
|
13
|
+
moe_router_weight_ptr, # input (num_experts, hidden_dim)
|
14
|
+
topk_weights_ptr, # output (bs, topk)
|
15
|
+
topk_ids_ptr, # output (bs, topk)
|
16
|
+
num_experts: tl.constexpr,
|
17
|
+
topk: tl.constexpr,
|
18
|
+
moe_softcapping: tl.constexpr,
|
19
|
+
moe_renormalize: tl.constexpr, # not supported
|
20
|
+
hidden_dim: tl.constexpr,
|
21
|
+
BLOCK_SIZE: tl.constexpr,
|
22
|
+
):
|
23
|
+
pid = tl.program_id(axis=0)
|
24
|
+
|
25
|
+
offsets = tl.arange(0, BLOCK_SIZE)
|
26
|
+
mask = offsets < hidden_dim
|
27
|
+
|
28
|
+
# moe_router_weight is k major
|
29
|
+
expert_offsets = tl.arange(0, num_experts)[:, None]
|
30
|
+
router_mask = mask[None, :]
|
31
|
+
w_router = tl.load(
|
32
|
+
moe_router_weight_ptr + expert_offsets * hidden_dim + offsets[None, :],
|
33
|
+
mask=router_mask,
|
34
|
+
other=0.0,
|
35
|
+
)
|
36
|
+
|
37
|
+
x = tl.load(input_ptr + pid * hidden_dim + offsets, mask=mask, other=0.0)
|
38
|
+
|
39
|
+
# todo: tl.dot?
|
40
|
+
logits = tl.sum((w_router.to(tl.float32) * x[None, :].to(tl.float32)), axis=-1)
|
41
|
+
|
42
|
+
# logit softcap
|
43
|
+
logits_scaled = logits / moe_softcapping
|
44
|
+
exped = tl.exp(2 * logits_scaled)
|
45
|
+
top = exped - 1
|
46
|
+
bottom = exped + 1
|
47
|
+
logits_softcapped = top / bottom * moe_softcapping
|
48
|
+
|
49
|
+
# topk
|
50
|
+
# assert 1 <= topk <= num_experts
|
51
|
+
|
52
|
+
# 5.38 us
|
53
|
+
|
54
|
+
top1 = tl.argmax(logits_softcapped, axis=0)
|
55
|
+
tl.store(topk_ids_ptr + pid * topk + 0, top1) # 5.63 us
|
56
|
+
|
57
|
+
top1_v = tl.max(logits_softcapped, axis=0)
|
58
|
+
invsumexp = 1.0 / tl.sum(tl.exp(logits_softcapped - top1_v), axis=0)
|
59
|
+
|
60
|
+
tl.store(
|
61
|
+
topk_weights_ptr + pid * topk + 0,
|
62
|
+
invsumexp,
|
63
|
+
) # 5.73 us
|
64
|
+
|
65
|
+
if topk >= 2:
|
66
|
+
top2 = tl.argmax(
|
67
|
+
tl.where(
|
68
|
+
tl.arange(0, num_experts) != top1, logits_softcapped, float("-inf")
|
69
|
+
),
|
70
|
+
axis=0,
|
71
|
+
)
|
72
|
+
tl.store(topk_ids_ptr + pid * topk + 1, top2)
|
73
|
+
top2_v = tl.sum(logits_softcapped * (tl.arange(0, num_experts) == top2), axis=0)
|
74
|
+
tl.store(
|
75
|
+
topk_weights_ptr + pid * topk + 1,
|
76
|
+
tl.exp(top2_v - top1_v) * invsumexp,
|
77
|
+
) # 5.95us
|
78
|
+
|
79
|
+
# probably slow
|
80
|
+
if topk > 2:
|
81
|
+
topk_mask = tl.full(logits_softcapped.shape, 1.0, dtype=logits_softcapped.dtype)
|
82
|
+
topk_mask = tl.where(
|
83
|
+
tl.arange(0, num_experts) != top1, topk_mask, float("-inf")
|
84
|
+
)
|
85
|
+
topk_mask = tl.where(
|
86
|
+
tl.arange(0, num_experts) != top2, topk_mask, float("-inf")
|
87
|
+
)
|
88
|
+
for i in range(2, topk):
|
89
|
+
topi = tl.argmax(logits_softcapped + topk_mask, axis=0)
|
90
|
+
topk_mask = tl.where(
|
91
|
+
tl.arange(0, num_experts) != topi, topk_mask, float("-inf")
|
92
|
+
)
|
93
|
+
tl.store(topk_ids_ptr + pid * topk + i, topi)
|
94
|
+
topi_v = tl.sum(
|
95
|
+
logits_softcapped * (tl.arange(0, num_experts) == topi), axis=0
|
96
|
+
)
|
97
|
+
tl.store(
|
98
|
+
topk_weights_ptr + pid * topk + i,
|
99
|
+
tl.exp(topi_v - top1_v) * invsumexp,
|
100
|
+
)
|
101
|
+
# assert not moe_renormalize, "moe weight renormalization not implemented"
|
102
|
+
|
103
|
+
|
104
|
+
def fused_moe_router_impl(
|
105
|
+
x: torch.Tensor,
|
106
|
+
router_weight: torch.Tensor,
|
107
|
+
topk: int,
|
108
|
+
moe_softcapping: float,
|
109
|
+
):
|
110
|
+
assert len(x.shape) == 2 and x.shape[1] == router_weight.shape[1]
|
111
|
+
bs, hidden_dim = x.shape
|
112
|
+
num_experts = router_weight.shape[0]
|
113
|
+
|
114
|
+
# router_logits = torch.empty((bs, num_experts), dtype=torch.float32, device=x.device)
|
115
|
+
topk_weights = torch.empty((bs, topk), dtype=torch.float32, device=x.device)
|
116
|
+
topk_ids = torch.empty((bs, topk), dtype=torch.int32, device=x.device)
|
117
|
+
|
118
|
+
grid = lambda meta: (bs,)
|
119
|
+
config = {
|
120
|
+
"BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
|
121
|
+
"num_warps": max(
|
122
|
+
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), 32), 4
|
123
|
+
),
|
124
|
+
}
|
125
|
+
|
126
|
+
fused_moe_router_kernel[grid](
|
127
|
+
x,
|
128
|
+
router_weight,
|
129
|
+
topk_weights,
|
130
|
+
topk_ids,
|
131
|
+
num_experts=num_experts,
|
132
|
+
topk=topk,
|
133
|
+
moe_softcapping=moe_softcapping,
|
134
|
+
moe_renormalize=False,
|
135
|
+
hidden_dim=hidden_dim,
|
136
|
+
**config,
|
137
|
+
)
|
138
|
+
|
139
|
+
return topk_weights, topk_ids
|
140
|
+
|
141
|
+
|
142
|
+
@triton.jit
|
143
|
+
def fused_moe_router_large_bs_kernel(
|
144
|
+
a_ptr, # input (bs, hidden_dim)
|
145
|
+
b_ptr, # input (num_experts, hidden_dim)
|
146
|
+
topk_weights_ptr, # output (bs, topk)
|
147
|
+
topk_ids_ptr, # output (bs, topk)
|
148
|
+
bs,
|
149
|
+
num_experts: tl.constexpr,
|
150
|
+
topk: tl.constexpr, # only support topk == 1
|
151
|
+
moe_softcapping: tl.constexpr,
|
152
|
+
moe_renormalize: tl.constexpr, # not supported
|
153
|
+
K: tl.constexpr,
|
154
|
+
BLOCK_SIZE_M: tl.constexpr,
|
155
|
+
BLOCK_SIZE_N: tl.constexpr,
|
156
|
+
BLOCK_SIZE_K: tl.constexpr,
|
157
|
+
stride_am: tl.constexpr,
|
158
|
+
stride_bn: tl.constexpr,
|
159
|
+
):
|
160
|
+
|
161
|
+
# 1. get block id
|
162
|
+
pid = tl.program_id(axis=0)
|
163
|
+
|
164
|
+
# 2. create pointers for the first block of A and B
|
165
|
+
# 2.1. setup a_ptrs with offsets in m and k
|
166
|
+
offs_m = pid * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)[:, None]
|
167
|
+
bs_mask = offs_m < bs
|
168
|
+
offs_k = tl.arange(0, BLOCK_SIZE_K)[None, :]
|
169
|
+
a_ptrs = a_ptr + (offs_m * stride_am + offs_k)
|
170
|
+
|
171
|
+
# 2.2. setup b_ptrs with offsets in k and n.
|
172
|
+
# Note: b matrix is k-major.
|
173
|
+
offs_k = tl.arange(0, BLOCK_SIZE_K)[None, :]
|
174
|
+
offs_n = tl.arange(0, BLOCK_SIZE_N)[:, None]
|
175
|
+
expert_mask = offs_n < num_experts
|
176
|
+
b_ptrs = b_ptr + (offs_n * stride_bn + offs_k)
|
177
|
+
|
178
|
+
# 3. Create an accumulator of float32 of size [BLOCK_SIZE_M, BLOCK_SIZE_N]
|
179
|
+
# 3.1. iterate in K dimension
|
180
|
+
# 3.2. transpose tile B
|
181
|
+
acc = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
|
182
|
+
for k in range(0, K // BLOCK_SIZE_K): # hidden_dim % BLOCK_SIZE_K == 0
|
183
|
+
a = tl.load(
|
184
|
+
a_ptrs,
|
185
|
+
mask=bs_mask,
|
186
|
+
other=0.0,
|
187
|
+
).to(tl.float32)
|
188
|
+
b = tl.load(b_ptrs, mask=expert_mask, other=0.0).to(tl.float32).T
|
189
|
+
acc += tl.dot(a, b)
|
190
|
+
|
191
|
+
# Advance the ptrs to the next K block.
|
192
|
+
a_ptrs += BLOCK_SIZE_K
|
193
|
+
b_ptrs += BLOCK_SIZE_K
|
194
|
+
|
195
|
+
# 4. logit softcap
|
196
|
+
logits_scaled = acc / moe_softcapping
|
197
|
+
exped = tl.exp(2 * logits_scaled)
|
198
|
+
logits_softcapped = (exped - 1) / (exped + 1) * moe_softcapping
|
199
|
+
|
200
|
+
# 5. top1
|
201
|
+
cond = tl.arange(0, BLOCK_SIZE_N)[None, :] < num_experts
|
202
|
+
top1 = tl.argmax(tl.where(cond, logits_softcapped, float("-inf")), axis=1)
|
203
|
+
top1_v = tl.max(
|
204
|
+
tl.where(cond, logits_softcapped, float("-inf")), axis=1, keep_dims=True
|
205
|
+
)
|
206
|
+
invsumexp = 1.0 / tl.sum(
|
207
|
+
tl.where(cond, tl.exp(logits_softcapped - top1_v), 0.0), axis=1
|
208
|
+
)
|
209
|
+
|
210
|
+
# 6. store to output
|
211
|
+
offs_topk = pid * topk * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
212
|
+
topk_mask = offs_topk < bs
|
213
|
+
tl.store(topk_ids_ptr + offs_topk, top1, mask=topk_mask)
|
214
|
+
tl.store(
|
215
|
+
topk_weights_ptr + offs_topk,
|
216
|
+
invsumexp,
|
217
|
+
mask=topk_mask,
|
218
|
+
)
|
219
|
+
|
220
|
+
|
221
|
+
def fused_moe_router_large_bs_impl(
|
222
|
+
x: torch.Tensor,
|
223
|
+
router_weight: torch.Tensor,
|
224
|
+
topk: int,
|
225
|
+
moe_softcapping: float,
|
226
|
+
BLOCK_SIZE_M: int,
|
227
|
+
BLOCK_SIZE_N: int,
|
228
|
+
BLOCK_SIZE_K: int,
|
229
|
+
):
|
230
|
+
assert len(x.shape) == 2 and x.shape[1] == router_weight.shape[1]
|
231
|
+
bs, hidden_dim = x.shape
|
232
|
+
num_experts = router_weight.shape[0]
|
233
|
+
|
234
|
+
assert num_experts <= BLOCK_SIZE_N
|
235
|
+
assert hidden_dim % BLOCK_SIZE_K == 0
|
236
|
+
assert topk == 1
|
237
|
+
|
238
|
+
topk_weights = torch.empty((bs, topk), dtype=torch.float32, device=x.device)
|
239
|
+
topk_ids = torch.empty((bs, topk), dtype=torch.int32, device=x.device)
|
240
|
+
|
241
|
+
grid = (triton.cdiv(bs, BLOCK_SIZE_M) * triton.cdiv(num_experts, BLOCK_SIZE_N),)
|
242
|
+
|
243
|
+
fused_moe_router_large_bs_kernel[grid](
|
244
|
+
a_ptr=x,
|
245
|
+
b_ptr=router_weight,
|
246
|
+
topk_weights_ptr=topk_weights,
|
247
|
+
topk_ids_ptr=topk_ids,
|
248
|
+
bs=bs,
|
249
|
+
num_experts=num_experts,
|
250
|
+
topk=topk,
|
251
|
+
moe_softcapping=moe_softcapping,
|
252
|
+
moe_renormalize=False,
|
253
|
+
K=hidden_dim,
|
254
|
+
BLOCK_SIZE_M=BLOCK_SIZE_M,
|
255
|
+
BLOCK_SIZE_N=BLOCK_SIZE_N,
|
256
|
+
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
257
|
+
stride_am=hidden_dim,
|
258
|
+
stride_bn=hidden_dim,
|
259
|
+
)
|
260
|
+
|
261
|
+
return topk_weights, topk_ids
|
262
|
+
|
263
|
+
|
264
|
+
def fused_moe_router_shim(
|
265
|
+
moe_softcapping,
|
266
|
+
hidden_states,
|
267
|
+
gating_output,
|
268
|
+
topk,
|
269
|
+
renormalize,
|
270
|
+
):
|
271
|
+
assert not renormalize
|
272
|
+
assert (
|
273
|
+
len(hidden_states.shape) == 2
|
274
|
+
and hidden_states.shape[1] == gating_output.shape[1]
|
275
|
+
)
|
276
|
+
bs, hidden_dim = hidden_states.shape
|
277
|
+
num_experts = gating_output.shape[0]
|
278
|
+
BLOCK_SIZE_M = 32
|
279
|
+
BLOCK_SIZE_N = 16
|
280
|
+
BLOCK_SIZE_K = 256
|
281
|
+
if (
|
282
|
+
bs >= 512
|
283
|
+
and topk == 1
|
284
|
+
and num_experts <= BLOCK_SIZE_N
|
285
|
+
and hidden_dim % BLOCK_SIZE_K == 0
|
286
|
+
):
|
287
|
+
return fused_moe_router_large_bs_impl(
|
288
|
+
x=hidden_states,
|
289
|
+
router_weight=gating_output,
|
290
|
+
topk=topk,
|
291
|
+
moe_softcapping=moe_softcapping,
|
292
|
+
BLOCK_SIZE_M=BLOCK_SIZE_M,
|
293
|
+
BLOCK_SIZE_N=BLOCK_SIZE_N,
|
294
|
+
BLOCK_SIZE_K=BLOCK_SIZE_K,
|
295
|
+
)
|
296
|
+
else:
|
297
|
+
return fused_moe_router_impl(
|
298
|
+
x=hidden_states,
|
299
|
+
router_weight=gating_output,
|
300
|
+
topk=topk,
|
301
|
+
moe_softcapping=moe_softcapping,
|
302
|
+
)
|
303
|
+
|
304
|
+
|
305
|
+
class FusedMoeRouter:
|
306
|
+
def __init__(self, router_linear, topk, moe_softcapping) -> None:
|
307
|
+
self.router_linear = router_linear
|
308
|
+
self.topk = topk
|
309
|
+
self.moe_softcapping = moe_softcapping
|
310
|
+
|
311
|
+
def __call__(self, *args, **kwargs):
|
312
|
+
return self.forward(*args, **kwargs)
|
313
|
+
|
314
|
+
def forward(
|
315
|
+
self, x: torch.Tensor, residual: torch.Tensor
|
316
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
317
|
+
if x.is_cuda:
|
318
|
+
return self.forward_cuda(x, residual)
|
319
|
+
else:
|
320
|
+
return self.forward_vllm(x, residual)
|
321
|
+
|
322
|
+
def forward_cuda(
|
323
|
+
self, x: torch.Tensor, autotune=False
|
324
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
325
|
+
return fused_moe_router_shim(
|
326
|
+
moe_softcapping=self.moe_softcapping,
|
327
|
+
hidden_states=x,
|
328
|
+
gating_output=self.router_linear.weight,
|
329
|
+
topk=self.topk,
|
330
|
+
renormalize=False,
|
331
|
+
)
|
332
|
+
|
333
|
+
def forward_vllm(
|
334
|
+
self,
|
335
|
+
x: torch.Tensor,
|
336
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
337
|
+
# g, _ = self.router_linear.forward(x)
|
338
|
+
g = x.float() @ self.router_linear.weight.T.float()
|
339
|
+
|
340
|
+
g = torch.tanh(g.float() / self.moe_softcapping) * self.moe_softcapping
|
341
|
+
|
342
|
+
return fused_topk(x, g, self.topk, False)
|
sglang/srt/layers/parameter.py
CHANGED
@@ -16,6 +16,7 @@ __all__ = [
|
|
16
16
|
"ModelWeightParameter",
|
17
17
|
"ChannelQuantScaleParameter",
|
18
18
|
"GroupQuantScaleParameter",
|
19
|
+
"BlockQuantScaleParameter",
|
19
20
|
"PackedColumnParameter",
|
20
21
|
"RowvLLMParameter",
|
21
22
|
]
|
@@ -221,6 +222,15 @@ class ChannelQuantScaleParameter(_ColumnvLLMParameter):
|
|
221
222
|
pass
|
222
223
|
|
223
224
|
|
225
|
+
class BlockQuantScaleParameter(_ColumnvLLMParameter, RowvLLMParameter):
|
226
|
+
"""
|
227
|
+
Parameter class for weight scales loaded for weights with
|
228
|
+
block-wise quantization. Uses both column and row parallelism.
|
229
|
+
"""
|
230
|
+
|
231
|
+
pass
|
232
|
+
|
233
|
+
|
224
234
|
class PerTensorScaleParameter(BasevLLMParameter):
|
225
235
|
"""
|
226
236
|
Parameter class for scales where the number of scales is
|
@@ -1,4 +1,6 @@
|
|
1
1
|
# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/__init__.py
|
2
|
+
import builtins
|
3
|
+
import inspect
|
2
4
|
import re
|
3
5
|
from copy import deepcopy
|
4
6
|
from typing import Callable, Dict, Optional, Type, Union
|
@@ -6,10 +8,7 @@ from typing import Callable, Dict, Optional, Type, Union
|
|
6
8
|
import torch
|
7
9
|
from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
|
8
10
|
from vllm.model_executor.layers.quantization.awq import AWQConfig
|
9
|
-
from vllm.model_executor.layers.quantization.awq_marlin import
|
10
|
-
AWQMarlinConfig,
|
11
|
-
AWQMoEMethod,
|
12
|
-
)
|
11
|
+
from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig
|
13
12
|
from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig
|
14
13
|
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import (
|
15
14
|
CompressedTensorsConfig,
|
@@ -28,6 +27,7 @@ from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config
|
|
28
27
|
from sglang.srt.layers.quantization.fp8 import Fp8Config
|
29
28
|
from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
|
30
29
|
from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp8Config
|
30
|
+
from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
|
31
31
|
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
|
32
32
|
|
33
33
|
QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
@@ -50,6 +50,7 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
|
50
50
|
"qqq": QQQConfig,
|
51
51
|
"experts_int8": ExpertsInt8Config,
|
52
52
|
"w8a8_int8": W8A8Int8Config,
|
53
|
+
"w8a8_fp8": W8A8Fp8Config,
|
53
54
|
}
|
54
55
|
|
55
56
|
|
@@ -178,96 +179,117 @@ def gptq_get_quant_method(self, layer, prefix):
|
|
178
179
|
return None
|
179
180
|
|
180
181
|
|
181
|
-
|
182
|
-
from vllm.model_executor.layers.quantization.awq import is_layer_skipped_awq
|
183
|
-
from vllm.model_executor.layers.quantization.awq_marlin import (
|
184
|
-
AWQMarlinLinearMethod,
|
185
|
-
AWQMoEMethod,
|
186
|
-
)
|
182
|
+
original_isinstance = builtins.isinstance
|
187
183
|
|
188
|
-
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
|
189
|
-
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
190
|
-
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
191
184
|
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
return AWQMarlinLinearMethod(self)
|
198
|
-
elif isinstance(layer, FusedMoE):
|
199
|
-
return AWQMoEMethod(self)
|
200
|
-
return None
|
185
|
+
def monkey_patch_isinstance_for_vllm_base_layer(reverse: bool = False):
|
186
|
+
"""
|
187
|
+
Patch isinstance so that the `get_quant_method` in vllm's QuantizationConfig
|
188
|
+
can recognize sglang layers
|
189
|
+
"""
|
201
190
|
|
191
|
+
if reverse:
|
192
|
+
builtins.isinstance = original_isinstance
|
193
|
+
return
|
202
194
|
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
def awq_moe_method_apply(
|
207
|
-
self,
|
208
|
-
layer: torch.nn.Module,
|
209
|
-
x: torch.Tensor,
|
210
|
-
router_logits: torch.Tensor,
|
211
|
-
top_k: int,
|
212
|
-
renormalize: bool,
|
213
|
-
use_grouped_topk: bool = False,
|
214
|
-
topk_group: Optional[int] = None,
|
215
|
-
num_expert_group: Optional[int] = None,
|
216
|
-
custom_routing_function: Optional[Callable] = None,
|
217
|
-
scoring_func: str = "softmax",
|
218
|
-
e_score_correction_bias: Optional[torch.Tensor] = None,
|
219
|
-
**kwargs,
|
220
|
-
):
|
221
|
-
return original_awq_moe_method_apply(
|
222
|
-
self,
|
223
|
-
layer,
|
224
|
-
x,
|
225
|
-
router_logits,
|
226
|
-
top_k,
|
227
|
-
renormalize,
|
228
|
-
use_grouped_topk,
|
229
|
-
topk_group,
|
230
|
-
num_expert_group,
|
231
|
-
custom_routing_function,
|
232
|
-
scoring_func,
|
233
|
-
e_score_correction_bias,
|
234
|
-
)
|
235
|
-
|
236
|
-
|
237
|
-
def patch_vllm_linear_base_isinstance():
|
238
|
-
import builtins
|
239
|
-
|
195
|
+
from vllm.model_executor.layers.fused_moe import FusedMoE
|
240
196
|
from vllm.model_executor.layers.linear import LinearBase
|
197
|
+
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
198
|
+
VocabParallelEmbedding,
|
199
|
+
)
|
241
200
|
|
242
201
|
from sglang.srt.layers.linear import LinearBase as PatchedLinearBase
|
243
|
-
|
244
|
-
|
202
|
+
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE as PatchedFusedMoE
|
203
|
+
from sglang.srt.layers.vocab_parallel_embedding import (
|
204
|
+
VocabParallelEmbedding as PatchedVocabParallelEmbedding,
|
205
|
+
)
|
245
206
|
|
246
207
|
def patched_isinstance(obj, classinfo):
|
247
208
|
if classinfo is LinearBase:
|
248
209
|
return original_isinstance(obj, PatchedLinearBase)
|
210
|
+
if classinfo is FusedMoE:
|
211
|
+
return original_isinstance(obj, PatchedFusedMoE)
|
212
|
+
if classinfo is VocabParallelEmbedding:
|
213
|
+
return original_isinstance(obj, PatchedVocabParallelEmbedding)
|
249
214
|
return original_isinstance(obj, classinfo)
|
250
215
|
|
251
216
|
builtins.isinstance = patched_isinstance
|
252
217
|
|
253
218
|
|
254
|
-
def
|
219
|
+
def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
|
220
|
+
"""
|
221
|
+
Monkey patch the apply function of vllm's FusedMoEMethodBase.
|
222
|
+
Convert sglang arguments to vllm arguments.
|
223
|
+
"""
|
224
|
+
original_apply = class_obj.apply
|
225
|
+
sig = inspect.signature(original_apply)
|
226
|
+
param_names = list(sig.parameters.keys())
|
227
|
+
has_correction_bias = "e_score_correction_bias" in param_names
|
228
|
+
|
229
|
+
def new_apply(
|
230
|
+
self,
|
231
|
+
layer: torch.nn.Module,
|
232
|
+
x: torch.Tensor,
|
233
|
+
router_logits: torch.Tensor,
|
234
|
+
top_k: int,
|
235
|
+
renormalize: bool,
|
236
|
+
use_grouped_topk: bool,
|
237
|
+
topk_group: Optional[int] = None,
|
238
|
+
num_expert_group: Optional[int] = None,
|
239
|
+
custom_routing_function: Optional[Callable] = None,
|
240
|
+
correction_bias: Optional[torch.Tensor] = None,
|
241
|
+
activation: str = "silu",
|
242
|
+
inplace: bool = True,
|
243
|
+
no_combine: bool = False,
|
244
|
+
):
|
245
|
+
assert activation == "silu"
|
246
|
+
assert inplace and not no_combine
|
247
|
+
|
248
|
+
kwargs = {
|
249
|
+
"self": self,
|
250
|
+
"layer": layer,
|
251
|
+
"x": x,
|
252
|
+
"router_logits": router_logits,
|
253
|
+
"top_k": top_k,
|
254
|
+
"renormalize": renormalize,
|
255
|
+
"use_grouped_topk": use_grouped_topk,
|
256
|
+
"topk_group": topk_group,
|
257
|
+
"num_expert_group": num_expert_group,
|
258
|
+
"custom_routing_function": custom_routing_function,
|
259
|
+
}
|
260
|
+
if correction_bias is not None:
|
261
|
+
if not has_correction_bias:
|
262
|
+
raise ValueError(
|
263
|
+
"Please increase the version of your vllm. Try `pip install vllm==0.7.2`"
|
264
|
+
)
|
265
|
+
kwargs["e_score_correction_bias"] = correction_bias
|
266
|
+
return original_apply(**kwargs)
|
267
|
+
|
268
|
+
setattr(class_obj, "apply", new_apply)
|
269
|
+
|
270
|
+
|
271
|
+
def monkey_patch_quant_configs():
|
255
272
|
"""Apply all monkey patches in one place."""
|
256
273
|
from vllm.model_executor.layers.quantization.awq_marlin import AWQMoEMethod
|
274
|
+
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import (
|
275
|
+
CompressedTensorsW8A8Fp8MoEMethod,
|
276
|
+
CompressedTensorsWNA16MoEMethod,
|
277
|
+
)
|
278
|
+
from vllm.model_executor.layers.quantization.gptq_marlin import GPTQMarlinMoEMethod
|
257
279
|
|
258
280
|
setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
|
259
281
|
setattr(GPTQConfig, "get_quant_method", gptq_get_quant_method)
|
260
|
-
|
261
|
-
|
282
|
+
|
283
|
+
monkey_patch_moe_apply(AWQMoEMethod)
|
284
|
+
monkey_patch_moe_apply(GPTQMarlinMoEMethod)
|
285
|
+
monkey_patch_moe_apply(CompressedTensorsW8A8Fp8MoEMethod)
|
286
|
+
monkey_patch_moe_apply(CompressedTensorsWNA16MoEMethod)
|
262
287
|
|
263
288
|
|
264
|
-
|
265
|
-
# Apply patches when module is imported
|
266
|
-
apply_monkey_patches()
|
289
|
+
monkey_patch_quant_configs()
|
267
290
|
|
268
291
|
|
269
292
|
__all__ = [
|
270
|
-
"QuantizationConfig",
|
271
293
|
"get_quantization_config",
|
272
294
|
"QUANTIZATION_METHODS",
|
273
295
|
]
|
@@ -13,12 +13,11 @@ from sglang.srt.layers.linear import (
|
|
13
13
|
LinearMethodBase,
|
14
14
|
UnquantizedLinearMethod,
|
15
15
|
)
|
16
|
-
from sglang.srt.layers.parameter import
|
16
|
+
from sglang.srt.layers.parameter import BlockQuantScaleParameter, ModelWeightParameter
|
17
17
|
from sglang.srt.layers.quantization.base_config import (
|
18
18
|
QuantizationConfig,
|
19
19
|
QuantizeMethodBase,
|
20
20
|
)
|
21
|
-
from sglang.srt.layers.quantization.fp8_utils import BlockQuantScaleParameter
|
22
21
|
from sglang.srt.layers.quantization.int8_utils import apply_w8a8_block_int8_linear
|
23
22
|
from sglang.srt.utils import set_weight_attrs
|
24
23
|
|