sglang 0.4.8__py3-none-any.whl → 0.4.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +17 -2
- sglang/bench_serving.py +168 -22
- sglang/srt/configs/internvl.py +4 -2
- sglang/srt/configs/janus_pro.py +1 -1
- sglang/srt/configs/model_config.py +49 -0
- sglang/srt/configs/update_config.py +119 -0
- sglang/srt/conversation.py +35 -0
- sglang/srt/custom_op.py +7 -1
- sglang/srt/disaggregation/base/conn.py +2 -0
- sglang/srt/disaggregation/decode.py +22 -6
- sglang/srt/disaggregation/mooncake/conn.py +289 -48
- sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
- sglang/srt/disaggregation/nixl/conn.py +100 -52
- sglang/srt/disaggregation/prefill.py +5 -4
- sglang/srt/disaggregation/utils.py +13 -12
- sglang/srt/distributed/parallel_state.py +44 -17
- sglang/srt/entrypoints/EngineBase.py +8 -0
- sglang/srt/entrypoints/engine.py +45 -9
- sglang/srt/entrypoints/http_server.py +111 -24
- sglang/srt/entrypoints/openai/protocol.py +51 -6
- sglang/srt/entrypoints/openai/serving_chat.py +52 -76
- sglang/srt/entrypoints/openai/serving_completions.py +1 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/eplb/__init__.py +0 -0
- sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
- sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
- sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
- sglang/srt/{managers → eplb}/expert_distribution.py +18 -1
- sglang/srt/{managers → eplb}/expert_location.py +1 -1
- sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
- sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
- sglang/srt/hf_transformers_utils.py +2 -1
- sglang/srt/layers/activation.py +7 -0
- sglang/srt/layers/amx_utils.py +86 -0
- sglang/srt/layers/attention/ascend_backend.py +219 -0
- sglang/srt/layers/attention/flashattention_backend.py +56 -23
- sglang/srt/layers/attention/tbo_backend.py +37 -9
- sglang/srt/layers/communicator.py +18 -2
- sglang/srt/layers/dp_attention.py +9 -3
- sglang/srt/layers/elementwise.py +76 -12
- sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
- sglang/srt/layers/layernorm.py +41 -0
- sglang/srt/layers/linear.py +99 -12
- sglang/srt/layers/logits_processor.py +15 -6
- sglang/srt/layers/moe/ep_moe/kernels.py +23 -8
- sglang/srt/layers/moe/ep_moe/layer.py +115 -25
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +42 -19
- sglang/srt/layers/moe/fused_moe_native.py +7 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +8 -4
- sglang/srt/layers/moe/fused_moe_triton/layer.py +129 -10
- sglang/srt/layers/moe/router.py +60 -22
- sglang/srt/layers/moe/topk.py +36 -28
- sglang/srt/layers/parameter.py +67 -7
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
- sglang/srt/layers/quantization/fp8.py +44 -0
- sglang/srt/layers/quantization/fp8_kernel.py +1 -1
- sglang/srt/layers/quantization/fp8_utils.py +6 -6
- sglang/srt/layers/quantization/gptq.py +5 -1
- sglang/srt/layers/quantization/moe_wna16.py +1 -1
- sglang/srt/layers/quantization/quant_utils.py +166 -0
- sglang/srt/layers/quantization/w8a8_int8.py +52 -1
- sglang/srt/layers/rotary_embedding.py +105 -13
- sglang/srt/layers/vocab_parallel_embedding.py +19 -2
- sglang/srt/lora/lora.py +4 -5
- sglang/srt/lora/lora_manager.py +73 -20
- sglang/srt/managers/configure_logging.py +1 -1
- sglang/srt/managers/io_struct.py +60 -15
- sglang/srt/managers/mm_utils.py +73 -59
- sglang/srt/managers/multimodal_processor.py +2 -6
- sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
- sglang/srt/managers/schedule_batch.py +80 -79
- sglang/srt/managers/scheduler.py +153 -63
- sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
- sglang/srt/managers/session_controller.py +12 -3
- sglang/srt/managers/tokenizer_manager.py +314 -103
- sglang/srt/managers/tp_worker.py +13 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
- sglang/srt/mem_cache/allocator.py +290 -0
- sglang/srt/mem_cache/chunk_cache.py +34 -2
- sglang/srt/mem_cache/memory_pool.py +289 -3
- sglang/srt/mem_cache/multimodal_cache.py +3 -0
- sglang/srt/model_executor/cuda_graph_runner.py +3 -2
- sglang/srt/model_executor/forward_batch_info.py +17 -4
- sglang/srt/model_executor/model_runner.py +302 -58
- sglang/srt/model_loader/loader.py +86 -10
- sglang/srt/model_loader/weight_utils.py +160 -3
- sglang/srt/models/deepseek_nextn.py +5 -4
- sglang/srt/models/deepseek_v2.py +305 -26
- sglang/srt/models/deepseek_vl2.py +3 -5
- sglang/srt/models/gemma3_causal.py +1 -2
- sglang/srt/models/gemma3n_audio.py +949 -0
- sglang/srt/models/gemma3n_causal.py +1010 -0
- sglang/srt/models/gemma3n_mm.py +495 -0
- sglang/srt/models/hunyuan.py +771 -0
- sglang/srt/models/kimi_vl.py +1 -2
- sglang/srt/models/llama.py +10 -4
- sglang/srt/models/llama4.py +32 -45
- sglang/srt/models/llama_eagle3.py +61 -11
- sglang/srt/models/llava.py +5 -5
- sglang/srt/models/minicpmo.py +2 -2
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mllama4.py +43 -11
- sglang/srt/models/phi4mm.py +1 -3
- sglang/srt/models/pixtral.py +3 -7
- sglang/srt/models/qwen2.py +31 -3
- sglang/srt/models/qwen2_5_vl.py +1 -3
- sglang/srt/models/qwen2_audio.py +200 -0
- sglang/srt/models/qwen2_moe.py +32 -6
- sglang/srt/models/qwen2_vl.py +1 -4
- sglang/srt/models/qwen3.py +94 -25
- sglang/srt/models/qwen3_moe.py +68 -21
- sglang/srt/models/vila.py +3 -8
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +150 -133
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
- sglang/srt/multimodal/processors/gemma3n.py +82 -0
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +3 -6
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
- sglang/srt/operations_strategy.py +6 -2
- sglang/srt/reasoning_parser.py +26 -0
- sglang/srt/sampling/sampling_batch_info.py +39 -1
- sglang/srt/server_args.py +85 -24
- sglang/srt/speculative/build_eagle_tree.py +57 -18
- sglang/srt/speculative/eagle_worker.py +6 -4
- sglang/srt/two_batch_overlap.py +204 -28
- sglang/srt/utils.py +369 -138
- sglang/srt/warmup.py +12 -3
- sglang/test/runners.py +10 -1
- sglang/test/test_utils.py +15 -3
- sglang/version.py +1 -1
- {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/METADATA +9 -6
- {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/RECORD +149 -137
- sglang/math_utils.py +0 -8
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
- /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
- /sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +0 -0
- {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/WHEEL +0 -0
- {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/top_level.txt +0 -0
sglang/srt/layers/moe/router.py
CHANGED
@@ -1,4 +1,4 @@
|
|
1
|
-
from typing import Tuple
|
1
|
+
from typing import Optional, Tuple
|
2
2
|
|
3
3
|
import torch
|
4
4
|
import triton
|
@@ -16,6 +16,8 @@ def fused_moe_router_kernel(
|
|
16
16
|
moe_router_weight_ptr, # input (num_experts, hidden_dim)
|
17
17
|
topk_weights_ptr, # output (bs, topk)
|
18
18
|
topk_ids_ptr, # output (bs, topk)
|
19
|
+
correction_bias_ptr,
|
20
|
+
is_correction_bias: tl.constexpr,
|
19
21
|
num_experts: tl.constexpr,
|
20
22
|
topk: tl.constexpr,
|
21
23
|
moe_softcapping: tl.constexpr,
|
@@ -49,6 +51,11 @@ def fused_moe_router_kernel(
|
|
49
51
|
bottom = exped + 1
|
50
52
|
logits_softcapped = top / bottom * moe_softcapping
|
51
53
|
|
54
|
+
# Add bias after softcapping
|
55
|
+
if is_correction_bias:
|
56
|
+
bias = tl.load(correction_bias_ptr + tl.arange(0, num_experts))
|
57
|
+
logits_softcapped = logits_softcapped + bias
|
58
|
+
|
52
59
|
# topk
|
53
60
|
# assert 1 <= topk <= num_experts
|
54
61
|
|
@@ -109,6 +116,7 @@ def fused_moe_router_impl(
|
|
109
116
|
router_weight: torch.Tensor,
|
110
117
|
topk: int,
|
111
118
|
moe_softcapping: float,
|
119
|
+
correction_bias: Optional[torch.Tensor] = None,
|
112
120
|
):
|
113
121
|
assert len(x.shape) == 2 and x.shape[1] == router_weight.shape[1]
|
114
122
|
bs, hidden_dim = x.shape
|
@@ -117,23 +125,23 @@ def fused_moe_router_impl(
|
|
117
125
|
# router_logits = torch.empty((bs, num_experts), dtype=torch.float32, device=x.device)
|
118
126
|
topk_weights = torch.empty((bs, topk), dtype=torch.float32, device=x.device)
|
119
127
|
topk_ids = torch.empty((bs, topk), dtype=torch.int32, device=x.device)
|
128
|
+
is_correction_bias = correction_bias is not None
|
120
129
|
|
121
|
-
|
122
|
-
|
123
|
-
min_num_warps = 16 if _is_hip else 32
|
124
|
-
|
130
|
+
max_warps = 16 if _is_hip else 32
|
125
131
|
config = {
|
126
132
|
"BLOCK_SIZE": triton.next_power_of_2(hidden_dim),
|
127
133
|
"num_warps": max(
|
128
|
-
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)),
|
134
|
+
min(triton.next_power_of_2(triton.cdiv(hidden_dim, 256)), max_warps), 4
|
129
135
|
),
|
130
136
|
}
|
131
137
|
|
132
|
-
fused_moe_router_kernel[
|
138
|
+
fused_moe_router_kernel[(bs,)](
|
133
139
|
x,
|
134
140
|
router_weight,
|
135
141
|
topk_weights,
|
136
142
|
topk_ids,
|
143
|
+
correction_bias,
|
144
|
+
is_correction_bias=is_correction_bias,
|
137
145
|
num_experts=num_experts,
|
138
146
|
topk=topk,
|
139
147
|
moe_softcapping=moe_softcapping,
|
@@ -153,7 +161,7 @@ def fused_moe_router_large_bs_kernel(
|
|
153
161
|
topk_ids_ptr, # output (bs, topk)
|
154
162
|
bs,
|
155
163
|
num_experts: tl.constexpr,
|
156
|
-
topk: tl.constexpr, # only support topk
|
164
|
+
topk: tl.constexpr, # only support topk <= 2
|
157
165
|
moe_softcapping: tl.constexpr,
|
158
166
|
moe_renormalize: tl.constexpr, # not supported
|
159
167
|
K: tl.constexpr,
|
@@ -204,25 +212,53 @@ def fused_moe_router_large_bs_kernel(
|
|
204
212
|
logits_softcapped = (exped - 1) / (exped + 1) * moe_softcapping
|
205
213
|
|
206
214
|
# 5. top1
|
207
|
-
|
208
|
-
|
215
|
+
arange_block_size_n = tl.arange(0, BLOCK_SIZE_N)[None, :]
|
216
|
+
cond_top1 = arange_block_size_n < num_experts
|
217
|
+
top1 = tl.argmax(tl.where(cond_top1, logits_softcapped, float("-inf")), axis=1)
|
209
218
|
top1_v = tl.max(
|
210
|
-
tl.where(
|
219
|
+
tl.where(cond_top1, logits_softcapped, float("-inf")), axis=1, keep_dims=True
|
211
220
|
)
|
212
|
-
|
213
|
-
tl.where(
|
221
|
+
top1_invsumexp = 1.0 / tl.sum(
|
222
|
+
tl.where(cond_top1, tl.exp(logits_softcapped - top1_v), 0.0), axis=1
|
214
223
|
)
|
215
224
|
|
216
|
-
# 6. store to output
|
217
|
-
|
218
|
-
|
219
|
-
tl.store(topk_ids_ptr +
|
225
|
+
# 6. store top1 to output
|
226
|
+
offs_top1 = pid * topk * BLOCK_SIZE_M + topk * tl.arange(0, BLOCK_SIZE_M)
|
227
|
+
top1_mask = offs_top1 < bs * topk
|
228
|
+
tl.store(topk_ids_ptr + offs_top1, top1, mask=top1_mask)
|
220
229
|
tl.store(
|
221
|
-
topk_weights_ptr +
|
222
|
-
|
223
|
-
mask=
|
230
|
+
topk_weights_ptr + offs_top1,
|
231
|
+
top1_invsumexp,
|
232
|
+
mask=top1_mask,
|
224
233
|
)
|
225
234
|
|
235
|
+
# 7. handle topk == 2
|
236
|
+
if topk == 2:
|
237
|
+
cond_top2 = (arange_block_size_n < num_experts) and (
|
238
|
+
arange_block_size_n != top1[:, None]
|
239
|
+
)
|
240
|
+
top2 = tl.argmax(
|
241
|
+
tl.where(cond_top2, logits_softcapped, float("-inf")),
|
242
|
+
axis=1,
|
243
|
+
keep_dims=True,
|
244
|
+
)
|
245
|
+
top2_v = tl.sum(
|
246
|
+
logits_softcapped * (arange_block_size_n == top2), axis=1, keep_dims=True
|
247
|
+
)
|
248
|
+
top2_invsumexp = tl.exp(top2_v - top1_v) * top1_invsumexp[:, None]
|
249
|
+
|
250
|
+
# store top2
|
251
|
+
offs_top2 = (
|
252
|
+
pid * topk * BLOCK_SIZE_M + topk * tl.arange(0, BLOCK_SIZE_M)[:, None] + 1
|
253
|
+
)
|
254
|
+
top2_mask = offs_top2 < bs * topk
|
255
|
+
tl.store(topk_ids_ptr + offs_top2, top2, mask=top2_mask)
|
256
|
+
tl.store(
|
257
|
+
topk_weights_ptr + offs_top2,
|
258
|
+
top2_invsumexp,
|
259
|
+
mask=top2_mask,
|
260
|
+
)
|
261
|
+
|
226
262
|
|
227
263
|
def fused_moe_router_large_bs_impl(
|
228
264
|
x: torch.Tensor,
|
@@ -239,7 +275,7 @@ def fused_moe_router_large_bs_impl(
|
|
239
275
|
|
240
276
|
assert num_experts <= BLOCK_SIZE_N
|
241
277
|
assert hidden_dim % BLOCK_SIZE_K == 0
|
242
|
-
assert topk
|
278
|
+
assert topk <= 2
|
243
279
|
|
244
280
|
topk_weights = torch.empty((bs, topk), dtype=torch.float32, device=x.device)
|
245
281
|
topk_ids = torch.empty((bs, topk), dtype=torch.int32, device=x.device)
|
@@ -273,6 +309,7 @@ def fused_moe_router_shim(
|
|
273
309
|
gating_output,
|
274
310
|
topk,
|
275
311
|
renormalize,
|
312
|
+
correction_bias: Optional[torch.Tensor] = None,
|
276
313
|
):
|
277
314
|
assert not renormalize
|
278
315
|
assert (
|
@@ -286,7 +323,7 @@ def fused_moe_router_shim(
|
|
286
323
|
BLOCK_SIZE_K = 256
|
287
324
|
if (
|
288
325
|
bs >= 512
|
289
|
-
and topk
|
326
|
+
and topk <= 2
|
290
327
|
and num_experts <= BLOCK_SIZE_N
|
291
328
|
and hidden_dim % BLOCK_SIZE_K == 0
|
292
329
|
):
|
@@ -305,6 +342,7 @@ def fused_moe_router_shim(
|
|
305
342
|
router_weight=gating_output,
|
306
343
|
topk=topk,
|
307
344
|
moe_softcapping=moe_softcapping,
|
345
|
+
correction_bias=correction_bias,
|
308
346
|
)
|
309
347
|
|
310
348
|
|
sglang/srt/layers/moe/topk.py
CHANGED
@@ -18,34 +18,43 @@ from typing import Callable, Optional
|
|
18
18
|
import torch
|
19
19
|
import torch.nn.functional as F
|
20
20
|
|
21
|
-
from sglang.srt.
|
22
|
-
from sglang.srt.
|
21
|
+
from sglang.srt.eplb import expert_location_dispatch
|
22
|
+
from sglang.srt.eplb.expert_distribution import (
|
23
23
|
ExpertDistributionRecorder,
|
24
24
|
get_global_expert_distribution_recorder,
|
25
25
|
)
|
26
|
-
from sglang.srt.
|
26
|
+
from sglang.srt.eplb.expert_location_dispatch import (
|
27
27
|
ExpertLocationDispatchInfo,
|
28
28
|
topk_ids_logical_to_physical,
|
29
29
|
)
|
30
30
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
31
31
|
from sglang.srt.utils import (
|
32
32
|
cpu_has_amx_support,
|
33
|
+
get_bool_env_var,
|
33
34
|
get_compiler_backend,
|
34
35
|
is_cpu,
|
35
36
|
is_cuda,
|
36
37
|
is_hip,
|
38
|
+
is_npu,
|
37
39
|
)
|
38
40
|
|
39
41
|
_is_cuda = is_cuda()
|
40
42
|
_is_hip = is_hip()
|
43
|
+
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
41
44
|
_is_cpu_amx_available = cpu_has_amx_support()
|
42
45
|
_is_cpu = is_cpu()
|
46
|
+
_is_npu = is_npu()
|
43
47
|
|
44
48
|
if _is_cuda:
|
45
49
|
from sgl_kernel import moe_fused_gate
|
46
50
|
|
47
51
|
if _is_cuda or _is_hip:
|
48
52
|
from sgl_kernel import topk_softmax
|
53
|
+
if _use_aiter:
|
54
|
+
try:
|
55
|
+
from aiter import biased_grouped_topk as aiter_biased_grouped_topk
|
56
|
+
except ImportError:
|
57
|
+
raise ImportError("aiter is required when SGLANG_USE_AITER is set to True")
|
49
58
|
|
50
59
|
|
51
60
|
def fused_topk_torch_native(
|
@@ -99,37 +108,14 @@ def fused_topk(
|
|
99
108
|
M, topk, dtype=torch.float32, device=hidden_states.device
|
100
109
|
)
|
101
110
|
topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
|
102
|
-
token_expert_indicies = torch.empty(
|
103
|
-
M, topk, dtype=torch.int32, device=hidden_states.device
|
104
|
-
)
|
105
111
|
|
106
112
|
topk_softmax(
|
107
113
|
topk_weights,
|
108
114
|
topk_ids,
|
109
|
-
|
110
|
-
|
111
|
-
)
|
112
|
-
del token_expert_indicies
|
113
|
-
|
114
|
-
return _fused_topk_postprocess(
|
115
|
-
topk_weights=topk_weights,
|
116
|
-
topk_ids=topk_ids,
|
117
|
-
renormalize=renormalize,
|
118
|
-
expert_location_dispatch_info=expert_location_dispatch_info,
|
119
|
-
num_token_non_padded=num_token_non_padded,
|
115
|
+
gating_output,
|
116
|
+
renormalize,
|
120
117
|
)
|
121
118
|
|
122
|
-
|
123
|
-
@torch.compile(dynamic=True, backend=get_compiler_backend())
|
124
|
-
def _fused_topk_postprocess(
|
125
|
-
topk_weights,
|
126
|
-
topk_ids,
|
127
|
-
renormalize,
|
128
|
-
expert_location_dispatch_info,
|
129
|
-
num_token_non_padded,
|
130
|
-
):
|
131
|
-
if renormalize:
|
132
|
-
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
133
119
|
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
|
134
120
|
_mask_topk_ids_padded_region(topk_ids, num_token_non_padded)
|
135
121
|
return topk_weights, topk_ids
|
@@ -152,6 +138,9 @@ def grouped_topk_gpu(
|
|
152
138
|
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
153
139
|
|
154
140
|
scores = torch.softmax(gating_output, dim=-1)
|
141
|
+
# NPU compiler limitation
|
142
|
+
if _is_npu and scores.dtype == torch.bfloat16:
|
143
|
+
scores = scores.to(torch.float16)
|
155
144
|
num_token = scores.shape[0]
|
156
145
|
num_experts = scores.shape[1]
|
157
146
|
group_scores = (
|
@@ -347,6 +336,25 @@ def biased_grouped_topk_gpu(
|
|
347
336
|
topk_ids, expert_location_dispatch_info, num_token_non_padded
|
348
337
|
)
|
349
338
|
return topk_weights, topk_ids
|
339
|
+
elif _use_aiter:
|
340
|
+
token = gating_output.shape[0]
|
341
|
+
device = gating_output.device
|
342
|
+
assert (
|
343
|
+
hidden_states.shape[0] == gating_output.shape[0]
|
344
|
+
), f"Number of tokens mismatch: hidden_states.shape[0] = {hidden_states.shape[0]}, gating_output.shape[0] = {gating_output.shape[0]}"
|
345
|
+
topk_weights = torch.empty((token, topk), dtype=torch.float32, device=device)
|
346
|
+
topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device)
|
347
|
+
aiter_biased_grouped_topk(
|
348
|
+
gating_output,
|
349
|
+
correction_bias,
|
350
|
+
topk_weights,
|
351
|
+
topk_ids,
|
352
|
+
num_expert_group,
|
353
|
+
topk_group,
|
354
|
+
renormalize,
|
355
|
+
routed_scaling_factor,
|
356
|
+
)
|
357
|
+
return topk_weights, topk_ids
|
350
358
|
else:
|
351
359
|
biased_grouped_topk_fn = (
|
352
360
|
torch.compile(
|
sglang/srt/layers/parameter.py
CHANGED
@@ -7,6 +7,8 @@ from typing import Callable, Optional, Union
|
|
7
7
|
import torch
|
8
8
|
from torch.nn import Parameter
|
9
9
|
|
10
|
+
from sglang.srt.utils import is_cpu
|
11
|
+
|
10
12
|
__all__ = [
|
11
13
|
"BasevLLMParameter",
|
12
14
|
"PackedvLLMParameter",
|
@@ -21,6 +23,8 @@ __all__ = [
|
|
21
23
|
|
22
24
|
logger = logging.getLogger(__name__)
|
23
25
|
|
26
|
+
_is_cpu = is_cpu()
|
27
|
+
|
24
28
|
|
25
29
|
class BasevLLMParameter(Parameter):
|
26
30
|
"""
|
@@ -93,9 +97,28 @@ class _ColumnvLLMParameter(BasevLLMParameter):
|
|
93
97
|
):
|
94
98
|
if not use_presharded_weights:
|
95
99
|
shard_size = self.data.shape[self.output_dim]
|
96
|
-
|
97
|
-
|
100
|
+
|
101
|
+
from sglang.srt.model_loader.weight_utils import (
|
102
|
+
narrow_padded_param_and_loaded_weight,
|
98
103
|
)
|
104
|
+
|
105
|
+
if _is_cpu:
|
106
|
+
param_data, loaded_weight = narrow_padded_param_and_loaded_weight(
|
107
|
+
self.data,
|
108
|
+
loaded_weight,
|
109
|
+
0, # param_data_start
|
110
|
+
tp_rank * shard_size,
|
111
|
+
self.output_dim,
|
112
|
+
shard_size,
|
113
|
+
)
|
114
|
+
assert param_data.shape == loaded_weight.shape
|
115
|
+
param_data.copy_(loaded_weight)
|
116
|
+
return
|
117
|
+
else:
|
118
|
+
loaded_weight = loaded_weight.narrow(
|
119
|
+
self.output_dim, tp_rank * shard_size, shard_size
|
120
|
+
)
|
121
|
+
|
99
122
|
assert self.data.shape == loaded_weight.shape
|
100
123
|
self.data.copy_(loaded_weight)
|
101
124
|
|
@@ -116,10 +139,27 @@ class _ColumnvLLMParameter(BasevLLMParameter):
|
|
116
139
|
param_data = self.data
|
117
140
|
|
118
141
|
param_data = param_data.narrow(self.output_dim, shard_offset, shard_size)
|
119
|
-
|
120
|
-
|
121
|
-
|
142
|
+
|
143
|
+
from sglang.srt.model_loader.weight_utils import (
|
144
|
+
narrow_padded_param_and_loaded_weight,
|
145
|
+
)
|
146
|
+
|
147
|
+
if _is_cpu:
|
148
|
+
param_data, loaded_weight = narrow_padded_param_and_loaded_weight(
|
149
|
+
param_data,
|
150
|
+
loaded_weight,
|
151
|
+
0, # param_data_start
|
152
|
+
tp_rank * shard_size,
|
153
|
+
self.output_dim,
|
154
|
+
shard_size,
|
155
|
+
not use_presharded_weights,
|
122
156
|
)
|
157
|
+
else:
|
158
|
+
if not use_presharded_weights:
|
159
|
+
loaded_weight = loaded_weight.narrow(
|
160
|
+
self.output_dim, tp_rank * shard_size, shard_size
|
161
|
+
)
|
162
|
+
|
123
163
|
assert param_data.shape == loaded_weight.shape
|
124
164
|
param_data.copy_(loaded_weight)
|
125
165
|
|
@@ -182,10 +222,30 @@ class RowvLLMParameter(BasevLLMParameter):
|
|
182
222
|
):
|
183
223
|
if not use_presharded_weights:
|
184
224
|
shard_size = self.data.shape[self.input_dim]
|
185
|
-
|
186
|
-
|
225
|
+
|
226
|
+
from sglang.srt.model_loader.weight_utils import (
|
227
|
+
narrow_padded_param_and_loaded_weight,
|
187
228
|
)
|
188
229
|
|
230
|
+
if _is_cpu:
|
231
|
+
param_data, loaded_weight = narrow_padded_param_and_loaded_weight(
|
232
|
+
self.data,
|
233
|
+
loaded_weight,
|
234
|
+
0, # param_data_start
|
235
|
+
tp_rank * shard_size,
|
236
|
+
self.input_dim,
|
237
|
+
shard_size,
|
238
|
+
)
|
239
|
+
|
240
|
+
assert param_data.shape == loaded_weight.shape
|
241
|
+
param_data.copy_(loaded_weight)
|
242
|
+
|
243
|
+
return
|
244
|
+
else:
|
245
|
+
loaded_weight = loaded_weight.narrow(
|
246
|
+
self.input_dim, tp_rank * shard_size, shard_size
|
247
|
+
)
|
248
|
+
|
189
249
|
if len(loaded_weight.shape) == 0:
|
190
250
|
loaded_weight = loaded_weight.reshape(1)
|
191
251
|
|
@@ -76,7 +76,7 @@ class CompressedTensorsW8A16Fp8(CompressedTensorsScheme):
|
|
76
76
|
layer.input_scale = torch.nn.Parameter(
|
77
77
|
layer.input_scale.data, requires_grad=False
|
78
78
|
)
|
79
|
-
prepare_fp8_layer_for_marlin(layer,
|
79
|
+
prepare_fp8_layer_for_marlin(layer, size_k_first=True)
|
80
80
|
|
81
81
|
def create_weights(
|
82
82
|
self,
|
@@ -27,6 +27,7 @@ except ImportError:
|
|
27
27
|
|
28
28
|
|
29
29
|
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
30
|
+
from sglang.srt.layers.amx_utils import _amx_process_weight_after_loading
|
30
31
|
from sglang.srt.layers.linear import (
|
31
32
|
LinearBase,
|
32
33
|
LinearMethodBase,
|
@@ -73,6 +74,7 @@ from sglang.srt.utils import (
|
|
73
74
|
log_info_on_rank0,
|
74
75
|
print_warning_once,
|
75
76
|
set_weight_attrs,
|
77
|
+
use_intel_amx_backend,
|
76
78
|
)
|
77
79
|
|
78
80
|
_is_hip = is_hip()
|
@@ -330,6 +332,12 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
330
332
|
)
|
331
333
|
|
332
334
|
layer.input_scale = None
|
335
|
+
elif _is_cpu:
|
336
|
+
assert (
|
337
|
+
_is_cpu_amx_available
|
338
|
+
), "Fp8LinearMethod on CPU requires that CPU has AMX support"
|
339
|
+
_amx_process_weight_after_loading(layer, ["weight"])
|
340
|
+
return
|
333
341
|
else:
|
334
342
|
weight, weight_scale = layer.weight.data, layer.weight_scale_inv.data
|
335
343
|
layer.weight = torch.nn.Parameter(weight, requires_grad=False)
|
@@ -426,6 +434,17 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
426
434
|
)
|
427
435
|
|
428
436
|
if self.block_quant:
|
437
|
+
if use_intel_amx_backend(layer):
|
438
|
+
return torch.ops.sgl_kernel.fp8_scaled_mm_cpu(
|
439
|
+
x,
|
440
|
+
layer.weight,
|
441
|
+
layer.weight_scale_inv,
|
442
|
+
self.quant_config.weight_block_size,
|
443
|
+
bias,
|
444
|
+
x.dtype,
|
445
|
+
True, # is_vnni
|
446
|
+
)
|
447
|
+
|
429
448
|
return self.w8a8_block_fp8_linear(
|
430
449
|
input=x,
|
431
450
|
weight=layer.weight,
|
@@ -746,6 +765,13 @@ class Fp8MoEMethod:
|
|
746
765
|
layer.w2_weight.data = shuffle_weight(
|
747
766
|
layer.w2_weight.contiguous(), (16, 16)
|
748
767
|
)
|
768
|
+
|
769
|
+
if _is_cpu:
|
770
|
+
assert (
|
771
|
+
_is_cpu_amx_available
|
772
|
+
), "Fp8MoEMethod on CPU requires that CPU has AMX support"
|
773
|
+
_amx_process_weight_after_loading(layer, ["w13_weight", "w2_weight"])
|
774
|
+
|
749
775
|
return
|
750
776
|
|
751
777
|
# If checkpoint is fp16 or bfloat16, quantize in place.
|
@@ -971,6 +997,24 @@ class Fp8MoEMethod:
|
|
971
997
|
routed_scaling_factor=routed_scaling_factor,
|
972
998
|
)
|
973
999
|
|
1000
|
+
if use_intel_amx_backend(layer):
|
1001
|
+
return torch.ops.sgl_kernel.fused_experts_cpu(
|
1002
|
+
x,
|
1003
|
+
layer.w13_weight,
|
1004
|
+
layer.w2_weight,
|
1005
|
+
topk_weights,
|
1006
|
+
topk_ids,
|
1007
|
+
False, # inplace See [Note] inplace should be False in fused_experts.
|
1008
|
+
False, # use_int8_w8a8
|
1009
|
+
True, # use_fp8_w8a16
|
1010
|
+
layer.w13_weight_scale_inv, # w1_scale
|
1011
|
+
layer.w2_weight_scale_inv, # w2_scale
|
1012
|
+
self.quant_config.weight_block_size, # block_size
|
1013
|
+
None, # a1_scale
|
1014
|
+
None, # a2_scale
|
1015
|
+
True, # is_vnni
|
1016
|
+
)
|
1017
|
+
|
974
1018
|
if _is_hip:
|
975
1019
|
ret = self.maybe_apply_hip_fused_experts(
|
976
1020
|
layer,
|
@@ -23,9 +23,9 @@ import torch
|
|
23
23
|
import triton
|
24
24
|
import triton.language as tl
|
25
25
|
|
26
|
-
from sglang.math_utils import align
|
27
26
|
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
28
27
|
from sglang.srt.utils import (
|
28
|
+
align,
|
29
29
|
direct_register_custom_op,
|
30
30
|
get_device_core_count,
|
31
31
|
get_device_name,
|
@@ -1,9 +1,7 @@
|
|
1
1
|
from typing import Callable, List, Optional, Tuple
|
2
2
|
|
3
|
-
import einops
|
4
3
|
import torch
|
5
4
|
|
6
|
-
from sglang.math_utils import align
|
7
5
|
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
8
6
|
from sglang.srt.layers.quantization.fp8_kernel import sglang_per_token_group_quant_fp8
|
9
7
|
from sglang.srt.layers.utils import is_sm100_supported
|
@@ -27,6 +25,7 @@ from sglang.srt.layers.quantization.fp8_kernel import (
|
|
27
25
|
w8a8_block_fp8_matmul_triton,
|
28
26
|
)
|
29
27
|
from sglang.srt.utils import (
|
28
|
+
align,
|
30
29
|
get_bool_env_var,
|
31
30
|
get_cuda_version,
|
32
31
|
get_device_capability,
|
@@ -42,7 +41,10 @@ _is_fp8_fnuz = is_fp8_fnuz()
|
|
42
41
|
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
43
42
|
|
44
43
|
if _use_aiter:
|
45
|
-
|
44
|
+
import aiter
|
45
|
+
from aiter import gemm_a8w8_blockscale_CK, get_hip_quant
|
46
|
+
|
47
|
+
aiter_per1x128_quant = get_hip_quant(aiter.QuantType.per_1x128)
|
46
48
|
|
47
49
|
if _is_cuda:
|
48
50
|
from sgl_kernel import fp8_blockwise_scaled_mm, fp8_scaled_mm
|
@@ -271,9 +273,7 @@ def aiter_w8a8_block_fp8_linear(
|
|
271
273
|
input_2d = input.view(-1, input.shape[-1])
|
272
274
|
output_shape = [*input.shape[:-1], weight.shape[0]]
|
273
275
|
|
274
|
-
q_input, x_scale =
|
275
|
-
input_2d, block_size[1], column_major_scales=False
|
276
|
-
)
|
276
|
+
q_input, x_scale = aiter_per1x128_quant(input_2d, quant_dtype=aiter.dtypes.fp8)
|
277
277
|
output = gemm_a8w8_blockscale_CK(
|
278
278
|
q_input, weight, x_scale, weight_scale, dtype=input.dtype
|
279
279
|
)
|
@@ -344,6 +344,10 @@ class GPTQMarlinConfig(QuantizationConfig):
|
|
344
344
|
if (num_bits, sym) not in cls.TYPE_MAP:
|
345
345
|
return False
|
346
346
|
|
347
|
+
assert (
|
348
|
+
VLLM_AVAILABLE
|
349
|
+
), "vllm is not installed, to use gptq_marlin, please install vllm"
|
350
|
+
|
347
351
|
return check_marlin_supported(
|
348
352
|
quant_type=cls.TYPE_MAP[(num_bits, sym)], group_size=group_size
|
349
353
|
)
|
@@ -726,6 +730,6 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
|
726
730
|
g_idx2=layer.w2_g_idx,
|
727
731
|
sort_indices1=layer.w13_g_idx_sort_indices,
|
728
732
|
sort_indices2=layer.w2_g_idx_sort_indices,
|
729
|
-
|
733
|
+
quant_type_id=self.quant_config.quant_type.id,
|
730
734
|
is_k_full=self.is_k_full,
|
731
735
|
).to(orig_dtype)
|
@@ -131,7 +131,7 @@ class MoeWNA16Config(QuantizationConfig):
|
|
131
131
|
capability_tuple = get_device_capability()
|
132
132
|
device_capability = (
|
133
133
|
-1
|
134
|
-
if
|
134
|
+
if all(capability is None for capability in capability_tuple)
|
135
135
|
else capability_tuple[0] * 10 + capability_tuple[1]
|
136
136
|
)
|
137
137
|
# Avoid circular import
|