sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +2 -1
- sglang/eval/loogle_eval.py +7 -0
- sglang/srt/configs/deepseekvl2.py +11 -2
- sglang/srt/configs/internvl.py +3 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/model_config.py +9 -7
- sglang/srt/configs/update_config.py +3 -1
- sglang/srt/conversation.py +1 -0
- sglang/srt/custom_op.py +5 -2
- sglang/srt/disaggregation/decode.py +9 -1
- sglang/srt/disaggregation/mooncake/conn.py +44 -56
- sglang/srt/distributed/parallel_state.py +33 -0
- sglang/srt/entrypoints/engine.py +30 -26
- sglang/srt/entrypoints/openai/serving_chat.py +21 -2
- sglang/srt/eplb/expert_location_dispatch.py +1 -1
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/qwen3_detector.py +150 -0
- sglang/srt/hf_transformers_utils.py +0 -1
- sglang/srt/layers/activation.py +13 -0
- sglang/srt/layers/attention/flashattention_backend.py +3 -3
- sglang/srt/layers/attention/flashinfer_backend.py +40 -1
- sglang/srt/layers/linear.py +13 -102
- sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
- sglang/srt/layers/moe/ep_moe/layer.py +23 -402
- sglang/srt/layers/moe/fused_moe_native.py +7 -47
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +35 -45
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
- sglang/srt/layers/moe/topk.py +187 -12
- sglang/srt/layers/quantization/__init__.py +20 -134
- sglang/srt/layers/quantization/awq.py +578 -11
- sglang/srt/layers/quantization/awq_triton.py +339 -0
- sglang/srt/layers/quantization/base_config.py +85 -10
- sglang/srt/layers/quantization/blockwise_int8.py +17 -55
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +24 -73
- sglang/srt/layers/quantization/fp8.py +273 -62
- sglang/srt/layers/quantization/fp8_kernel.py +210 -46
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gptq.py +501 -143
- sglang/srt/layers/quantization/marlin_utils.py +790 -0
- sglang/srt/layers/quantization/modelopt_quant.py +26 -108
- sglang/srt/layers/quantization/moe_wna16.py +45 -49
- sglang/srt/layers/quantization/petit.py +252 -0
- sglang/srt/layers/quantization/petit_utils.py +104 -0
- sglang/srt/layers/quantization/qoq.py +7 -6
- sglang/srt/layers/quantization/scalar_type.py +352 -0
- sglang/srt/layers/quantization/unquant.py +422 -0
- sglang/srt/layers/quantization/utils.py +343 -3
- sglang/srt/layers/quantization/w4afp8.py +8 -4
- sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
- sglang/srt/layers/quantization/w8a8_int8.py +51 -115
- sglang/srt/layers/vocab_parallel_embedding.py +1 -41
- sglang/srt/lora/lora.py +0 -4
- sglang/srt/lora/lora_manager.py +87 -53
- sglang/srt/lora/mem_pool.py +81 -33
- sglang/srt/lora/utils.py +12 -5
- sglang/srt/managers/cache_controller.py +241 -0
- sglang/srt/managers/io_struct.py +41 -29
- sglang/srt/managers/mm_utils.py +7 -8
- sglang/srt/managers/schedule_batch.py +150 -110
- sglang/srt/managers/schedule_policy.py +68 -27
- sglang/srt/managers/scheduler.py +243 -61
- sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
- sglang/srt/managers/tokenizer_manager.py +11 -3
- sglang/srt/managers/tp_worker.py +14 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/allocator.py +7 -16
- sglang/srt/mem_cache/base_prefix_cache.py +14 -2
- sglang/srt/mem_cache/chunk_cache.py +5 -2
- sglang/srt/mem_cache/hicache_storage.py +152 -0
- sglang/srt/mem_cache/hiradix_cache.py +179 -4
- sglang/srt/mem_cache/memory_pool.py +16 -1
- sglang/srt/mem_cache/memory_pool_host.py +41 -2
- sglang/srt/mem_cache/radix_cache.py +26 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +5 -6
- sglang/srt/model_executor/forward_batch_info.py +14 -1
- sglang/srt/model_executor/model_runner.py +109 -22
- sglang/srt/model_loader/loader.py +7 -1
- sglang/srt/model_loader/utils.py +4 -4
- sglang/srt/models/clip.py +1 -1
- sglang/srt/models/deepseek.py +9 -6
- sglang/srt/models/deepseek_janus_pro.py +1 -1
- sglang/srt/models/deepseek_v2.py +191 -171
- sglang/srt/models/deepseek_vl2.py +5 -5
- sglang/srt/models/gemma.py +48 -0
- sglang/srt/models/gemma2.py +52 -0
- sglang/srt/models/gemma3_causal.py +63 -0
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -4
- sglang/srt/models/granitemoe.py +385 -0
- sglang/srt/models/grok.py +9 -3
- sglang/srt/models/hunyuan.py +63 -16
- sglang/srt/models/internvl.py +1 -1
- sglang/srt/models/kimi_vl.py +1 -1
- sglang/srt/models/llama.py +41 -0
- sglang/srt/models/llama4.py +11 -11
- sglang/srt/models/llava.py +2 -2
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/minicpm.py +0 -2
- sglang/srt/models/minicpmo.py +3 -7
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mixtral.py +9 -2
- sglang/srt/models/mllama.py +3 -5
- sglang/srt/models/mllama4.py +3 -3
- sglang/srt/models/olmoe.py +8 -5
- sglang/srt/models/persimmon.py +330 -0
- sglang/srt/models/phi.py +321 -0
- sglang/srt/models/phi4mm.py +44 -4
- sglang/srt/models/phi4mm_audio.py +1260 -0
- sglang/srt/models/phi4mm_utils.py +1917 -0
- sglang/srt/models/phimoe.py +9 -3
- sglang/srt/models/qwen.py +37 -0
- sglang/srt/models/qwen2.py +41 -0
- sglang/srt/models/qwen2_5_vl.py +4 -4
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +53 -5
- sglang/srt/models/qwen2_vl.py +4 -4
- sglang/srt/models/qwen3.py +65 -1
- sglang/srt/models/qwen3_moe.py +56 -18
- sglang/srt/models/vila.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +91 -97
- sglang/srt/multimodal/processors/clip.py +21 -19
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
- sglang/srt/multimodal/processors/gemma3.py +13 -17
- sglang/srt/multimodal/processors/gemma3n.py +19 -23
- sglang/srt/multimodal/processors/internvl.py +9 -10
- sglang/srt/multimodal/processors/janus_pro.py +12 -27
- sglang/srt/multimodal/processors/kimi_vl.py +12 -14
- sglang/srt/multimodal/processors/llava.py +4 -2
- sglang/srt/multimodal/processors/minicpm.py +35 -44
- sglang/srt/multimodal/processors/mlama.py +21 -18
- sglang/srt/multimodal/processors/mllama4.py +4 -5
- sglang/srt/multimodal/processors/phi4mm.py +63 -39
- sglang/srt/multimodal/processors/pixtral.py +14 -35
- sglang/srt/multimodal/processors/qwen_audio.py +65 -0
- sglang/srt/multimodal/processors/qwen_vl.py +16 -21
- sglang/srt/multimodal/processors/vila.py +14 -14
- sglang/srt/sampling/sampling_params.py +8 -1
- sglang/srt/server_args.py +393 -230
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +9 -1
- sglang/srt/two_batch_overlap.py +1 -0
- sglang/srt/utils.py +27 -1
- sglang/test/runners.py +14 -3
- sglang/test/test_block_fp8.py +8 -3
- sglang/test/test_block_fp8_ep.py +1 -1
- sglang/test/test_custom_ops.py +12 -7
- sglang/test/test_cutlass_w4a8_moe.py +1 -3
- sglang/test/test_fp4_moe.py +1 -3
- sglang/test/test_marlin_moe.py +286 -0
- sglang/test/test_marlin_utils.py +171 -0
- sglang/test/test_utils.py +35 -0
- sglang/version.py +1 -1
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/METADATA +8 -8
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/RECORD +166 -146
- sglang/srt/layers/quantization/quant_utils.py +0 -166
- sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,339 @@
|
|
1
|
+
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/awq_triton.py
|
2
|
+
|
3
|
+
# SPDX-License-Identifier: Apache-2.0
|
4
|
+
# SPDX-FileCopyrightText: Copyright contributors to the vLLM project
|
5
|
+
|
6
|
+
import torch
|
7
|
+
import triton
|
8
|
+
import triton.language as tl
|
9
|
+
|
10
|
+
AWQ_TRITON_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
|
11
|
+
|
12
|
+
|
13
|
+
@triton.jit
|
14
|
+
def awq_dequantize_kernel(
|
15
|
+
qweight_ptr, # quantized matrix
|
16
|
+
scales_ptr, # scales, per group
|
17
|
+
zeros_ptr, # zeros, per group
|
18
|
+
group_size, # Should always be one of the supported group sizes
|
19
|
+
result_ptr, # Output matrix
|
20
|
+
num_cols, # input num cols in qweight
|
21
|
+
num_rows, # input num rows in qweight
|
22
|
+
BLOCK_SIZE_X: tl.constexpr,
|
23
|
+
BLOCK_SIZE_Y: tl.constexpr,
|
24
|
+
):
|
25
|
+
# Setup the pids.
|
26
|
+
pid_x = tl.program_id(axis=0)
|
27
|
+
pid_y = tl.program_id(axis=1)
|
28
|
+
|
29
|
+
# Compute offsets and masks for qweight_ptr.
|
30
|
+
offsets_y = pid_y * BLOCK_SIZE_Y + tl.arange(0, BLOCK_SIZE_Y)
|
31
|
+
offsets_x = pid_x * BLOCK_SIZE_X + tl.arange(0, BLOCK_SIZE_X)
|
32
|
+
offsets = num_cols * offsets_y[:, None] + offsets_x[None, :]
|
33
|
+
|
34
|
+
masks_y = offsets_y < num_rows
|
35
|
+
masks_x = offsets_x < num_cols
|
36
|
+
|
37
|
+
masks = masks_y[:, None] & masks_x[None, :]
|
38
|
+
|
39
|
+
# Compute offsets and masks for result output ptr.
|
40
|
+
result_offsets_y = pid_y * BLOCK_SIZE_Y + tl.arange(0, BLOCK_SIZE_Y)
|
41
|
+
result_offsets_x = pid_x * BLOCK_SIZE_X * 8 + tl.arange(0, BLOCK_SIZE_X * 8)
|
42
|
+
result_offsets = (
|
43
|
+
8 * num_cols * result_offsets_y[:, None] + result_offsets_x[None, :]
|
44
|
+
)
|
45
|
+
|
46
|
+
result_masks_y = result_offsets_y < num_rows
|
47
|
+
result_masks_x = result_offsets_x < num_cols * 8
|
48
|
+
result_masks = result_masks_y[:, None] & result_masks_x[None, :]
|
49
|
+
|
50
|
+
# Load the weights.
|
51
|
+
iweights = tl.load(qweight_ptr + offsets, masks, 0.0)
|
52
|
+
iweights = tl.interleave(iweights, iweights)
|
53
|
+
iweights = tl.interleave(iweights, iweights)
|
54
|
+
iweights = tl.interleave(iweights, iweights)
|
55
|
+
|
56
|
+
# Create reverse AWQ order as tensor: [0, 4, 1, 5, 2, 6, 3, 7]
|
57
|
+
# that will map given indices to the correct order.
|
58
|
+
reverse_awq_order_tensor = (
|
59
|
+
(tl.arange(0, 2) * 4)[None, :] + tl.arange(0, 4)[:, None]
|
60
|
+
).reshape(8)
|
61
|
+
|
62
|
+
# Use this to compute a set of shifts that can be used to unpack and
|
63
|
+
# reorder the values in iweights and zeros.
|
64
|
+
shifts = reverse_awq_order_tensor * 4
|
65
|
+
shifts = tl.broadcast_to(shifts[None, :], (BLOCK_SIZE_Y * BLOCK_SIZE_X, 8))
|
66
|
+
shifts = tl.reshape(shifts, (BLOCK_SIZE_Y, BLOCK_SIZE_X * 8))
|
67
|
+
|
68
|
+
# Unpack and reorder: shift out the correct 4-bit value and mask.
|
69
|
+
iweights = (iweights >> shifts) & 0xF
|
70
|
+
|
71
|
+
# Compute zero offsets and masks.
|
72
|
+
zero_offsets_y = pid_y * BLOCK_SIZE_Y // group_size + tl.arange(0, 1)
|
73
|
+
zero_offsets_x = pid_x * BLOCK_SIZE_X + tl.arange(0, BLOCK_SIZE_X)
|
74
|
+
zero_offsets = num_cols * zero_offsets_y[:, None] + zero_offsets_x[None, :]
|
75
|
+
|
76
|
+
zero_masks_y = zero_offsets_y < num_rows // group_size
|
77
|
+
zero_masks_x = zero_offsets_x < num_cols
|
78
|
+
zero_masks = zero_masks_y[:, None] & zero_masks_x[None, :]
|
79
|
+
|
80
|
+
# Load the zeros.
|
81
|
+
zeros = tl.load(zeros_ptr + zero_offsets, zero_masks, 0.0)
|
82
|
+
zeros = tl.interleave(zeros, zeros)
|
83
|
+
zeros = tl.interleave(zeros, zeros)
|
84
|
+
zeros = tl.interleave(zeros, zeros)
|
85
|
+
zeros = tl.broadcast_to(zeros, (BLOCK_SIZE_Y, BLOCK_SIZE_X * 8))
|
86
|
+
|
87
|
+
# Unpack and reorder: shift out the correct 4-bit value and mask.
|
88
|
+
zeros = (zeros >> shifts) & 0xF
|
89
|
+
|
90
|
+
# Compute scale offsets and masks.
|
91
|
+
scale_offsets_y = pid_y * BLOCK_SIZE_Y // group_size + tl.arange(0, 1)
|
92
|
+
scale_offsets_x = pid_x * BLOCK_SIZE_X * 8 + tl.arange(0, BLOCK_SIZE_X * 8)
|
93
|
+
scale_offsets = num_cols * 8 * scale_offsets_y[:, None] + scale_offsets_x[None, :]
|
94
|
+
scale_masks_y = scale_offsets_y < num_rows // group_size
|
95
|
+
scale_masks_x = scale_offsets_x < num_cols * 8
|
96
|
+
scale_masks = scale_masks_y[:, None] & scale_masks_x[None, :]
|
97
|
+
|
98
|
+
# Load the scales.
|
99
|
+
scales = tl.load(scales_ptr + scale_offsets, scale_masks, 0.0)
|
100
|
+
scales = tl.broadcast_to(scales, (BLOCK_SIZE_Y, BLOCK_SIZE_X * 8))
|
101
|
+
|
102
|
+
# Dequantize.
|
103
|
+
iweights = (iweights - zeros) * scales
|
104
|
+
iweights = iweights.to(result_ptr.type.element_ty)
|
105
|
+
|
106
|
+
# Finally, store.
|
107
|
+
tl.store(result_ptr + result_offsets, iweights, result_masks)
|
108
|
+
|
109
|
+
|
110
|
+
@triton.jit
|
111
|
+
def awq_gemm_kernel(
|
112
|
+
a_ptr,
|
113
|
+
b_ptr,
|
114
|
+
c_ptr,
|
115
|
+
zeros_ptr,
|
116
|
+
scales_ptr,
|
117
|
+
M,
|
118
|
+
N,
|
119
|
+
K,
|
120
|
+
group_size,
|
121
|
+
BLOCK_SIZE_M: tl.constexpr,
|
122
|
+
BLOCK_SIZE_N: tl.constexpr,
|
123
|
+
BLOCK_SIZE_K: tl.constexpr,
|
124
|
+
SPLIT_K: tl.constexpr,
|
125
|
+
):
|
126
|
+
pid = tl.program_id(axis=0)
|
127
|
+
pid_z = tl.program_id(1)
|
128
|
+
|
129
|
+
# NOTE: This doesn't work in TRITON_INTERPRET=1 mode. Use below instead.
|
130
|
+
# num_pid_n = (N + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N
|
131
|
+
num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
|
132
|
+
|
133
|
+
pid_m = pid // num_pid_n
|
134
|
+
pid_n = pid % num_pid_n
|
135
|
+
|
136
|
+
accumulator_dtype = c_ptr.type.element_ty
|
137
|
+
|
138
|
+
# NOTE: This doesn't work in TRITON_INTERPRET=1 mode. Use below instead.
|
139
|
+
# accumulator = tl.arange(0, BLOCK_SIZE_N)
|
140
|
+
# accumulator = tl.broadcast_to(accumulator[None, :],
|
141
|
+
# (BLOCK_SIZE_M, BLOCK_SIZE_N))
|
142
|
+
# accumulator = accumulator & 0x0
|
143
|
+
# accumulator = accumulator.to(accumulator_dtype)
|
144
|
+
accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=accumulator_dtype)
|
145
|
+
|
146
|
+
# Create reverse AWQ order as tensor: [0, 4, 1, 5, 2, 6, 3, 7]
|
147
|
+
# that will map given indices to the correct order.
|
148
|
+
reverse_awq_order_tensor = (
|
149
|
+
(tl.arange(0, 2) * 4)[None, :] + tl.arange(0, 4)[:, None]
|
150
|
+
).reshape(8)
|
151
|
+
|
152
|
+
# Create the necessary shifts to use to unpack.
|
153
|
+
shifts = reverse_awq_order_tensor * 4
|
154
|
+
shifts = tl.broadcast_to(shifts[None, :], (BLOCK_SIZE_K * (BLOCK_SIZE_N // 8), 8))
|
155
|
+
shifts = tl.reshape(shifts, (BLOCK_SIZE_K, BLOCK_SIZE_N))
|
156
|
+
|
157
|
+
# Offsets and masks.
|
158
|
+
offsets_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
159
|
+
masks_am = offsets_am < M
|
160
|
+
|
161
|
+
offsets_bn = pid_n * (BLOCK_SIZE_N // 8) + tl.arange(0, BLOCK_SIZE_N // 8)
|
162
|
+
masks_bn = offsets_bn < N // 8
|
163
|
+
|
164
|
+
offsets_zn = pid_n * (BLOCK_SIZE_N // 8) + tl.arange(0, BLOCK_SIZE_N // 8)
|
165
|
+
masks_zn = offsets_zn < N // 8
|
166
|
+
|
167
|
+
offsets_sn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
168
|
+
masks_sn = offsets_sn < N
|
169
|
+
|
170
|
+
offsets_k = pid_z * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
|
171
|
+
offsets_a = K * offsets_am[:, None] + offsets_k[None, :]
|
172
|
+
offsets_b = (N // 8) * offsets_k[:, None] + offsets_bn[None, :]
|
173
|
+
|
174
|
+
a_ptrs = a_ptr + offsets_a
|
175
|
+
b_ptrs = b_ptr + offsets_b
|
176
|
+
|
177
|
+
# NOTE: Use this in TRITON_INTERPRET=1 mode instead of tl.cdiv
|
178
|
+
# block_offset = BLOCK_SIZE_K * SPLIT_K
|
179
|
+
# for k in range(0, (K + block_offset - 1) // (block_offset)):
|
180
|
+
for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)):
|
181
|
+
masks_k = offsets_k < K
|
182
|
+
masks_a = masks_am[:, None] & masks_k[None, :]
|
183
|
+
a = tl.load(a_ptrs, mask=masks_a, other=0.0)
|
184
|
+
|
185
|
+
masks_b = masks_k[:, None] & masks_bn[None, :]
|
186
|
+
b = tl.load(b_ptrs, mask=masks_b, other=0.0)
|
187
|
+
b = tl.interleave(b, b)
|
188
|
+
b = tl.interleave(b, b)
|
189
|
+
b = tl.interleave(b, b)
|
190
|
+
|
191
|
+
# Dequantize b.
|
192
|
+
offsets_szk = (
|
193
|
+
BLOCK_SIZE_K * SPLIT_K * k + pid_z * BLOCK_SIZE_K
|
194
|
+
) // group_size + tl.arange(0, 1)
|
195
|
+
offsets_z = (N // 8) * offsets_szk[:, None] + offsets_zn[None, :]
|
196
|
+
masks_zk = offsets_szk < K // group_size
|
197
|
+
masks_z = masks_zk[:, None] & masks_zn[None, :]
|
198
|
+
zeros_ptrs = zeros_ptr + offsets_z
|
199
|
+
zeros = tl.load(zeros_ptrs, mask=masks_z, other=0.0)
|
200
|
+
zeros = tl.interleave(zeros, zeros)
|
201
|
+
zeros = tl.interleave(zeros, zeros)
|
202
|
+
zeros = tl.interleave(zeros, zeros)
|
203
|
+
zeros = tl.broadcast_to(zeros, (BLOCK_SIZE_K, BLOCK_SIZE_N))
|
204
|
+
|
205
|
+
offsets_s = N * offsets_szk[:, None] + offsets_sn[None, :]
|
206
|
+
masks_sk = offsets_szk < K // group_size
|
207
|
+
masks_s = masks_sk[:, None] & masks_sn[None, :]
|
208
|
+
scales_ptrs = scales_ptr + offsets_s
|
209
|
+
scales = tl.load(scales_ptrs, mask=masks_s, other=0.0)
|
210
|
+
scales = tl.broadcast_to(scales, (BLOCK_SIZE_K, BLOCK_SIZE_N))
|
211
|
+
|
212
|
+
b = (b >> shifts) & 0xF
|
213
|
+
zeros = (zeros >> shifts) & 0xF
|
214
|
+
b = (b - zeros) * scales
|
215
|
+
b = b.to(c_ptr.type.element_ty)
|
216
|
+
|
217
|
+
# Accumulate results.
|
218
|
+
accumulator = tl.dot(a, b, accumulator, out_dtype=accumulator_dtype)
|
219
|
+
|
220
|
+
offsets_k += BLOCK_SIZE_K * SPLIT_K
|
221
|
+
a_ptrs += BLOCK_SIZE_K * SPLIT_K
|
222
|
+
b_ptrs += BLOCK_SIZE_K * SPLIT_K * (N // 8)
|
223
|
+
|
224
|
+
c = accumulator.to(c_ptr.type.element_ty)
|
225
|
+
offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
|
226
|
+
offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
|
227
|
+
c_ptrs = c_ptr + pid_z * N * M + N * offs_cm[:, None] + offs_cn[None, :]
|
228
|
+
c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
|
229
|
+
tl.store(c_ptrs, c, mask=c_mask)
|
230
|
+
|
231
|
+
|
232
|
+
# qweights - [K , M // 8], int32
|
233
|
+
# scales - [K // G, M ], float16
|
234
|
+
# zeros - [K // G, M // 8], int32
|
235
|
+
def awq_dequantize_triton(
|
236
|
+
qweight: torch.Tensor,
|
237
|
+
scales: torch.Tensor,
|
238
|
+
zeros: torch.Tensor,
|
239
|
+
block_size_x: int = 32,
|
240
|
+
block_size_y: int = 32,
|
241
|
+
) -> torch.Tensor:
|
242
|
+
K = qweight.shape[0]
|
243
|
+
M = scales.shape[1]
|
244
|
+
group_size = qweight.shape[0] // scales.shape[0]
|
245
|
+
|
246
|
+
assert K > 0 and M > 0
|
247
|
+
assert scales.shape[0] == K // group_size and scales.shape[1] == M
|
248
|
+
assert zeros.shape[0] == K // group_size and zeros.shape[1] == M // 8
|
249
|
+
assert group_size <= K
|
250
|
+
assert group_size in AWQ_TRITON_SUPPORTED_GROUP_SIZES or group_size == K
|
251
|
+
|
252
|
+
# Result tensor:
|
253
|
+
# number of rows = same as input tensor
|
254
|
+
# number of cols = 8 x input tensor num cols
|
255
|
+
result = torch.empty(
|
256
|
+
qweight.shape[0],
|
257
|
+
qweight.shape[1] * 8,
|
258
|
+
device=qweight.device,
|
259
|
+
dtype=scales.dtype,
|
260
|
+
)
|
261
|
+
|
262
|
+
Y = qweight.shape[0] # num rows
|
263
|
+
X = qweight.shape[1] # num cols
|
264
|
+
|
265
|
+
grid = lambda META: (
|
266
|
+
triton.cdiv(X, META["BLOCK_SIZE_X"]),
|
267
|
+
triton.cdiv(Y, META["BLOCK_SIZE_Y"]),
|
268
|
+
)
|
269
|
+
awq_dequantize_kernel[grid](
|
270
|
+
qweight,
|
271
|
+
scales,
|
272
|
+
zeros,
|
273
|
+
group_size,
|
274
|
+
result,
|
275
|
+
X,
|
276
|
+
Y,
|
277
|
+
BLOCK_SIZE_X=block_size_x,
|
278
|
+
BLOCK_SIZE_Y=block_size_y,
|
279
|
+
)
|
280
|
+
|
281
|
+
return result
|
282
|
+
|
283
|
+
|
284
|
+
# input - [M, K]
|
285
|
+
# qweight - [K, N // 8]
|
286
|
+
# qzeros - [K // G, N // 8]
|
287
|
+
# scales - [K // G, N]
|
288
|
+
# split_k_iters - parallelism along K-dimension, int, power of 2.
|
289
|
+
def awq_gemm_triton(
|
290
|
+
input: torch.Tensor,
|
291
|
+
qweight: torch.Tensor,
|
292
|
+
scales: torch.Tensor,
|
293
|
+
qzeros: torch.Tensor,
|
294
|
+
split_k_iters: int,
|
295
|
+
block_size_m: int = 32,
|
296
|
+
block_size_n: int = 32,
|
297
|
+
block_size_k: int = 32,
|
298
|
+
) -> torch.Tensor:
|
299
|
+
M, K = input.shape
|
300
|
+
N = qweight.shape[1] * 8
|
301
|
+
group_size = qweight.shape[0] // qzeros.shape[0]
|
302
|
+
|
303
|
+
assert N > 0 and K > 0 and M > 0
|
304
|
+
assert qweight.shape[0] == K and qweight.shape[1] == N // 8
|
305
|
+
assert qzeros.shape[0] == K // group_size and qzeros.shape[1] == N // 8
|
306
|
+
assert scales.shape[0] == K // group_size and scales.shape[1] == N
|
307
|
+
assert split_k_iters & (split_k_iters - 1) == 0 and split_k_iters != 0
|
308
|
+
assert split_k_iters <= 32
|
309
|
+
assert group_size <= K
|
310
|
+
assert group_size in AWQ_TRITON_SUPPORTED_GROUP_SIZES or group_size == K
|
311
|
+
|
312
|
+
grid = lambda META: (
|
313
|
+
triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
|
314
|
+
split_k_iters,
|
315
|
+
)
|
316
|
+
|
317
|
+
result = torch.zeros((split_k_iters, M, N), dtype=scales.dtype, device=input.device)
|
318
|
+
|
319
|
+
# A = input, B = qweight, C = result
|
320
|
+
# A = M x K, B = K x N, C = M x N
|
321
|
+
awq_gemm_kernel[grid](
|
322
|
+
input,
|
323
|
+
qweight,
|
324
|
+
result,
|
325
|
+
qzeros,
|
326
|
+
scales,
|
327
|
+
M,
|
328
|
+
N,
|
329
|
+
K,
|
330
|
+
group_size,
|
331
|
+
BLOCK_SIZE_M=block_size_m,
|
332
|
+
BLOCK_SIZE_N=block_size_n,
|
333
|
+
BLOCK_SIZE_K=block_size_k,
|
334
|
+
SPLIT_K=split_k_iters,
|
335
|
+
)
|
336
|
+
|
337
|
+
result = result.sum(0)
|
338
|
+
|
339
|
+
return result
|
@@ -1,12 +1,16 @@
|
|
1
1
|
# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/base_config.py
|
2
|
+
from __future__ import annotations
|
2
3
|
|
3
4
|
import inspect
|
4
5
|
from abc import ABC, abstractmethod
|
5
|
-
from typing import Any, Dict, List, Optional, Type
|
6
|
+
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type
|
6
7
|
|
7
8
|
import torch
|
8
9
|
from torch import nn
|
9
10
|
|
11
|
+
if TYPE_CHECKING:
|
12
|
+
from sglang.srt.layers.moe.topk import TopKOutput
|
13
|
+
|
10
14
|
|
11
15
|
class QuantizeMethodBase(ABC):
|
12
16
|
"""Base class for different quantized methods."""
|
@@ -18,14 +22,14 @@ class QuantizeMethodBase(ABC):
|
|
18
22
|
"""Create weights for a layer.
|
19
23
|
|
20
24
|
The weights will be set as attributes of the layer."""
|
21
|
-
raise NotImplementedError
|
25
|
+
raise NotImplementedError()
|
22
26
|
|
23
27
|
@abstractmethod
|
24
28
|
def apply(self, layer: torch.nn.Module, *args, **kwargs) -> torch.Tensor:
|
25
29
|
"""Apply the weights in layer to the input tensor.
|
26
30
|
|
27
31
|
Expects create_weights to have been called before on the layer."""
|
28
|
-
raise NotImplementedError
|
32
|
+
raise NotImplementedError()
|
29
33
|
|
30
34
|
def process_weights_after_loading(self, layer: nn.Module) -> None:
|
31
35
|
"""Process the weight after loading.
|
@@ -35,6 +39,77 @@ class QuantizeMethodBase(ABC):
|
|
35
39
|
return
|
36
40
|
|
37
41
|
|
42
|
+
class LinearMethodBase(QuantizeMethodBase):
|
43
|
+
"""Base class for different (maybe quantized) linear methods."""
|
44
|
+
|
45
|
+
@abstractmethod
|
46
|
+
def create_weights(
|
47
|
+
self,
|
48
|
+
layer: torch.nn.Module,
|
49
|
+
input_size_per_partition: int,
|
50
|
+
output_partition_sizes: List[int],
|
51
|
+
input_size: int,
|
52
|
+
output_size: int,
|
53
|
+
params_dtype: torch.dtype,
|
54
|
+
**extra_weight_attrs,
|
55
|
+
):
|
56
|
+
"""Create weights for a linear layer.
|
57
|
+
The weights will be set as attributes of the layer.
|
58
|
+
|
59
|
+
Args:
|
60
|
+
layer: The layer that is using the LinearMethodBase factory.
|
61
|
+
input_size_per_partition: Size of the weight input dim on rank X.
|
62
|
+
output_partition_sizes: Sizes of the output dim of each logical
|
63
|
+
weight on rank X. E.g., output_partition_sizes for QKVLinear
|
64
|
+
is a list contains the width of Wq, Wk, Wv on rank X.
|
65
|
+
input_size: Size of the input dim of the weight across all ranks.
|
66
|
+
output_size: Size of the output dim of the weight across all ranks.
|
67
|
+
params_dtype: Datatype of the parameters.
|
68
|
+
"""
|
69
|
+
raise NotImplementedError()
|
70
|
+
|
71
|
+
@abstractmethod
|
72
|
+
def apply(
|
73
|
+
self,
|
74
|
+
layer: torch.nn.Module,
|
75
|
+
x: torch.Tensor,
|
76
|
+
bias: Optional[torch.Tensor] = None,
|
77
|
+
) -> torch.Tensor:
|
78
|
+
"""Apply the weights in layer to the input tensor.
|
79
|
+
Expects create_weights to have been called before on the layer."""
|
80
|
+
raise NotImplementedError()
|
81
|
+
|
82
|
+
|
83
|
+
class FusedMoEMethodBase(QuantizeMethodBase):
|
84
|
+
|
85
|
+
@abstractmethod
|
86
|
+
def create_weights(
|
87
|
+
self,
|
88
|
+
layer: torch.nn.Module,
|
89
|
+
num_experts: int,
|
90
|
+
hidden_size: int,
|
91
|
+
intermediate_size: int,
|
92
|
+
params_dtype: torch.dtype,
|
93
|
+
**extra_weight_attrs,
|
94
|
+
):
|
95
|
+
raise NotImplementedError
|
96
|
+
|
97
|
+
@abstractmethod
|
98
|
+
def apply(
|
99
|
+
self,
|
100
|
+
layer: torch.nn.Module,
|
101
|
+
x: torch.Tensor,
|
102
|
+
topk_output: TopKOutput,
|
103
|
+
*,
|
104
|
+
activation: str = "silu",
|
105
|
+
apply_router_weight_on_input: bool = False,
|
106
|
+
inplace: bool = True,
|
107
|
+
no_combine: bool = False,
|
108
|
+
routed_scaling_factor: Optional[float] = None,
|
109
|
+
) -> torch.Tensor:
|
110
|
+
raise NotImplementedError
|
111
|
+
|
112
|
+
|
38
113
|
class QuantizationConfig(ABC):
|
39
114
|
"""Base class for quantization configs."""
|
40
115
|
|
@@ -46,12 +121,12 @@ class QuantizationConfig(ABC):
|
|
46
121
|
@abstractmethod
|
47
122
|
def get_name(self) -> str:
|
48
123
|
"""Name of the quantization method."""
|
49
|
-
raise NotImplementedError
|
124
|
+
raise NotImplementedError()
|
50
125
|
|
51
126
|
@abstractmethod
|
52
127
|
def get_supported_act_dtypes(self) -> List[torch.dtype]:
|
53
128
|
"""List of supported activation dtypes."""
|
54
|
-
raise NotImplementedError
|
129
|
+
raise NotImplementedError()
|
55
130
|
|
56
131
|
@classmethod
|
57
132
|
@abstractmethod
|
@@ -62,19 +137,19 @@ class QuantizationConfig(ABC):
|
|
62
137
|
This requirement is due to the custom CUDA kernels used by the
|
63
138
|
quantization method.
|
64
139
|
"""
|
65
|
-
raise NotImplementedError
|
140
|
+
raise NotImplementedError()
|
66
141
|
|
67
142
|
@staticmethod
|
68
143
|
@abstractmethod
|
69
144
|
def get_config_filenames() -> List[str]:
|
70
145
|
"""List of filenames to search for in the model directory."""
|
71
|
-
raise NotImplementedError
|
146
|
+
raise NotImplementedError()
|
72
147
|
|
73
148
|
@classmethod
|
74
149
|
@abstractmethod
|
75
150
|
def from_config(cls, config: Dict[str, Any]) -> "QuantizationConfig":
|
76
151
|
"""Create a config class from the model's quantization config."""
|
77
|
-
raise NotImplementedError
|
152
|
+
raise NotImplementedError()
|
78
153
|
|
79
154
|
@classmethod
|
80
155
|
def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]:
|
@@ -117,7 +192,7 @@ class QuantizationConfig(ABC):
|
|
117
192
|
The quantize method. None if the given layer doesn't support quant
|
118
193
|
method.
|
119
194
|
"""
|
120
|
-
raise NotImplementedError
|
195
|
+
raise NotImplementedError()
|
121
196
|
|
122
197
|
@abstractmethod
|
123
198
|
def get_scaled_act_names(self) -> List[str]:
|
@@ -125,7 +200,7 @@ class QuantizationConfig(ABC):
|
|
125
200
|
|
126
201
|
For now, this is only used by AWQ.
|
127
202
|
"""
|
128
|
-
raise NotImplementedError
|
203
|
+
raise NotImplementedError()
|
129
204
|
|
130
205
|
|
131
206
|
def method_has_implemented_embedding(method_class: Type[QuantizeMethodBase]) -> bool:
|
@@ -1,26 +1,29 @@
|
|
1
1
|
# Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py
|
2
2
|
|
3
|
+
from __future__ import annotations
|
4
|
+
|
3
5
|
import logging
|
4
|
-
from typing import Any, Callable, Dict, List, Optional
|
6
|
+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
|
5
7
|
|
6
8
|
import torch
|
7
9
|
from torch.nn import Module
|
8
10
|
|
9
11
|
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
10
|
-
from sglang.srt.layers.linear import (
|
11
|
-
LinearBase,
|
12
|
-
LinearMethodBase,
|
13
|
-
UnquantizedLinearMethod,
|
14
|
-
)
|
15
12
|
from sglang.srt.layers.parameter import BlockQuantScaleParameter, ModelWeightParameter
|
16
13
|
from sglang.srt.layers.quantization.base_config import (
|
14
|
+
FusedMoEMethodBase,
|
15
|
+
LinearMethodBase,
|
17
16
|
QuantizationConfig,
|
18
17
|
QuantizeMethodBase,
|
19
18
|
)
|
20
19
|
from sglang.srt.layers.quantization.int8_utils import apply_w8a8_block_int8_linear
|
20
|
+
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
21
21
|
from sglang.srt.layers.quantization.utils import is_layer_skipped
|
22
22
|
from sglang.srt.utils import set_weight_attrs
|
23
23
|
|
24
|
+
if TYPE_CHECKING:
|
25
|
+
from sglang.srt.layers.moe.topk import TopKOutput
|
26
|
+
|
24
27
|
ACTIVATION_SCHEMES = ["static", "dynamic"]
|
25
28
|
|
26
29
|
logger = logging.getLogger(__name__)
|
@@ -78,7 +81,7 @@ class BlockInt8Config(QuantizationConfig):
|
|
78
81
|
return []
|
79
82
|
|
80
83
|
@classmethod
|
81
|
-
def from_config(cls, config: Dict[str, Any]) ->
|
84
|
+
def from_config(cls, config: Dict[str, Any]) -> BlockInt8Config:
|
82
85
|
quant_method = cls.get_from_keys(config, ["quant_method"])
|
83
86
|
is_checkpoint_int8_serialized = "int8" in quant_method
|
84
87
|
activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
|
@@ -93,7 +96,8 @@ class BlockInt8Config(QuantizationConfig):
|
|
93
96
|
|
94
97
|
def get_quant_method(
|
95
98
|
self, layer: torch.nn.Module, prefix: str
|
96
|
-
) -> Optional[
|
99
|
+
) -> Optional[QuantizeMethodBase]:
|
100
|
+
from sglang.srt.layers.linear import LinearBase
|
97
101
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
98
102
|
|
99
103
|
if isinstance(layer, LinearBase):
|
@@ -230,7 +234,7 @@ class BlockInt8LinearMethod(LinearMethodBase):
|
|
230
234
|
)
|
231
235
|
|
232
236
|
|
233
|
-
class BlockInt8MoEMethod:
|
237
|
+
class BlockInt8MoEMethod(FusedMoEMethodBase):
|
234
238
|
"""MoE method for INT8.
|
235
239
|
Supports loading INT8 checkpoints with static weight scale and
|
236
240
|
dynamic activation scale.
|
@@ -242,25 +246,7 @@ class BlockInt8MoEMethod:
|
|
242
246
|
quant_config: The quantization config.
|
243
247
|
"""
|
244
248
|
|
245
|
-
def
|
246
|
-
from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
|
247
|
-
|
248
|
-
if not hasattr(cls, "_initialized"):
|
249
|
-
original_init = cls.__init__
|
250
|
-
new_cls = type(
|
251
|
-
cls.__name__,
|
252
|
-
(FusedMoEMethodBase,),
|
253
|
-
{
|
254
|
-
"__init__": original_init,
|
255
|
-
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
|
256
|
-
},
|
257
|
-
)
|
258
|
-
obj = super(new_cls, new_cls).__new__(new_cls)
|
259
|
-
obj.__init__(*args, **kwargs)
|
260
|
-
return obj
|
261
|
-
return super().__new__(cls)
|
262
|
-
|
263
|
-
def __init__(self, quant_config):
|
249
|
+
def __init__(self, quant_config: BlockInt8Config):
|
264
250
|
self.quant_config = quant_config
|
265
251
|
assert self.quant_config.weight_block_size is not None
|
266
252
|
assert self.quant_config.is_checkpoint_int8_serialized
|
@@ -361,15 +347,8 @@ class BlockInt8MoEMethod:
|
|
361
347
|
self,
|
362
348
|
layer: torch.nn.Module,
|
363
349
|
x: torch.Tensor,
|
364
|
-
|
365
|
-
|
366
|
-
renormalize: bool,
|
367
|
-
use_grouped_topk: bool,
|
368
|
-
topk_group: Optional[int] = None,
|
369
|
-
num_expert_group: Optional[int] = None,
|
370
|
-
num_fused_shared_experts: int = 0,
|
371
|
-
custom_routing_function: Optional[Callable] = None,
|
372
|
-
correction_bias: Optional[torch.Tensor] = None,
|
350
|
+
topk_output: TopKOutput,
|
351
|
+
*,
|
373
352
|
activation: str = "silu",
|
374
353
|
apply_router_weight_on_input: bool = False,
|
375
354
|
inplace: bool = True,
|
@@ -377,30 +356,13 @@ class BlockInt8MoEMethod:
|
|
377
356
|
routed_scaling_factor: Optional[float] = None,
|
378
357
|
) -> torch.Tensor:
|
379
358
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
380
|
-
from sglang.srt.layers.moe.topk import select_experts
|
381
|
-
|
382
|
-
# Expert selection
|
383
|
-
topk_weights, topk_ids = select_experts(
|
384
|
-
hidden_states=x,
|
385
|
-
router_logits=router_logits,
|
386
|
-
use_grouped_topk=use_grouped_topk,
|
387
|
-
top_k=top_k,
|
388
|
-
renormalize=renormalize,
|
389
|
-
topk_group=topk_group,
|
390
|
-
num_expert_group=num_expert_group,
|
391
|
-
num_fused_shared_experts=num_fused_shared_experts,
|
392
|
-
custom_routing_function=custom_routing_function,
|
393
|
-
correction_bias=correction_bias,
|
394
|
-
routed_scaling_factor=routed_scaling_factor,
|
395
|
-
)
|
396
359
|
|
397
360
|
# Expert fusion with INT8 quantization
|
398
361
|
return fused_experts(
|
399
362
|
x,
|
400
363
|
layer.w13_weight,
|
401
364
|
layer.w2_weight,
|
402
|
-
|
403
|
-
topk_ids=topk_ids,
|
365
|
+
topk_output=topk_output,
|
404
366
|
inplace=inplace,
|
405
367
|
activation=activation,
|
406
368
|
apply_router_weight_on_input=apply_router_weight_on_input,
|