sglang 0.4.3.post3__py3-none-any.whl → 0.4.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +2 -2
- sglang/lang/chat_template.py +29 -0
- sglang/srt/_custom_ops.py +19 -17
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/janus_pro.py +629 -0
- sglang/srt/configs/model_config.py +24 -14
- sglang/srt/conversation.py +80 -2
- sglang/srt/custom_op.py +64 -3
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +18 -17
- sglang/srt/distributed/parallel_state.py +10 -1
- sglang/srt/entrypoints/engine.py +5 -3
- sglang/srt/entrypoints/http_server.py +1 -1
- sglang/srt/hf_transformers_utils.py +16 -1
- sglang/srt/layers/attention/flashinfer_backend.py +95 -49
- sglang/srt/layers/attention/flashinfer_mla_backend.py +317 -57
- sglang/srt/layers/attention/triton_backend.py +5 -5
- sglang/srt/layers/attention/triton_ops/decode_attention.py +6 -6
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +3 -3
- sglang/srt/layers/attention/triton_ops/extend_attention.py +4 -4
- sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +3 -3
- sglang/srt/layers/attention/vision.py +43 -62
- sglang/srt/layers/linear.py +1 -1
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -1
- sglang/srt/layers/moe/ep_moe/layer.py +25 -9
- sglang/srt/layers/moe/fused_moe_triton/configs/E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +63 -23
- sglang/srt/layers/moe/fused_moe_triton/layer.py +16 -4
- sglang/srt/layers/parameter.py +10 -0
- sglang/srt/layers/quantization/__init__.py +90 -68
- sglang/srt/layers/quantization/blockwise_int8.py +1 -2
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/fp8.py +174 -106
- sglang/srt/layers/quantization/fp8_kernel.py +210 -38
- sglang/srt/layers/quantization/fp8_utils.py +156 -15
- sglang/srt/layers/quantization/modelopt_quant.py +5 -1
- sglang/srt/layers/quantization/w8a8_fp8.py +128 -0
- sglang/srt/layers/quantization/w8a8_int8.py +152 -3
- sglang/srt/layers/rotary_embedding.py +5 -3
- sglang/srt/layers/sampler.py +29 -35
- sglang/srt/layers/vocab_parallel_embedding.py +0 -1
- sglang/srt/lora/backend/__init__.py +9 -12
- sglang/srt/managers/cache_controller.py +72 -8
- sglang/srt/managers/image_processor.py +37 -631
- sglang/srt/managers/image_processors/base_image_processor.py +219 -0
- sglang/srt/managers/image_processors/janus_pro.py +79 -0
- sglang/srt/managers/image_processors/llava.py +152 -0
- sglang/srt/managers/image_processors/minicpmv.py +86 -0
- sglang/srt/managers/image_processors/mlama.py +60 -0
- sglang/srt/managers/image_processors/qwen_vl.py +161 -0
- sglang/srt/managers/io_struct.py +33 -15
- sglang/srt/managers/multi_modality_padding.py +134 -0
- sglang/srt/managers/schedule_batch.py +212 -117
- sglang/srt/managers/schedule_policy.py +40 -8
- sglang/srt/managers/scheduler.py +258 -782
- sglang/srt/managers/scheduler_output_processor_mixin.py +611 -0
- sglang/srt/managers/tokenizer_manager.py +7 -6
- sglang/srt/managers/tp_worker_overlap_thread.py +4 -1
- sglang/srt/mem_cache/base_prefix_cache.py +6 -8
- sglang/srt/mem_cache/chunk_cache.py +12 -44
- sglang/srt/mem_cache/hiradix_cache.py +63 -34
- sglang/srt/mem_cache/memory_pool.py +112 -46
- sglang/srt/mem_cache/paged_allocator.py +283 -0
- sglang/srt/mem_cache/radix_cache.py +117 -36
- sglang/srt/metrics/collector.py +8 -0
- sglang/srt/model_executor/cuda_graph_runner.py +10 -11
- sglang/srt/model_executor/forward_batch_info.py +12 -8
- sglang/srt/model_executor/model_runner.py +153 -134
- sglang/srt/model_loader/loader.py +2 -1
- sglang/srt/model_loader/weight_utils.py +1 -1
- sglang/srt/models/deepseek_janus_pro.py +2127 -0
- sglang/srt/models/deepseek_nextn.py +23 -3
- sglang/srt/models/deepseek_v2.py +25 -19
- sglang/srt/models/minicpmv.py +28 -89
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/qwen2.py +0 -1
- sglang/srt/models/qwen2_5_vl.py +25 -50
- sglang/srt/models/qwen2_vl.py +33 -49
- sglang/srt/openai_api/adapter.py +37 -15
- sglang/srt/openai_api/protocol.py +8 -1
- sglang/srt/sampling/penaltylib/frequency_penalty.py +0 -1
- sglang/srt/sampling/penaltylib/presence_penalty.py +0 -1
- sglang/srt/server_args.py +19 -20
- sglang/srt/speculative/build_eagle_tree.py +6 -1
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -11
- sglang/srt/speculative/eagle_utils.py +2 -1
- sglang/srt/speculative/eagle_worker.py +109 -38
- sglang/srt/utils.py +104 -9
- sglang/test/runners.py +104 -10
- sglang/test/test_block_fp8.py +106 -16
- sglang/test/test_custom_ops.py +88 -0
- sglang/test/test_utils.py +20 -4
- sglang/utils.py +0 -4
- sglang/version.py +1 -1
- {sglang-0.4.3.post3.dist-info → sglang-0.4.4.dist-info}/METADATA +9 -9
- {sglang-0.4.3.post3.dist-info → sglang-0.4.4.dist-info}/RECORD +128 -83
- {sglang-0.4.3.post3.dist-info → sglang-0.4.4.dist-info}/WHEEL +1 -1
- {sglang-0.4.3.post3.dist-info → sglang-0.4.4.dist-info}/LICENSE +0 -0
- {sglang-0.4.3.post3.dist-info → sglang-0.4.4.dist-info}/top_level.txt +0 -0
@@ -11,20 +11,23 @@ from typing import Any, Callable, Dict, List, Optional, Tuple
|
|
11
11
|
import torch
|
12
12
|
import triton
|
13
13
|
import triton.language as tl
|
14
|
-
from vllm import _custom_ops as
|
14
|
+
from vllm import _custom_ops as vllm_ops
|
15
15
|
|
16
16
|
from sglang.srt.layers.moe.topk import select_experts
|
17
17
|
from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
|
18
|
-
from sglang.srt.layers.quantization.int8_kernel import
|
18
|
+
from sglang.srt.layers.quantization.int8_kernel import (
|
19
|
+
per_token_group_quant_int8,
|
20
|
+
per_token_quant_int8,
|
21
|
+
)
|
19
22
|
from sglang.srt.utils import (
|
20
23
|
direct_register_custom_op,
|
21
24
|
get_bool_env_var,
|
22
25
|
get_device_name,
|
23
|
-
|
26
|
+
is_cuda,
|
24
27
|
is_hip,
|
25
28
|
)
|
26
29
|
|
27
|
-
|
30
|
+
_is_hip = is_hip()
|
28
31
|
|
29
32
|
|
30
33
|
logger = logging.getLogger(__name__)
|
@@ -34,17 +37,17 @@ enable_moe_align_block_size_triton = bool(
|
|
34
37
|
int(os.getenv("ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0"))
|
35
38
|
)
|
36
39
|
|
37
|
-
_is_cuda =
|
38
|
-
_is_rocm = torch.cuda.is_available() and torch.version.hip
|
40
|
+
_is_cuda = is_cuda()
|
39
41
|
|
40
42
|
if _is_cuda:
|
41
43
|
from sgl_kernel import gelu_and_mul, silu_and_mul
|
42
44
|
|
45
|
+
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
|
43
46
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
44
47
|
sglang_per_token_group_quant_fp8,
|
45
48
|
)
|
46
49
|
|
47
|
-
if _is_cuda or
|
50
|
+
if _is_cuda or _is_hip:
|
48
51
|
from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
|
49
52
|
|
50
53
|
|
@@ -117,6 +120,7 @@ def fused_moe_kernel(
|
|
117
120
|
- expert_ids: A tensor containing the indices of the expert for each
|
118
121
|
block. It determines which expert matrix from B should be used for
|
119
122
|
each block in A.
|
123
|
+
|
120
124
|
This kernel performs the multiplication of a token by its corresponding
|
121
125
|
expert matrix as determined by `expert_ids`. The sorting of
|
122
126
|
`sorted_token_ids` by expert index and padding ensures divisibility by
|
@@ -167,17 +171,38 @@ def fused_moe_kernel(
|
|
167
171
|
)
|
168
172
|
b_scale = tl.load(b_scale_ptrs)
|
169
173
|
|
170
|
-
if use_fp8_w8a8
|
174
|
+
if use_fp8_w8a8:
|
175
|
+
# block-wise
|
171
176
|
if group_k > 0 and group_n > 0:
|
172
177
|
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
|
173
178
|
offs_bsn = offs_bn // group_n
|
174
179
|
b_scale_ptrs = (
|
175
180
|
b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn
|
176
181
|
)
|
182
|
+
# tensor-wise
|
177
183
|
else:
|
178
184
|
a_scale = tl.load(a_scale_ptr)
|
179
185
|
b_scale = tl.load(b_scale_ptr + off_experts)
|
180
186
|
|
187
|
+
if use_int8_w8a8:
|
188
|
+
# block-wise
|
189
|
+
if group_k > 0 and group_n > 0:
|
190
|
+
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
|
191
|
+
offs_bsn = offs_bn // group_n
|
192
|
+
b_scale_ptrs = (
|
193
|
+
b_scale_ptr + off_experts * stride_bse + offs_bsn * stride_bsn
|
194
|
+
)
|
195
|
+
# channel-wise
|
196
|
+
else:
|
197
|
+
# Load per-column scale for weights
|
198
|
+
b_scale_ptrs = (
|
199
|
+
b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn
|
200
|
+
)
|
201
|
+
b_scale = tl.load(b_scale_ptrs)
|
202
|
+
# Load per-token scale for activations
|
203
|
+
a_scale_ptrs = a_scale_ptr + (offs_token // top_k) * stride_asm
|
204
|
+
a_scale = tl.load(a_scale_ptrs, mask=token_mask, other=0.0)[:, None]
|
205
|
+
|
181
206
|
# -----------------------------------------------------------
|
182
207
|
# Iterate to compute a block of the C matrix.
|
183
208
|
# We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
|
@@ -217,7 +242,11 @@ def fused_moe_kernel(
|
|
217
242
|
|
218
243
|
accumulator += tl.dot(a, b) * a_scale[:, None] * b_scale[None, :]
|
219
244
|
else:
|
220
|
-
|
245
|
+
# fix out of shared memory issue
|
246
|
+
if use_fp8_w8a8:
|
247
|
+
accumulator = tl.dot(a, b, acc=accumulator)
|
248
|
+
else:
|
249
|
+
accumulator += tl.dot(a, b)
|
221
250
|
else:
|
222
251
|
accumulator += tl.dot(a, b)
|
223
252
|
# Advance the ptrs to the next K block.
|
@@ -458,7 +487,7 @@ def moe_align_block_size(
|
|
458
487
|
cumsum_buffer,
|
459
488
|
)
|
460
489
|
else:
|
461
|
-
|
490
|
+
vllm_ops.moe_align_block_size(
|
462
491
|
topk_ids,
|
463
492
|
num_experts,
|
464
493
|
block_size,
|
@@ -497,9 +526,14 @@ def invoke_fused_moe_kernel(
|
|
497
526
|
if use_fp8_w8a8:
|
498
527
|
assert B_scale is not None
|
499
528
|
if block_shape is None:
|
529
|
+
# activation tensor-wise fp8 quantization, dynamic or static
|
500
530
|
padded_size = padding_size
|
501
|
-
|
531
|
+
if _is_cuda:
|
532
|
+
A, A_scale = sgl_scaled_fp8_quant(A, A_scale)
|
533
|
+
else:
|
534
|
+
A, A_scale = vllm_ops.scaled_fp8_quant(A, A_scale)
|
502
535
|
else:
|
536
|
+
# activation block-wise fp8 quantization
|
503
537
|
assert len(block_shape) == 2
|
504
538
|
block_n, block_k = block_shape[0], block_shape[1]
|
505
539
|
if _is_cuda:
|
@@ -512,9 +546,10 @@ def invoke_fused_moe_kernel(
|
|
512
546
|
elif use_int8_w8a8:
|
513
547
|
assert B_scale is not None
|
514
548
|
if block_shape is None:
|
515
|
-
|
516
|
-
A, A_scale =
|
549
|
+
# activation channel-wise int8 quantization
|
550
|
+
A, A_scale = per_token_quant_int8(A)
|
517
551
|
else:
|
552
|
+
# activation block-wise int8 quantization
|
518
553
|
assert len(block_shape) == 2
|
519
554
|
block_n, block_k = block_shape[0], block_shape[1]
|
520
555
|
A, A_scale = per_token_group_quant_int8(A, block_k)
|
@@ -648,7 +683,7 @@ def get_default_config(
|
|
648
683
|
"BLOCK_SIZE_K": 128,
|
649
684
|
"GROUP_SIZE_M": 32,
|
650
685
|
"num_warps": 8,
|
651
|
-
"num_stages": 2 if
|
686
|
+
"num_stages": 2 if _is_hip else 4,
|
652
687
|
}
|
653
688
|
if M <= E:
|
654
689
|
config = {
|
@@ -657,7 +692,7 @@ def get_default_config(
|
|
657
692
|
"BLOCK_SIZE_K": 128,
|
658
693
|
"GROUP_SIZE_M": 1,
|
659
694
|
"num_warps": 4,
|
660
|
-
"num_stages": 2 if
|
695
|
+
"num_stages": 2 if _is_hip else 4,
|
661
696
|
}
|
662
697
|
else:
|
663
698
|
# Block-wise quant: BLOCK_SIZE_K must be divisable by block_shape[1]
|
@@ -667,7 +702,7 @@ def get_default_config(
|
|
667
702
|
"BLOCK_SIZE_K": block_shape[1],
|
668
703
|
"GROUP_SIZE_M": 32,
|
669
704
|
"num_warps": 4,
|
670
|
-
"num_stages": 2 if
|
705
|
+
"num_stages": 2 if _is_hip else 3,
|
671
706
|
}
|
672
707
|
else:
|
673
708
|
config = {
|
@@ -945,7 +980,7 @@ def fused_experts_impl(
|
|
945
980
|
if (
|
946
981
|
not (use_fp8_w8a8 or use_int8_w8a8)
|
947
982
|
or block_shape is not None
|
948
|
-
or (
|
983
|
+
or (_is_hip and get_bool_env_var("CK_MOE"))
|
949
984
|
):
|
950
985
|
padded_size = 0
|
951
986
|
|
@@ -1029,7 +1064,9 @@ def fused_experts_impl(
|
|
1029
1064
|
# so the cache size and config are already set correctly and
|
1030
1065
|
# do not need to be adjusted.
|
1031
1066
|
intermediate_cache1 = intermediate_cache1[:tokens_in_chunk]
|
1032
|
-
intermediate_cache2 = intermediate_cache2[
|
1067
|
+
intermediate_cache2 = intermediate_cache2[
|
1068
|
+
: tokens_in_chunk * topk_ids.shape[1]
|
1069
|
+
]
|
1033
1070
|
intermediate_cache3 = intermediate_cache3[:tokens_in_chunk]
|
1034
1071
|
config = get_config_func(tokens_in_chunk)
|
1035
1072
|
|
@@ -1060,17 +1097,20 @@ def fused_experts_impl(
|
|
1060
1097
|
use_int8_w8a16=use_int8_w8a16,
|
1061
1098
|
block_shape=block_shape,
|
1062
1099
|
)
|
1063
|
-
|
1064
1100
|
if activation == "silu":
|
1065
1101
|
if _is_cuda:
|
1066
1102
|
silu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
|
1067
1103
|
else:
|
1068
|
-
|
1104
|
+
vllm_ops.silu_and_mul(
|
1105
|
+
intermediate_cache2, intermediate_cache1.view(-1, N)
|
1106
|
+
)
|
1069
1107
|
elif activation == "gelu":
|
1070
1108
|
if _is_cuda:
|
1071
1109
|
gelu_and_mul(intermediate_cache1.view(-1, N), intermediate_cache2)
|
1072
1110
|
else:
|
1073
|
-
|
1111
|
+
vllm_ops.gelu_and_mul(
|
1112
|
+
intermediate_cache2, intermediate_cache1.view(-1, N)
|
1113
|
+
)
|
1074
1114
|
else:
|
1075
1115
|
raise ValueError(f"Unsupported activation: {activation=}")
|
1076
1116
|
|
@@ -1101,8 +1141,8 @@ def fused_experts_impl(
|
|
1101
1141
|
|
1102
1142
|
if no_combine:
|
1103
1143
|
pass
|
1104
|
-
elif
|
1105
|
-
|
1144
|
+
elif _is_hip:
|
1145
|
+
vllm_ops.moe_sum(
|
1106
1146
|
intermediate_cache3.view(*intermediate_cache3.shape),
|
1107
1147
|
out_hidden_states[begin_chunk_idx:end_chunk_idx],
|
1108
1148
|
)
|
@@ -27,9 +27,9 @@ else:
|
|
27
27
|
|
28
28
|
import logging
|
29
29
|
|
30
|
-
|
30
|
+
_is_hip = is_hip()
|
31
31
|
|
32
|
-
if
|
32
|
+
if _is_hip:
|
33
33
|
from aiter import ck_moe
|
34
34
|
|
35
35
|
logger = logging.getLogger(__name__)
|
@@ -102,7 +102,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
102
102
|
set_weight_attrs(w2_weight, extra_weight_attrs)
|
103
103
|
|
104
104
|
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
105
|
-
if
|
105
|
+
if _is_hip and get_bool_env_var("CK_MOE"):
|
106
106
|
layer.w13_weight = torch.nn.Parameter(
|
107
107
|
permute_weight(layer.w13_weight.data),
|
108
108
|
requires_grad=False,
|
@@ -175,7 +175,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
175
175
|
correction_bias=correction_bias,
|
176
176
|
)
|
177
177
|
|
178
|
-
if
|
178
|
+
if _is_hip and get_bool_env_var("CK_MOE"):
|
179
179
|
assert not no_combine, "unsupported"
|
180
180
|
return ck_moe(
|
181
181
|
x,
|
@@ -513,6 +513,10 @@ class FusedMoE(torch.nn.Module):
|
|
513
513
|
|
514
514
|
# Case input scale: input_scale loading is only supported for fp8
|
515
515
|
if "input_scale" in weight_name:
|
516
|
+
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust input_scale for e4m3fnuz (AMD)
|
517
|
+
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
|
518
|
+
loaded_weight = loaded_weight * 2.0
|
519
|
+
|
516
520
|
# this is needed for compressed-tensors only
|
517
521
|
loaded_weight = loaded_weight.to(param.data.device)
|
518
522
|
|
@@ -551,6 +555,10 @@ class FusedMoE(torch.nn.Module):
|
|
551
555
|
# specific to each case
|
552
556
|
quant_method = getattr(param, "quant_method", None)
|
553
557
|
if quant_method == FusedMoeWeightScaleSupported.CHANNEL.value:
|
558
|
+
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust INT4 column-wise scaling number to e4m3fnuz (AMD)
|
559
|
+
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
|
560
|
+
loaded_weight = loaded_weight * 0.5
|
561
|
+
|
554
562
|
self._load_per_channel_weight_scale(
|
555
563
|
shard_id=shard_id,
|
556
564
|
shard_dim=shard_dim,
|
@@ -570,6 +578,10 @@ class FusedMoE(torch.nn.Module):
|
|
570
578
|
tp_rank=tp_rank,
|
571
579
|
)
|
572
580
|
elif quant_method == FusedMoeWeightScaleSupported.TENSOR.value:
|
581
|
+
# INT4-FP8 (INT4 MoE Weight, FP8 Compute): Adjust FP8 per-tensor scaling number for e4m3fnuz (AMD)
|
582
|
+
if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
|
583
|
+
loaded_weight = loaded_weight * 2.0
|
584
|
+
|
573
585
|
self._load_per_tensor_weight_scale(
|
574
586
|
shard_id=shard_id,
|
575
587
|
param=param,
|
sglang/srt/layers/parameter.py
CHANGED
@@ -16,6 +16,7 @@ __all__ = [
|
|
16
16
|
"ModelWeightParameter",
|
17
17
|
"ChannelQuantScaleParameter",
|
18
18
|
"GroupQuantScaleParameter",
|
19
|
+
"BlockQuantScaleParameter",
|
19
20
|
"PackedColumnParameter",
|
20
21
|
"RowvLLMParameter",
|
21
22
|
]
|
@@ -221,6 +222,15 @@ class ChannelQuantScaleParameter(_ColumnvLLMParameter):
|
|
221
222
|
pass
|
222
223
|
|
223
224
|
|
225
|
+
class BlockQuantScaleParameter(_ColumnvLLMParameter, RowvLLMParameter):
|
226
|
+
"""
|
227
|
+
Parameter class for weight scales loaded for weights with
|
228
|
+
block-wise quantization. Uses both column and row parallelism.
|
229
|
+
"""
|
230
|
+
|
231
|
+
pass
|
232
|
+
|
233
|
+
|
224
234
|
class PerTensorScaleParameter(BasevLLMParameter):
|
225
235
|
"""
|
226
236
|
Parameter class for scales where the number of scales is
|
@@ -1,4 +1,6 @@
|
|
1
1
|
# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/__init__.py
|
2
|
+
import builtins
|
3
|
+
import inspect
|
2
4
|
import re
|
3
5
|
from copy import deepcopy
|
4
6
|
from typing import Callable, Dict, Optional, Type, Union
|
@@ -6,10 +8,7 @@ from typing import Callable, Dict, Optional, Type, Union
|
|
6
8
|
import torch
|
7
9
|
from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
|
8
10
|
from vllm.model_executor.layers.quantization.awq import AWQConfig
|
9
|
-
from vllm.model_executor.layers.quantization.awq_marlin import
|
10
|
-
AWQMarlinConfig,
|
11
|
-
AWQMoEMethod,
|
12
|
-
)
|
11
|
+
from vllm.model_executor.layers.quantization.awq_marlin import AWQMarlinConfig
|
13
12
|
from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig
|
14
13
|
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors import (
|
15
14
|
CompressedTensorsConfig,
|
@@ -28,6 +27,7 @@ from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config
|
|
28
27
|
from sglang.srt.layers.quantization.fp8 import Fp8Config
|
29
28
|
from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
|
30
29
|
from sglang.srt.layers.quantization.modelopt_quant import ModelOptFp8Config
|
30
|
+
from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
|
31
31
|
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
|
32
32
|
|
33
33
|
QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
@@ -50,6 +50,7 @@ QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
|
50
50
|
"qqq": QQQConfig,
|
51
51
|
"experts_int8": ExpertsInt8Config,
|
52
52
|
"w8a8_int8": W8A8Int8Config,
|
53
|
+
"w8a8_fp8": W8A8Fp8Config,
|
53
54
|
}
|
54
55
|
|
55
56
|
|
@@ -178,96 +179,117 @@ def gptq_get_quant_method(self, layer, prefix):
|
|
178
179
|
return None
|
179
180
|
|
180
181
|
|
181
|
-
|
182
|
-
from vllm.model_executor.layers.quantization.awq import is_layer_skipped_awq
|
183
|
-
from vllm.model_executor.layers.quantization.awq_marlin import (
|
184
|
-
AWQMarlinLinearMethod,
|
185
|
-
AWQMoEMethod,
|
186
|
-
)
|
182
|
+
original_isinstance = builtins.isinstance
|
187
183
|
|
188
|
-
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
|
189
|
-
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
190
|
-
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
191
184
|
|
192
|
-
|
193
|
-
|
194
|
-
|
195
|
-
|
196
|
-
|
197
|
-
return AWQMarlinLinearMethod(self)
|
198
|
-
elif isinstance(layer, FusedMoE):
|
199
|
-
return AWQMoEMethod(self)
|
200
|
-
return None
|
185
|
+
def monkey_patch_isinstance_for_vllm_base_layer(reverse: bool = False):
|
186
|
+
"""
|
187
|
+
Patch isinstance so that the `get_quant_method` in vllm's QuantizationConfig
|
188
|
+
can recognize sglang layers
|
189
|
+
"""
|
201
190
|
|
191
|
+
if reverse:
|
192
|
+
builtins.isinstance = original_isinstance
|
193
|
+
return
|
202
194
|
|
203
|
-
|
204
|
-
|
205
|
-
|
206
|
-
def awq_moe_method_apply(
|
207
|
-
self,
|
208
|
-
layer: torch.nn.Module,
|
209
|
-
x: torch.Tensor,
|
210
|
-
router_logits: torch.Tensor,
|
211
|
-
top_k: int,
|
212
|
-
renormalize: bool,
|
213
|
-
use_grouped_topk: bool = False,
|
214
|
-
topk_group: Optional[int] = None,
|
215
|
-
num_expert_group: Optional[int] = None,
|
216
|
-
custom_routing_function: Optional[Callable] = None,
|
217
|
-
scoring_func: str = "softmax",
|
218
|
-
e_score_correction_bias: Optional[torch.Tensor] = None,
|
219
|
-
**kwargs,
|
220
|
-
):
|
221
|
-
return original_awq_moe_method_apply(
|
222
|
-
self,
|
223
|
-
layer,
|
224
|
-
x,
|
225
|
-
router_logits,
|
226
|
-
top_k,
|
227
|
-
renormalize,
|
228
|
-
use_grouped_topk,
|
229
|
-
topk_group,
|
230
|
-
num_expert_group,
|
231
|
-
custom_routing_function,
|
232
|
-
scoring_func,
|
233
|
-
e_score_correction_bias,
|
234
|
-
)
|
235
|
-
|
236
|
-
|
237
|
-
def patch_vllm_linear_base_isinstance():
|
238
|
-
import builtins
|
239
|
-
|
195
|
+
from vllm.model_executor.layers.fused_moe import FusedMoE
|
240
196
|
from vllm.model_executor.layers.linear import LinearBase
|
197
|
+
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
198
|
+
VocabParallelEmbedding,
|
199
|
+
)
|
241
200
|
|
242
201
|
from sglang.srt.layers.linear import LinearBase as PatchedLinearBase
|
243
|
-
|
244
|
-
|
202
|
+
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE as PatchedFusedMoE
|
203
|
+
from sglang.srt.layers.vocab_parallel_embedding import (
|
204
|
+
VocabParallelEmbedding as PatchedVocabParallelEmbedding,
|
205
|
+
)
|
245
206
|
|
246
207
|
def patched_isinstance(obj, classinfo):
|
247
208
|
if classinfo is LinearBase:
|
248
209
|
return original_isinstance(obj, PatchedLinearBase)
|
210
|
+
if classinfo is FusedMoE:
|
211
|
+
return original_isinstance(obj, PatchedFusedMoE)
|
212
|
+
if classinfo is VocabParallelEmbedding:
|
213
|
+
return original_isinstance(obj, PatchedVocabParallelEmbedding)
|
249
214
|
return original_isinstance(obj, classinfo)
|
250
215
|
|
251
216
|
builtins.isinstance = patched_isinstance
|
252
217
|
|
253
218
|
|
254
|
-
def
|
219
|
+
def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
|
220
|
+
"""
|
221
|
+
Monkey patch the apply function of vllm's FusedMoEMethodBase.
|
222
|
+
Convert sglang arguments to vllm arguments.
|
223
|
+
"""
|
224
|
+
original_apply = class_obj.apply
|
225
|
+
sig = inspect.signature(original_apply)
|
226
|
+
param_names = list(sig.parameters.keys())
|
227
|
+
has_correction_bias = "e_score_correction_bias" in param_names
|
228
|
+
|
229
|
+
def new_apply(
|
230
|
+
self,
|
231
|
+
layer: torch.nn.Module,
|
232
|
+
x: torch.Tensor,
|
233
|
+
router_logits: torch.Tensor,
|
234
|
+
top_k: int,
|
235
|
+
renormalize: bool,
|
236
|
+
use_grouped_topk: bool,
|
237
|
+
topk_group: Optional[int] = None,
|
238
|
+
num_expert_group: Optional[int] = None,
|
239
|
+
custom_routing_function: Optional[Callable] = None,
|
240
|
+
correction_bias: Optional[torch.Tensor] = None,
|
241
|
+
activation: str = "silu",
|
242
|
+
inplace: bool = True,
|
243
|
+
no_combine: bool = False,
|
244
|
+
):
|
245
|
+
assert activation == "silu"
|
246
|
+
assert inplace and not no_combine
|
247
|
+
|
248
|
+
kwargs = {
|
249
|
+
"self": self,
|
250
|
+
"layer": layer,
|
251
|
+
"x": x,
|
252
|
+
"router_logits": router_logits,
|
253
|
+
"top_k": top_k,
|
254
|
+
"renormalize": renormalize,
|
255
|
+
"use_grouped_topk": use_grouped_topk,
|
256
|
+
"topk_group": topk_group,
|
257
|
+
"num_expert_group": num_expert_group,
|
258
|
+
"custom_routing_function": custom_routing_function,
|
259
|
+
}
|
260
|
+
if correction_bias is not None:
|
261
|
+
if not has_correction_bias:
|
262
|
+
raise ValueError(
|
263
|
+
"Please increase the version of your vllm. Try `pip install vllm==0.7.2`"
|
264
|
+
)
|
265
|
+
kwargs["e_score_correction_bias"] = correction_bias
|
266
|
+
return original_apply(**kwargs)
|
267
|
+
|
268
|
+
setattr(class_obj, "apply", new_apply)
|
269
|
+
|
270
|
+
|
271
|
+
def monkey_patch_quant_configs():
|
255
272
|
"""Apply all monkey patches in one place."""
|
256
273
|
from vllm.model_executor.layers.quantization.awq_marlin import AWQMoEMethod
|
274
|
+
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import (
|
275
|
+
CompressedTensorsW8A8Fp8MoEMethod,
|
276
|
+
CompressedTensorsWNA16MoEMethod,
|
277
|
+
)
|
278
|
+
from vllm.model_executor.layers.quantization.gptq_marlin import GPTQMarlinMoEMethod
|
257
279
|
|
258
280
|
setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
|
259
281
|
setattr(GPTQConfig, "get_quant_method", gptq_get_quant_method)
|
260
|
-
|
261
|
-
|
282
|
+
|
283
|
+
monkey_patch_moe_apply(AWQMoEMethod)
|
284
|
+
monkey_patch_moe_apply(GPTQMarlinMoEMethod)
|
285
|
+
monkey_patch_moe_apply(CompressedTensorsW8A8Fp8MoEMethod)
|
286
|
+
monkey_patch_moe_apply(CompressedTensorsWNA16MoEMethod)
|
262
287
|
|
263
288
|
|
264
|
-
|
265
|
-
# Apply patches when module is imported
|
266
|
-
apply_monkey_patches()
|
289
|
+
monkey_patch_quant_configs()
|
267
290
|
|
268
291
|
|
269
292
|
__all__ = [
|
270
|
-
"QuantizationConfig",
|
271
293
|
"get_quantization_config",
|
272
294
|
"QUANTIZATION_METHODS",
|
273
295
|
]
|
@@ -13,12 +13,11 @@ from sglang.srt.layers.linear import (
|
|
13
13
|
LinearMethodBase,
|
14
14
|
UnquantizedLinearMethod,
|
15
15
|
)
|
16
|
-
from sglang.srt.layers.parameter import
|
16
|
+
from sglang.srt.layers.parameter import BlockQuantScaleParameter, ModelWeightParameter
|
17
17
|
from sglang.srt.layers.quantization.base_config import (
|
18
18
|
QuantizationConfig,
|
19
19
|
QuantizeMethodBase,
|
20
20
|
)
|
21
|
-
from sglang.srt.layers.quantization.fp8_utils import BlockQuantScaleParameter
|
22
21
|
from sglang.srt.layers.quantization.int8_utils import apply_w8a8_block_int8_linear
|
23
22
|
from sglang.srt.utils import set_weight_attrs
|
24
23
|
|
@@ -0,0 +1,146 @@
|
|
1
|
+
{
|
2
|
+
"1": {
|
3
|
+
"BLOCK_SIZE_M": 16,
|
4
|
+
"BLOCK_SIZE_N": 32,
|
5
|
+
"BLOCK_SIZE_K": 64,
|
6
|
+
"GROUP_SIZE_M": 64,
|
7
|
+
"num_warps": 4,
|
8
|
+
"num_stages": 4
|
9
|
+
},
|
10
|
+
"2": {
|
11
|
+
"BLOCK_SIZE_M": 16,
|
12
|
+
"BLOCK_SIZE_N": 32,
|
13
|
+
"BLOCK_SIZE_K": 64,
|
14
|
+
"GROUP_SIZE_M": 64,
|
15
|
+
"num_warps": 4,
|
16
|
+
"num_stages": 3
|
17
|
+
},
|
18
|
+
"4": {
|
19
|
+
"BLOCK_SIZE_M": 16,
|
20
|
+
"BLOCK_SIZE_N": 32,
|
21
|
+
"BLOCK_SIZE_K": 128,
|
22
|
+
"GROUP_SIZE_M": 32,
|
23
|
+
"num_warps": 4,
|
24
|
+
"num_stages": 3
|
25
|
+
},
|
26
|
+
"8": {
|
27
|
+
"BLOCK_SIZE_M": 16,
|
28
|
+
"BLOCK_SIZE_N": 32,
|
29
|
+
"BLOCK_SIZE_K": 128,
|
30
|
+
"GROUP_SIZE_M": 32,
|
31
|
+
"num_warps": 8,
|
32
|
+
"num_stages": 4
|
33
|
+
},
|
34
|
+
"16": {
|
35
|
+
"BLOCK_SIZE_M": 16,
|
36
|
+
"BLOCK_SIZE_N": 32,
|
37
|
+
"BLOCK_SIZE_K": 64,
|
38
|
+
"GROUP_SIZE_M": 32,
|
39
|
+
"num_warps": 4,
|
40
|
+
"num_stages": 4
|
41
|
+
},
|
42
|
+
"24": {
|
43
|
+
"BLOCK_SIZE_M": 16,
|
44
|
+
"BLOCK_SIZE_N": 128,
|
45
|
+
"BLOCK_SIZE_K": 128,
|
46
|
+
"GROUP_SIZE_M": 16,
|
47
|
+
"num_warps": 4,
|
48
|
+
"num_stages": 3
|
49
|
+
},
|
50
|
+
"32": {
|
51
|
+
"BLOCK_SIZE_M": 16,
|
52
|
+
"BLOCK_SIZE_N": 32,
|
53
|
+
"BLOCK_SIZE_K": 128,
|
54
|
+
"GROUP_SIZE_M": 64,
|
55
|
+
"num_warps": 8,
|
56
|
+
"num_stages": 3
|
57
|
+
},
|
58
|
+
"48": {
|
59
|
+
"BLOCK_SIZE_M": 16,
|
60
|
+
"BLOCK_SIZE_N": 64,
|
61
|
+
"BLOCK_SIZE_K": 128,
|
62
|
+
"GROUP_SIZE_M": 1,
|
63
|
+
"num_warps": 4,
|
64
|
+
"num_stages": 4
|
65
|
+
},
|
66
|
+
"64": {
|
67
|
+
"BLOCK_SIZE_M": 16,
|
68
|
+
"BLOCK_SIZE_N": 64,
|
69
|
+
"BLOCK_SIZE_K": 128,
|
70
|
+
"GROUP_SIZE_M": 16,
|
71
|
+
"num_warps": 4,
|
72
|
+
"num_stages": 5
|
73
|
+
},
|
74
|
+
"96": {
|
75
|
+
"BLOCK_SIZE_M": 16,
|
76
|
+
"BLOCK_SIZE_N": 32,
|
77
|
+
"BLOCK_SIZE_K": 128,
|
78
|
+
"GROUP_SIZE_M": 16,
|
79
|
+
"num_warps": 4,
|
80
|
+
"num_stages": 4
|
81
|
+
},
|
82
|
+
"128": {
|
83
|
+
"BLOCK_SIZE_M": 16,
|
84
|
+
"BLOCK_SIZE_N": 64,
|
85
|
+
"BLOCK_SIZE_K": 128,
|
86
|
+
"GROUP_SIZE_M": 64,
|
87
|
+
"num_warps": 4,
|
88
|
+
"num_stages": 3
|
89
|
+
},
|
90
|
+
"256": {
|
91
|
+
"BLOCK_SIZE_M": 32,
|
92
|
+
"BLOCK_SIZE_N": 64,
|
93
|
+
"BLOCK_SIZE_K": 128,
|
94
|
+
"GROUP_SIZE_M": 16,
|
95
|
+
"num_warps": 4,
|
96
|
+
"num_stages": 4
|
97
|
+
},
|
98
|
+
"512": {
|
99
|
+
"BLOCK_SIZE_M": 64,
|
100
|
+
"BLOCK_SIZE_N": 64,
|
101
|
+
"BLOCK_SIZE_K": 128,
|
102
|
+
"GROUP_SIZE_M": 64,
|
103
|
+
"num_warps": 4,
|
104
|
+
"num_stages": 3
|
105
|
+
},
|
106
|
+
"1024": {
|
107
|
+
"BLOCK_SIZE_M": 128,
|
108
|
+
"BLOCK_SIZE_N": 64,
|
109
|
+
"BLOCK_SIZE_K": 128,
|
110
|
+
"GROUP_SIZE_M": 32,
|
111
|
+
"num_warps": 4,
|
112
|
+
"num_stages": 3
|
113
|
+
},
|
114
|
+
"1536": {
|
115
|
+
"BLOCK_SIZE_M": 128,
|
116
|
+
"BLOCK_SIZE_N": 64,
|
117
|
+
"BLOCK_SIZE_K": 128,
|
118
|
+
"GROUP_SIZE_M": 16,
|
119
|
+
"num_warps": 4,
|
120
|
+
"num_stages": 3
|
121
|
+
},
|
122
|
+
"2048": {
|
123
|
+
"BLOCK_SIZE_M": 128,
|
124
|
+
"BLOCK_SIZE_N": 64,
|
125
|
+
"BLOCK_SIZE_K": 128,
|
126
|
+
"GROUP_SIZE_M": 64,
|
127
|
+
"num_warps": 4,
|
128
|
+
"num_stages": 3
|
129
|
+
},
|
130
|
+
"3072": {
|
131
|
+
"BLOCK_SIZE_M": 128,
|
132
|
+
"BLOCK_SIZE_N": 64,
|
133
|
+
"BLOCK_SIZE_K": 128,
|
134
|
+
"GROUP_SIZE_M": 64,
|
135
|
+
"num_warps": 4,
|
136
|
+
"num_stages": 3
|
137
|
+
},
|
138
|
+
"4096": {
|
139
|
+
"BLOCK_SIZE_M": 128,
|
140
|
+
"BLOCK_SIZE_N": 64,
|
141
|
+
"BLOCK_SIZE_K": 128,
|
142
|
+
"GROUP_SIZE_M": 64,
|
143
|
+
"num_warps": 4,
|
144
|
+
"num_stages": 3
|
145
|
+
}
|
146
|
+
}
|