sglang 0.4.8.post1__py3-none-any.whl → 0.4.9.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +17 -2
- sglang/bench_serving.py +170 -24
- sglang/srt/configs/internvl.py +4 -2
- sglang/srt/configs/janus_pro.py +1 -1
- sglang/srt/configs/model_config.py +60 -1
- sglang/srt/configs/update_config.py +119 -0
- sglang/srt/conversation.py +69 -1
- sglang/srt/disaggregation/decode.py +21 -5
- sglang/srt/disaggregation/mooncake/conn.py +35 -4
- sglang/srt/disaggregation/nixl/conn.py +6 -6
- sglang/srt/disaggregation/prefill.py +2 -2
- sglang/srt/disaggregation/utils.py +1 -1
- sglang/srt/distributed/parallel_state.py +44 -17
- sglang/srt/entrypoints/EngineBase.py +8 -0
- sglang/srt/entrypoints/engine.py +40 -6
- sglang/srt/entrypoints/http_server.py +111 -24
- sglang/srt/entrypoints/http_server_engine.py +1 -1
- sglang/srt/entrypoints/openai/protocol.py +4 -2
- sglang/srt/eplb/__init__.py +0 -0
- sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
- sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
- sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
- sglang/srt/{managers → eplb}/expert_distribution.py +1 -5
- sglang/srt/{managers → eplb}/expert_location.py +1 -1
- sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
- sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
- sglang/srt/hf_transformers_utils.py +2 -1
- sglang/srt/layers/activation.py +2 -2
- sglang/srt/layers/amx_utils.py +86 -0
- sglang/srt/layers/attention/ascend_backend.py +219 -0
- sglang/srt/layers/attention/flashattention_backend.py +32 -9
- sglang/srt/layers/attention/tbo_backend.py +37 -9
- sglang/srt/layers/communicator.py +20 -2
- sglang/srt/layers/dp_attention.py +9 -3
- sglang/srt/layers/elementwise.py +76 -12
- sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
- sglang/srt/layers/layernorm.py +26 -0
- sglang/srt/layers/linear.py +84 -14
- sglang/srt/layers/logits_processor.py +4 -4
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +215 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +81 -8
- sglang/srt/layers/moe/ep_moe/layer.py +176 -15
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +23 -17
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +3 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +211 -74
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +176 -0
- sglang/srt/layers/moe/router.py +60 -22
- sglang/srt/layers/moe/topk.py +10 -28
- sglang/srt/layers/parameter.py +67 -7
- sglang/srt/layers/quantization/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
- sglang/srt/layers/quantization/fp8.py +72 -7
- sglang/srt/layers/quantization/fp8_kernel.py +1 -1
- sglang/srt/layers/quantization/fp8_utils.py +1 -2
- sglang/srt/layers/quantization/gptq.py +5 -1
- sglang/srt/layers/quantization/modelopt_quant.py +244 -1
- sglang/srt/layers/quantization/moe_wna16.py +1 -1
- sglang/srt/layers/quantization/quant_utils.py +166 -0
- sglang/srt/layers/quantization/w4afp8.py +264 -0
- sglang/srt/layers/quantization/w8a8_int8.py +52 -1
- sglang/srt/layers/rotary_embedding.py +2 -2
- sglang/srt/layers/vocab_parallel_embedding.py +20 -10
- sglang/srt/lora/lora.py +4 -5
- sglang/srt/lora/lora_manager.py +73 -20
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +30 -19
- sglang/srt/lora/triton_ops/qkv_lora_b.py +30 -19
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +27 -11
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +27 -15
- sglang/srt/managers/cache_controller.py +41 -195
- sglang/srt/managers/configure_logging.py +1 -1
- sglang/srt/managers/io_struct.py +58 -14
- sglang/srt/managers/mm_utils.py +77 -61
- sglang/srt/managers/multimodal_processor.py +2 -6
- sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
- sglang/srt/managers/schedule_batch.py +78 -85
- sglang/srt/managers/scheduler.py +130 -64
- sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
- sglang/srt/managers/session_controller.py +12 -3
- sglang/srt/managers/tokenizer_manager.py +314 -103
- sglang/srt/managers/tp_worker.py +13 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
- sglang/srt/mem_cache/allocator.py +290 -0
- sglang/srt/mem_cache/chunk_cache.py +34 -2
- sglang/srt/mem_cache/hiradix_cache.py +2 -0
- sglang/srt/mem_cache/memory_pool.py +402 -66
- sglang/srt/mem_cache/memory_pool_host.py +6 -109
- sglang/srt/mem_cache/multimodal_cache.py +3 -0
- sglang/srt/mem_cache/radix_cache.py +8 -4
- sglang/srt/model_executor/cuda_graph_runner.py +2 -1
- sglang/srt/model_executor/forward_batch_info.py +17 -4
- sglang/srt/model_executor/model_runner.py +297 -56
- sglang/srt/model_loader/loader.py +41 -0
- sglang/srt/model_loader/weight_utils.py +72 -4
- sglang/srt/models/deepseek_nextn.py +1 -3
- sglang/srt/models/deepseek_v2.py +195 -45
- sglang/srt/models/deepseek_vl2.py +3 -5
- sglang/srt/models/gemma3_causal.py +1 -2
- sglang/srt/models/gemma3n_causal.py +4 -3
- sglang/srt/models/gemma3n_mm.py +4 -20
- sglang/srt/models/hunyuan.py +1 -1
- sglang/srt/models/kimi_vl.py +1 -2
- sglang/srt/models/llama.py +10 -4
- sglang/srt/models/llama4.py +32 -45
- sglang/srt/models/llama_eagle3.py +61 -11
- sglang/srt/models/llava.py +5 -5
- sglang/srt/models/minicpmo.py +2 -2
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mllama4.py +402 -89
- sglang/srt/models/phi4mm.py +1 -3
- sglang/srt/models/pixtral.py +3 -7
- sglang/srt/models/qwen2.py +31 -3
- sglang/srt/models/qwen2_5_vl.py +1 -3
- sglang/srt/models/qwen2_audio.py +200 -0
- sglang/srt/models/qwen2_moe.py +32 -6
- sglang/srt/models/qwen2_vl.py +1 -4
- sglang/srt/models/qwen3.py +94 -25
- sglang/srt/models/qwen3_moe.py +68 -21
- sglang/srt/models/vila.py +3 -8
- sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +2 -2
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +140 -158
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3n.py +5 -20
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +65 -66
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
- sglang/srt/operations_strategy.py +6 -2
- sglang/srt/reasoning_parser.py +26 -0
- sglang/srt/sampling/sampling_batch_info.py +39 -1
- sglang/srt/server_args.py +84 -22
- sglang/srt/speculative/build_eagle_tree.py +57 -18
- sglang/srt/speculative/eagle_worker.py +6 -4
- sglang/srt/two_batch_overlap.py +203 -27
- sglang/srt/utils.py +343 -163
- sglang/srt/warmup.py +12 -3
- sglang/test/runners.py +10 -1
- sglang/test/test_cutlass_w4a8_moe.py +281 -0
- sglang/test/test_utils.py +15 -3
- sglang/utils.py +5 -5
- sglang/version.py +1 -1
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/METADATA +12 -8
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/RECORD +157 -146
- sglang/math_utils.py +0 -8
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
- /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/top_level.txt +0 -0
sglang/srt/layers/moe/router.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import Tuple
|
1
|
+
from typing import Optional, Tuple
|
2
2
|
|
3
3
|
import torch
|
4
4
|
import triton
|
@@ -16,6 +16,8 @@ def fused_moe_router_kernel(
|
|
16
16
|
moe_router_weight_ptr, # input (num_experts, hidden_dim)
|
17
17
|
topk_weights_ptr, # output (bs, topk)
|
18
18
|
topk_ids_ptr, # output (bs, topk)
|
19
|
+
correction_bias_ptr,
|
20
|
+
is_correction_bias: tl.constexpr,
|
19
21
|
num_experts: tl.constexpr,
|
20
22
|
topk: tl.constexpr,
|
21
23
|
moe_softcapping: tl.constexpr,
|
@@ -49,6 +51,11 @@ def fused_moe_router_kernel(
|
|
49
51
|
bottom = exped + 1
|
50
52
|
logits_softcapped = top / bottom * moe_softcapping
|
51
53
|
|
54
|
+
# Add bias after softcapping
|
55
|
+
if is_correction_bias:
|
56
|
+
bias = tl.load(correction_bias_ptr + tl.arange(0, num_experts))
|
57
|
+
logits_softcapped = logits_softcapped + bias
|
58
|
+
|
52
59
|
# topk
|
53
60
|
# assert 1 <= topk <= num_experts
|
54
61
|
|
@@ -109,6 +116,7 @@ def fused_moe_router_impl(
|
|
109
116
|
router_weight: torch.Tensor,
|
110
117
|
topk: int,
|
111
118
|
moe_softcapping: float,
|
119
|
+
correction_bias: Optional[torch.Tensor] = None,
|
112
120
|
):
|
113
121
|
assert len(x.shape) == 2 and x.shape[1] == router_weight.shape[1]
|
114
122
|
bs, hidden_dim = x.shape
|
@@ -117,23 +125,23 @@ def fused_moe_router_impl(
|
|
117
125
|
# router_logits = torch.empty((bs, num_experts), dtype=torch.float32, device=x.device)
|
118
126
|
topk_weights = torch.empty((bs, topk), dtype=torch.float32, device=x.device)
|
119
127
|
topk_ids = torch.empty((bs, topk), dtype=torch.int32, device=x.device)
|
128
|
+
is_correction_bias = correction_bias is not None
|
120
129
|
|
121
|
-
|
122
|
-
|
123
|
-
min_num_warps = 16 if _is_hip else 32
|
124
|
-
|
130
|
+
max_warps = 16 if _is_hip else 32
|
125
131
|
config = {
|
126
132
|
"BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
|
127
133
|
"num_warps": max(
|
128
|
-
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)),
|
134
|
+
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), max_warps), 4
|
129
135
|
),
|
130
136
|
}
|
131
137
|
|
132
|
-
fused_moe_router_kernel[
|
138
|
+
fused_moe_router_kernel[(bs,)](
|
133
139
|
x,
|
134
140
|
router_weight,
|
135
141
|
topk_weights,
|
136
142
|
topk_ids,
|
143
|
+
correction_bias,
|
144
|
+
is_correction_bias=is_correction_bias,
|
137
145
|
num_experts=num_experts,
|
138
146
|
topk=topk,
|
139
147
|
moe_softcapping=moe_softcapping,
|
@@ -153,7 +161,7 @@ def fused_moe_router_large_bs_kernel(
|
|
153
161
|
topk_ids_ptr, # output (bs, topk)
|
154
162
|
bs,
|
155
163
|
num_experts: tl.constexpr,
|
156
|
-
topk: tl.constexpr, # only support topk
|
164
|
+
topk: tl.constexpr, # only support topk <= 2
|
157
165
|
moe_softcapping: tl.constexpr,
|
158
166
|
moe_renormalize: tl.constexpr, # not supported
|
159
167
|
K: tl.constexpr,
|
@@ -204,25 +212,53 @@ def fused_moe_router_large_bs_kernel(
|
|
204
212
|
logits_softcapped = (exped - 1) / (exped + 1) * moe_softcapping
|
205
213
|
|
206
214
|
# 5. top1
|
207
|
-
|
208
|
-
|
215
|
+
arange_block_size_n = tl.arange(0, BLOCK_SIZE_N)[None, :]
|
216
|
+
cond_top1 = arange_block_size_n < num_experts
|
217
|
+
top1 = tl.argmax(tl.where(cond_top1, logits_softcapped, float("-inf")), axis=1)
|
209
218
|
top1_v = tl.max(
|
210
|
-
tl.where(
|
219
|
+
tl.where(cond_top1, logits_softcapped, float("-inf")), axis=1, keep_dims=True
|
211
220
|
)
|
212
|
-
|
213
|
-
tl.where(
|
221
|
+
top1_invsumexp = 1.0 / tl.sum(
|
222
|
+
tl.where(cond_top1, tl.exp(logits_softcapped - top1_v), 0.0), axis=1
|
214
223
|
)
|
215
224
|
|
216
|
-
# 6. store to output
|
217
|
-
|
218
|
-
|
219
|
-
tl.store(topk_ids_ptr +
|
225
|
+
# 6. store top1 to output
|
226
|
+
offs_top1 = pid * topk * BLOCK_SIZE_M + topk * tl.arange(0, BLOCK_SIZE_M)
|
227
|
+
top1_mask = offs_top1 < bs * topk
|
228
|
+
tl.store(topk_ids_ptr + offs_top1, top1, mask=top1_mask)
|
220
229
|
tl.store(
|
221
|
-
topk_weights_ptr +
|
222
|
-
|
223
|
-
mask=
|
230
|
+
topk_weights_ptr + offs_top1,
|
231
|
+
top1_invsumexp,
|
232
|
+
mask=top1_mask,
|
224
233
|
)
|
225
234
|
|
235
|
+
# 7. handle topk == 2
|
236
|
+
if topk == 2:
|
237
|
+
cond_top2 = (arange_block_size_n < num_experts) and (
|
238
|
+
arange_block_size_n != top1[:, None]
|
239
|
+
)
|
240
|
+
top2 = tl.argmax(
|
241
|
+
tl.where(cond_top2, logits_softcapped, float("-inf")),
|
242
|
+
axis=1,
|
243
|
+
keep_dims=True,
|
244
|
+
)
|
245
|
+
top2_v = tl.sum(
|
246
|
+
logits_softcapped * (arange_block_size_n == top2), axis=1, keep_dims=True
|
247
|
+
)
|
248
|
+
top2_invsumexp = tl.exp(top2_v - top1_v) * top1_invsumexp[:, None]
|
249
|
+
|
250
|
+
# store top2
|
251
|
+
offs_top2 = (
|
252
|
+
pid * topk * BLOCK_SIZE_M + topk * tl.arange(0, BLOCK_SIZE_M)[:, None] + 1
|
253
|
+
)
|
254
|
+
top2_mask = offs_top2 < bs * topk
|
255
|
+
tl.store(topk_ids_ptr + offs_top2, top2, mask=top2_mask)
|
256
|
+
tl.store(
|
257
|
+
topk_weights_ptr + offs_top2,
|
258
|
+
top2_invsumexp,
|
259
|
+
mask=top2_mask,
|
260
|
+
)
|
261
|
+
|
226
262
|
|
227
263
|
def fused_moe_router_large_bs_impl(
|
228
264
|
x: torch.Tensor,
|
@@ -239,7 +275,7 @@ def fused_moe_router_large_bs_impl(
|
|
239
275
|
|
240
276
|
assert num_experts <= BLOCK_SIZE_N
|
241
277
|
assert hidden_dim % BLOCK_SIZE_K == 0
|
242
|
-
assert topk
|
278
|
+
assert topk <= 2
|
243
279
|
|
244
280
|
topk_weights = torch.empty((bs, topk), dtype=torch.float32, device=x.device)
|
245
281
|
topk_ids = torch.empty((bs, topk), dtype=torch.int32, device=x.device)
|
@@ -273,6 +309,7 @@ def fused_moe_router_shim(
|
|
273
309
|
gating_output,
|
274
310
|
topk,
|
275
311
|
renormalize,
|
312
|
+
correction_bias: Optional[torch.Tensor] = None,
|
276
313
|
):
|
277
314
|
assert not renormalize
|
278
315
|
assert (
|
@@ -286,7 +323,7 @@ def fused_moe_router_shim(
|
|
286
323
|
BLOCK_SIZE_K = 256
|
287
324
|
if (
|
288
325
|
bs >= 512
|
289
|
-
and topk
|
326
|
+
and topk <= 2
|
290
327
|
and num_experts <= BLOCK_SIZE_N
|
291
328
|
and hidden_dim % BLOCK_SIZE_K == 0
|
292
329
|
):
|
@@ -305,6 +342,7 @@ def fused_moe_router_shim(
|
|
305
342
|
router_weight=gating_output,
|
306
343
|
topk=topk,
|
307
344
|
moe_softcapping=moe_softcapping,
|
345
|
+
correction_bias=correction_bias,
|
308
346
|
)
|
309
347
|
|
310
348
|
|
sglang/srt/layers/moe/topk.py
CHANGED
@@ -18,12 +18,12 @@ from typing import Callable, Optional
|
|
18
18
|
import torch
|
19
19
|
import torch.nn.functional as F
|
20
20
|
|
21
|
-
from sglang.srt.
|
22
|
-
from sglang.srt.
|
21
|
+
from sglang.srt.eplb import expert_location_dispatch
|
22
|
+
from sglang.srt.eplb.expert_distribution import (
|
23
23
|
ExpertDistributionRecorder,
|
24
24
|
get_global_expert_distribution_recorder,
|
25
25
|
)
|
26
|
-
from sglang.srt.
|
26
|
+
from sglang.srt.eplb.expert_location_dispatch import (
|
27
27
|
ExpertLocationDispatchInfo,
|
28
28
|
topk_ids_logical_to_physical,
|
29
29
|
)
|
@@ -35,6 +35,7 @@ from sglang.srt.utils import (
|
|
35
35
|
is_cpu,
|
36
36
|
is_cuda,
|
37
37
|
is_hip,
|
38
|
+
is_npu,
|
38
39
|
)
|
39
40
|
|
40
41
|
_is_cuda = is_cuda()
|
@@ -42,6 +43,7 @@ _is_hip = is_hip()
|
|
42
43
|
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
43
44
|
_is_cpu_amx_available = cpu_has_amx_support()
|
44
45
|
_is_cpu = is_cpu()
|
46
|
+
_is_npu = is_npu()
|
45
47
|
|
46
48
|
if _is_cuda:
|
47
49
|
from sgl_kernel import moe_fused_gate
|
@@ -106,37 +108,14 @@ def fused_topk(
|
|
106
108
|
M, topk, dtype=torch.float32, device=hidden_states.device
|
107
109
|
)
|
108
110
|
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
|
109
|
-
token_expert_indicies = torch.empty(
|
110
|
-
M, topk, dtype=torch.int32, device=hidden_states.device
|
111
|
-
)
|
112
111
|
|
113
112
|
topk_softmax(
|
114
113
|
topk_weights,
|
115
114
|
topk_ids,
|
116
|
-
|
117
|
-
|
118
|
-
)
|
119
|
-
del token_expert_indicies
|
120
|
-
|
121
|
-
return _fused_topk_postprocess(
|
122
|
-
topk_weights=topk_weights,
|
123
|
-
topk_ids=topk_ids,
|
124
|
-
renormalize=renormalize,
|
125
|
-
expert_location_dispatch_info=expert_location_dispatch_info,
|
126
|
-
num_token_non_padded=num_token_non_padded,
|
115
|
+
gating_output,
|
116
|
+
renormalize,
|
127
117
|
)
|
128
118
|
|
129
|
-
|
130
|
-
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
131
|
-
def _fused_topk_postprocess(
|
132
|
-
topk_weights,
|
133
|
-
topk_ids,
|
134
|
-
renormalize,
|
135
|
-
expert_location_dispatch_info,
|
136
|
-
num_token_non_padded,
|
137
|
-
):
|
138
|
-
if renormalize:
|
139
|
-
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
140
119
|
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
|
141
120
|
_mask_topk_ids_padded_region(topk_ids, num_token_non_padded)
|
142
121
|
return topk_weights, topk_ids
|
@@ -159,6 +138,9 @@ def grouped_topk_gpu(
|
|
159
138
|
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
160
139
|
|
161
140
|
scores = torch.softmax(gating_output, dim=-1)
|
141
|
+
# NPU compiler limitation
|
142
|
+
if _is_npu and scores.dtype == torch.bfloat16:
|
143
|
+
scores = scores.to(torch.float16)
|
162
144
|
num_token = scores.shape[0]
|
163
145
|
num_experts = scores.shape[1]
|
164
146
|
group_scores = (
|
sglang/srt/layers/parameter.py
CHANGED
@@ -7,6 +7,8 @@ from typing import Callable, Optional, Union
|
|
7
7
|
import torch
|
8
8
|
from torch.nn import Parameter
|
9
9
|
|
10
|
+
from sglang.srt.utils import is_cpu
|
11
|
+
|
10
12
|
__all__ = [
|
11
13
|
"BasevLLMParameter",
|
12
14
|
"PackedvLLMParameter",
|
@@ -21,6 +23,8 @@ __all__ = [
|
|
21
23
|
|
22
24
|
logger = logging.getLogger(__name__)
|
23
25
|
|
26
|
+
_is_cpu = is_cpu()
|
27
|
+
|
24
28
|
|
25
29
|
class BasevLLMParameter(Parameter):
|
26
30
|
"""
|
@@ -93,9 +97,28 @@ class _ColumnvLLMParameter(BasevLLMParameter):
|
|
93
97
|
):
|
94
98
|
if not use_presharded_weights:
|
95
99
|
shard_size = self.data.shape[self.output_dim]
|
96
|
-
|
97
|
-
|
100
|
+
|
101
|
+
from sglang.srt.model_loader.weight_utils import (
|
102
|
+
narrow_padded_param_and_loaded_weight,
|
98
103
|
)
|
104
|
+
|
105
|
+
if _is_cpu:
|
106
|
+
param_data, loaded_weight = narrow_padded_param_and_loaded_weight(
|
107
|
+
self.data,
|
108
|
+
loaded_weight,
|
109
|
+
0, # param_data_start
|
110
|
+
tp_rank * shard_size,
|
111
|
+
self.output_dim,
|
112
|
+
shard_size,
|
113
|
+
)
|
114
|
+
assert param_data.shape == loaded_weight.shape
|
115
|
+
param_data.copy_(loaded_weight)
|
116
|
+
return
|
117
|
+
else:
|
118
|
+
loaded_weight = loaded_weight.narrow(
|
119
|
+
self.output_dim, tp_rank * shard_size, shard_size
|
120
|
+
)
|
121
|
+
|
99
122
|
assert self.data.shape == loaded_weight.shape
|
100
123
|
self.data.copy_(loaded_weight)
|
101
124
|
|
@@ -116,10 +139,27 @@ class _ColumnvLLMParameter(BasevLLMParameter):
|
|
116
139
|
param_data = self.data
|
117
140
|
|
118
141
|
param_data = param_data.narrow(self.output_dim, shard_offset, shard_size)
|
119
|
-
|
120
|
-
|
121
|
-
|
142
|
+
|
143
|
+
from sglang.srt.model_loader.weight_utils import (
|
144
|
+
narrow_padded_param_and_loaded_weight,
|
145
|
+
)
|
146
|
+
|
147
|
+
if _is_cpu:
|
148
|
+
param_data, loaded_weight = narrow_padded_param_and_loaded_weight(
|
149
|
+
param_data,
|
150
|
+
loaded_weight,
|
151
|
+
0, # param_data_start
|
152
|
+
tp_rank * shard_size,
|
153
|
+
self.output_dim,
|
154
|
+
shard_size,
|
155
|
+
not use_presharded_weights,
|
122
156
|
)
|
157
|
+
else:
|
158
|
+
if not use_presharded_weights:
|
159
|
+
loaded_weight = loaded_weight.narrow(
|
160
|
+
self.output_dim, tp_rank * shard_size, shard_size
|
161
|
+
)
|
162
|
+
|
123
163
|
assert param_data.shape == loaded_weight.shape
|
124
164
|
param_data.copy_(loaded_weight)
|
125
165
|
|
@@ -182,10 +222,30 @@ class RowvLLMParameter(BasevLLMParameter):
|
|
182
222
|
):
|
183
223
|
if not use_presharded_weights:
|
184
224
|
shard_size = self.data.shape[self.input_dim]
|
185
|
-
|
186
|
-
|
225
|
+
|
226
|
+
from sglang.srt.model_loader.weight_utils import (
|
227
|
+
narrow_padded_param_and_loaded_weight,
|
187
228
|
)
|
188
229
|
|
230
|
+
if _is_cpu:
|
231
|
+
param_data, loaded_weight = narrow_padded_param_and_loaded_weight(
|
232
|
+
self.data,
|
233
|
+
loaded_weight,
|
234
|
+
0, # param_data_start
|
235
|
+
tp_rank * shard_size,
|
236
|
+
self.input_dim,
|
237
|
+
shard_size,
|
238
|
+
)
|
239
|
+
|
240
|
+
assert param_data.shape == loaded_weight.shape
|
241
|
+
param_data.copy_(loaded_weight)
|
242
|
+
|
243
|
+
return
|
244
|
+
else:
|
245
|
+
loaded_weight = loaded_weight.narrow(
|
246
|
+
self.input_dim, tp_rank * shard_size, shard_size
|
247
|
+
)
|
248
|
+
|
189
249
|
if len(loaded_weight.shape) == 0:
|
190
250
|
loaded_weight = loaded_weight.reshape(1)
|
191
251
|
|
@@ -68,6 +68,7 @@ from sglang.srt.layers.quantization.modelopt_quant import (
|
|
68
68
|
)
|
69
69
|
from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config
|
70
70
|
from sglang.srt.layers.quantization.qoq import QoQConfig
|
71
|
+
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config
|
71
72
|
from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
|
72
73
|
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
|
73
74
|
|
@@ -82,6 +83,7 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
|
82
83
|
"moe_wna16": MoeWNA16Config,
|
83
84
|
"compressed-tensors": CompressedTensorsConfig,
|
84
85
|
"qoq": QoQConfig,
|
86
|
+
"w4afp8": W4AFp8Config,
|
85
87
|
}
|
86
88
|
|
87
89
|
# VLLM-dependent quantization methods
|
@@ -76,7 +76,7 @@ class CompressedTensorsW8A16Fp8(CompressedTensorsScheme):
|
|
76
76
|
layer.input_scale = torch.nn.Parameter(
|
77
77
|
layer.input_scale.data, requires_grad=False
|
78
78
|
)
|
79
|
-
prepare_fp8_layer_for_marlin(layer,
|
79
|
+
prepare_fp8_layer_for_marlin(layer, size_k_first=True)
|
80
80
|
|
81
81
|
def create_weights(
|
82
82
|
self,
|
@@ -1,7 +1,7 @@
|
|
1
1
|
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py
|
2
2
|
|
3
3
|
import logging
|
4
|
-
from typing import Any, Callable, Dict, List, Optional
|
4
|
+
from typing import Any, Callable, Dict, List, Optional, Union
|
5
5
|
|
6
6
|
import torch
|
7
7
|
import torch.nn.functional as F
|
@@ -27,6 +27,7 @@ except ImportError:
|
|
27
27
|
|
28
28
|
|
29
29
|
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
30
|
+
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
|
30
31
|
from sglang.srt.layers.linear import (
|
31
32
|
LinearBase,
|
32
33
|
LinearMethodBase,
|
@@ -73,6 +74,7 @@ from sglang.srt.utils import (
|
|
73
74
|
log_info_on_rank0,
|
74
75
|
print_warning_once,
|
75
76
|
set_weight_attrs,
|
77
|
+
use_intel_amx_backend,
|
76
78
|
)
|
77
79
|
|
78
80
|
_is_hip = is_hip()
|
@@ -86,7 +88,7 @@ _is_fp8_fnuz = is_fp8_fnuz()
|
|
86
88
|
_use_hip_int4 = get_bool_env_var("SGLANG_INT4_WEIGHT")
|
87
89
|
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
88
90
|
|
89
|
-
if _is_hip:
|
91
|
+
if _is_hip and (_use_aiter or _use_hip_int4):
|
90
92
|
from aiter import ActivationType, QuantType
|
91
93
|
from aiter.fused_moe import fused_moe
|
92
94
|
from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages
|
@@ -198,7 +200,7 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
198
200
|
quant_config: The quantization config.
|
199
201
|
"""
|
200
202
|
|
201
|
-
def __init__(self, quant_config: Fp8Config):
|
203
|
+
def __init__(self, quant_config: Union["Fp8Config", "W4AFp8Config"]):
|
202
204
|
self.quant_config = quant_config
|
203
205
|
self.cutlass_fp8_supported = cutlass_fp8_supported()
|
204
206
|
|
@@ -284,7 +286,10 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
284
286
|
if self.quant_config.is_checkpoint_fp8_serialized:
|
285
287
|
# WEIGHT SCALE
|
286
288
|
if self.block_quant:
|
287
|
-
|
289
|
+
if hasattr(self.quant_config, "activation_scheme"):
|
290
|
+
assert self.quant_config.activation_scheme == "dynamic"
|
291
|
+
elif hasattr(self.quant_config, "linear_activation_scheme"):
|
292
|
+
assert self.quant_config.linear_activation_scheme == "dynamic"
|
288
293
|
scale = BlockQuantScaleParameter(
|
289
294
|
data=torch.empty(
|
290
295
|
(output_size_per_partition + block_n - 1) // block_n,
|
@@ -306,7 +311,13 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
306
311
|
layer.register_parameter("weight_scale", scale)
|
307
312
|
|
308
313
|
# INPUT ACTIVATION SCALE
|
309
|
-
if
|
314
|
+
if (
|
315
|
+
hasattr(self.quant_config, "activation_scheme")
|
316
|
+
and self.quant_config.activation_scheme == "static"
|
317
|
+
) or (
|
318
|
+
hasattr(self.quant_config, "linear_activation_scheme")
|
319
|
+
and self.quant_config.linear_activation_scheme == "static"
|
320
|
+
):
|
310
321
|
scale = PerTensorScaleParameter(
|
311
322
|
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
312
323
|
weight_loader=weight_loader,
|
@@ -330,6 +341,12 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
330
341
|
)
|
331
342
|
|
332
343
|
layer.input_scale = None
|
344
|
+
elif _is_cpu:
|
345
|
+
assert (
|
346
|
+
_is_cpu_amx_available
|
347
|
+
), "Fp8LinearMethod on CPU requires that CPU has AMX support"
|
348
|
+
_amx_process_weight_after_loading(layer, ["weight"])
|
349
|
+
return
|
333
350
|
else:
|
334
351
|
weight, weight_scale = layer.weight.data, layer.weight_scale_inv.data
|
335
352
|
layer.weight = torch.nn.Parameter(weight, requires_grad=False)
|
@@ -363,7 +380,13 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
363
380
|
layer.weight_scale = torch.nn.Parameter(
|
364
381
|
layer.weight_scale.data, requires_grad=False
|
365
382
|
)
|
366
|
-
if
|
383
|
+
if (
|
384
|
+
hasattr(self.quant_config, "activation_scheme")
|
385
|
+
and self.quant_config.activation_scheme == "static"
|
386
|
+
) or (
|
387
|
+
hasattr(self.quant_config, "linear_activation_scheme")
|
388
|
+
and self.quant_config.linear_activation_scheme == "static"
|
389
|
+
):
|
367
390
|
layer.input_scale = torch.nn.Parameter(
|
368
391
|
layer.input_scale.data, requires_grad=False
|
369
392
|
)
|
@@ -397,7 +420,13 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
397
420
|
# Update layer with new values.
|
398
421
|
layer.weight = Parameter(weight.t(), requires_grad=False)
|
399
422
|
layer.weight_scale = Parameter(weight_scale, requires_grad=False)
|
400
|
-
if
|
423
|
+
if (
|
424
|
+
hasattr(self.quant_config, "activation_scheme")
|
425
|
+
and self.quant_config.activation_scheme == "static"
|
426
|
+
) or (
|
427
|
+
hasattr(self.quant_config, "linear_activation_scheme")
|
428
|
+
and self.quant_config.linear_activation_scheme == "static"
|
429
|
+
):
|
401
430
|
layer.input_scale = Parameter(
|
402
431
|
layer.input_scale.max(), requires_grad=False
|
403
432
|
)
|
@@ -426,6 +455,17 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
426
455
|
)
|
427
456
|
|
428
457
|
if self.block_quant:
|
458
|
+
if use_intel_amx_backend(layer):
|
459
|
+
return torch.ops.sgl_kernel.fp8_scaled_mm_cpu(
|
460
|
+
x,
|
461
|
+
layer.weight,
|
462
|
+
layer.weight_scale_inv,
|
463
|
+
self.quant_config.weight_block_size,
|
464
|
+
bias,
|
465
|
+
x.dtype,
|
466
|
+
True, # is_vnni
|
467
|
+
)
|
468
|
+
|
429
469
|
return self.w8a8_block_fp8_linear(
|
430
470
|
input=x,
|
431
471
|
weight=layer.weight,
|
@@ -746,6 +786,13 @@ class Fp8MoEMethod:
|
|
746
786
|
layer.w2_weight.data = shuffle_weight(
|
747
787
|
layer.w2_weight.contiguous(), (16, 16)
|
748
788
|
)
|
789
|
+
|
790
|
+
if _is_cpu:
|
791
|
+
assert (
|
792
|
+
_is_cpu_amx_available
|
793
|
+
), "Fp8MoEMethod on CPU requires that CPU has AMX support"
|
794
|
+
_amx_process_weight_after_loading(layer, ["w13_weight", "w2_weight"])
|
795
|
+
|
749
796
|
return
|
750
797
|
|
751
798
|
# If checkpoint is fp16 or bfloat16, quantize in place.
|
@@ -971,6 +1018,24 @@ class Fp8MoEMethod:
|
|
971
1018
|
routed_scaling_factor=routed_scaling_factor,
|
972
1019
|
)
|
973
1020
|
|
1021
|
+
if use_intel_amx_backend(layer):
|
1022
|
+
return torch.ops.sgl_kernel.fused_experts_cpu(
|
1023
|
+
x,
|
1024
|
+
layer.w13_weight,
|
1025
|
+
layer.w2_weight,
|
1026
|
+
topk_weights,
|
1027
|
+
topk_ids,
|
1028
|
+
False, # inplace See [Note] inplace should be False in fused_experts.
|
1029
|
+
False, # use_int8_w8a8
|
1030
|
+
True, # use_fp8_w8a16
|
1031
|
+
layer.w13_weight_scale_inv, # w1_scale
|
1032
|
+
layer.w2_weight_scale_inv, # w2_scale
|
1033
|
+
self.quant_config.weight_block_size, # block_size
|
1034
|
+
None, # a1_scale
|
1035
|
+
None, # a2_scale
|
1036
|
+
True, # is_vnni
|
1037
|
+
)
|
1038
|
+
|
974
1039
|
if _is_hip:
|
975
1040
|
ret = self.maybe_apply_hip_fused_experts(
|
976
1041
|
layer,
|
@@ -23,9 +23,9 @@ import torch
|
|
23
23
|
import triton
|
24
24
|
import triton.language as tl
|
25
25
|
|
26
|
-
from sglang.math_utils import align
|
27
26
|
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
28
27
|
from sglang.srt.utils import (
|
28
|
+
align,
|
29
29
|
direct_register_custom_op,
|
30
30
|
get_device_core_count,
|
31
31
|
get_device_name,
|
@@ -1,9 +1,7 @@
|
|
1
1
|
from typing import Callable, List, Optional, Tuple
|
2
2
|
|
3
|
-
import einops
|
4
3
|
import torch
|
5
4
|
|
6
|
-
from sglang.math_utils import align
|
7
5
|
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
8
6
|
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8
|
9
7
|
from sglang.srt.layers.utils import is_sm100_supported
|
@@ -27,6 +25,7 @@ from sglang.srt.layers.quantization.fp8_kernel import (
|
|
27
25
|
w8a8_block_fp8_matmul_triton,
|
28
26
|
)
|
29
27
|
from sglang.srt.utils import (
|
28
|
+
align,
|
30
29
|
get_bool_env_var,
|
31
30
|
get_cuda_version,
|
32
31
|
get_device_capability,
|
@@ -344,6 +344,10 @@ class GPTQMarlinConfig(QuantizationConfig):
|
|
344
344
|
if (num_bits, sym) not in cls.TYPE_MAP:
|
345
345
|
return False
|
346
346
|
|
347
|
+
assert (
|
348
|
+
VLLM_AVAILABLE
|
349
|
+
), "vllm is not installed, to use gptq_marlin, please install vllm"
|
350
|
+
|
347
351
|
return check_marlin_supported(
|
348
352
|
quant_type=cls.TYPE_MAP[(num_bits, sym)], group_size=group_size
|
349
353
|
)
|
@@ -726,6 +730,6 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
|
726
730
|
g_idx2=layer.w2_g_idx,
|
727
731
|
sort_indices1=layer.w13_g_idx_sort_indices,
|
728
732
|
sort_indices2=layer.w2_g_idx_sort_indices,
|
729
|
-
|
733
|
+
quant_type_id=self.quant_config.quant_type.id,
|
730
734
|
is_k_full=self.is_k_full,
|
731
735
|
).to(orig_dtype)
|