sglang 0.4.4.post1__py3-none-any.whl → 0.4.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +6 -0
- sglang/bench_one_batch.py +1 -1
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +3 -1
- sglang/check_env.py +3 -4
- sglang/lang/backend/openai.py +18 -5
- sglang/lang/chat_template.py +28 -7
- sglang/lang/interpreter.py +7 -3
- sglang/lang/ir.py +10 -0
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/code_completion_parser.py +174 -0
- sglang/srt/configs/__init__.py +2 -6
- sglang/srt/configs/deepseekvl2.py +667 -0
- sglang/srt/configs/janus_pro.py +3 -4
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +63 -11
- sglang/srt/configs/utils.py +25 -0
- sglang/srt/connector/__init__.py +51 -0
- sglang/srt/connector/base_connector.py +112 -0
- sglang/srt/connector/redis.py +85 -0
- sglang/srt/connector/s3.py +122 -0
- sglang/srt/connector/serde/__init__.py +31 -0
- sglang/srt/connector/serde/safe_serde.py +29 -0
- sglang/srt/connector/serde/serde.py +43 -0
- sglang/srt/connector/utils.py +35 -0
- sglang/srt/conversation.py +88 -0
- sglang/srt/disaggregation/conn.py +81 -0
- sglang/srt/disaggregation/decode.py +495 -0
- sglang/srt/disaggregation/mini_lb.py +285 -0
- sglang/srt/disaggregation/prefill.py +249 -0
- sglang/srt/disaggregation/utils.py +44 -0
- sglang/srt/distributed/parallel_state.py +10 -3
- sglang/srt/entrypoints/engine.py +55 -5
- sglang/srt/entrypoints/http_server.py +71 -12
- sglang/srt/function_call_parser.py +133 -54
- sglang/srt/hf_transformers_utils.py +28 -3
- sglang/srt/layers/activation.py +4 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +295 -0
- sglang/srt/layers/attention/flashinfer_backend.py +1 -1
- sglang/srt/layers/attention/flashmla_backend.py +284 -0
- sglang/srt/layers/attention/triton_backend.py +171 -38
- sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
- sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
- sglang/srt/layers/attention/utils.py +53 -0
- sglang/srt/layers/attention/vision.py +9 -28
- sglang/srt/layers/dp_attention.py +32 -21
- sglang/srt/layers/layernorm.py +24 -2
- sglang/srt/layers/linear.py +17 -5
- sglang/srt/layers/logits_processor.py +25 -7
- sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
- sglang/srt/layers/moe/ep_moe/layer.py +273 -1
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
- sglang/srt/layers/moe/fused_moe_native.py +2 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
- sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
- sglang/srt/layers/moe/topk.py +31 -18
- sglang/srt/layers/parameter.py +1 -1
- sglang/srt/layers/quantization/__init__.py +184 -126
- sglang/srt/layers/quantization/base_config.py +5 -0
- sglang/srt/layers/quantization/blockwise_int8.py +1 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
- sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
- sglang/srt/layers/quantization/fp8.py +76 -34
- sglang/srt/layers/quantization/fp8_kernel.py +24 -8
- sglang/srt/layers/quantization/fp8_utils.py +284 -28
- sglang/srt/layers/quantization/gptq.py +36 -9
- sglang/srt/layers/quantization/kv_cache.py +98 -0
- sglang/srt/layers/quantization/modelopt_quant.py +9 -7
- sglang/srt/layers/quantization/utils.py +153 -0
- sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
- sglang/srt/layers/rotary_embedding.py +66 -87
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/layers.py +68 -0
- sglang/srt/lora/lora.py +2 -22
- sglang/srt/lora/lora_manager.py +47 -23
- sglang/srt/lora/mem_pool.py +110 -51
- sglang/srt/lora/utils.py +12 -1
- sglang/srt/managers/cache_controller.py +2 -5
- sglang/srt/managers/data_parallel_controller.py +30 -8
- sglang/srt/managers/expert_distribution.py +81 -0
- sglang/srt/managers/io_struct.py +39 -3
- sglang/srt/managers/mm_utils.py +373 -0
- sglang/srt/managers/multimodal_processor.py +68 -0
- sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
- sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
- sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
- sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
- sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
- sglang/srt/managers/schedule_batch.py +133 -30
- sglang/srt/managers/scheduler.py +273 -20
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +59 -23
- sglang/srt/managers/tp_worker.py +1 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
- sglang/srt/managers/utils.py +6 -1
- sglang/srt/mem_cache/hiradix_cache.py +18 -7
- sglang/srt/mem_cache/memory_pool.py +255 -98
- sglang/srt/mem_cache/paged_allocator.py +2 -2
- sglang/srt/mem_cache/radix_cache.py +4 -4
- sglang/srt/model_executor/cuda_graph_runner.py +27 -13
- sglang/srt/model_executor/forward_batch_info.py +68 -11
- sglang/srt/model_executor/model_runner.py +70 -6
- sglang/srt/model_loader/loader.py +160 -2
- sglang/srt/model_loader/weight_utils.py +45 -0
- sglang/srt/models/deepseek_janus_pro.py +29 -86
- sglang/srt/models/deepseek_nextn.py +22 -10
- sglang/srt/models/deepseek_v2.py +208 -77
- sglang/srt/models/deepseek_vl2.py +358 -0
- sglang/srt/models/gemma3_causal.py +684 -0
- sglang/srt/models/gemma3_mm.py +462 -0
- sglang/srt/models/llama.py +47 -7
- sglang/srt/models/llama_eagle.py +1 -0
- sglang/srt/models/llama_eagle3.py +196 -0
- sglang/srt/models/llava.py +3 -3
- sglang/srt/models/llavavid.py +3 -3
- sglang/srt/models/minicpmo.py +1995 -0
- sglang/srt/models/minicpmv.py +62 -137
- sglang/srt/models/mllama.py +4 -4
- sglang/srt/models/phi3_small.py +1 -1
- sglang/srt/models/qwen2.py +3 -0
- sglang/srt/models/qwen2_5_vl.py +68 -146
- sglang/srt/models/qwen2_classification.py +75 -0
- sglang/srt/models/qwen2_moe.py +9 -1
- sglang/srt/models/qwen2_vl.py +25 -63
- sglang/srt/openai_api/adapter.py +124 -28
- sglang/srt/openai_api/protocol.py +23 -2
- sglang/srt/sampling/sampling_batch_info.py +1 -1
- sglang/srt/sampling/sampling_params.py +6 -6
- sglang/srt/server_args.py +99 -9
- sglang/srt/speculative/build_eagle_tree.py +7 -347
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
- sglang/srt/speculative/eagle_utils.py +208 -252
- sglang/srt/speculative/eagle_worker.py +139 -53
- sglang/srt/speculative/spec_info.py +6 -1
- sglang/srt/torch_memory_saver_adapter.py +22 -0
- sglang/srt/utils.py +182 -21
- sglang/test/__init__.py +0 -0
- sglang/test/attention/__init__.py +0 -0
- sglang/test/attention/test_flashattn_backend.py +312 -0
- sglang/test/runners.py +2 -0
- sglang/test/test_activation.py +2 -1
- sglang/test/test_block_fp8.py +5 -4
- sglang/test/test_block_fp8_ep.py +2 -1
- sglang/test/test_dynamic_grad_mode.py +58 -0
- sglang/test/test_layernorm.py +3 -2
- sglang/test/test_utils.py +55 -4
- sglang/utils.py +31 -0
- sglang/version.py +1 -1
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/METADATA +12 -8
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/RECORD +167 -123
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/WHEEL +1 -1
- sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
- sglang/srt/managers/image_processor.py +0 -55
- sglang/srt/managers/image_processors/base_image_processor.py +0 -219
- sglang/srt/managers/image_processors/minicpmv.py +0 -86
- sglang/srt/managers/multi_modality_padding.py +0 -134
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info/licenses}/LICENSE +0 -0
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,658 @@
|
|
1
|
+
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
|
4
|
+
import enum
|
5
|
+
import logging
|
6
|
+
from enum import Enum
|
7
|
+
from typing import Callable, List, Optional
|
8
|
+
|
9
|
+
import torch
|
10
|
+
from compressed_tensors import CompressionFormat
|
11
|
+
from compressed_tensors.quantization import QuantizationStrategy
|
12
|
+
|
13
|
+
from sglang.srt.layers.moe.fused_moe_triton import (
|
14
|
+
FusedMoE,
|
15
|
+
FusedMoEMethodBase,
|
16
|
+
FusedMoeWeightScaleSupported,
|
17
|
+
)
|
18
|
+
from sglang.srt.layers.moe.topk import select_experts
|
19
|
+
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
|
20
|
+
from sglang.srt.layers.quantization.utils import (
|
21
|
+
all_close_1d,
|
22
|
+
is_cuda,
|
23
|
+
is_fp8_fnuz,
|
24
|
+
per_tensor_dequantize,
|
25
|
+
replace_parameter,
|
26
|
+
)
|
27
|
+
from sglang.srt.utils import set_weight_attrs
|
28
|
+
|
29
|
+
_is_cuda = is_cuda()
|
30
|
+
|
31
|
+
if _is_cuda:
|
32
|
+
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
|
33
|
+
else:
|
34
|
+
from vllm import _custom_ops as vllm_ops
|
35
|
+
|
36
|
+
try:
|
37
|
+
import vllm
|
38
|
+
|
39
|
+
VLLM_AVAILABLE = True
|
40
|
+
except ImportError:
|
41
|
+
VLLM_AVAILABLE = False
|
42
|
+
|
43
|
+
logger = logging.getLogger(__name__)
|
44
|
+
|
45
|
+
|
46
|
+
class GPTQMarlinState(Enum):
|
47
|
+
REPACK = enum.auto()
|
48
|
+
READY = enum.auto()
|
49
|
+
|
50
|
+
|
51
|
+
__all__ = [
|
52
|
+
"CompressedTensorsMoEMethod",
|
53
|
+
"CompressedTensorsW8A8Fp8MoEMethod",
|
54
|
+
"CompressedTensorsWNA16MoEMethod",
|
55
|
+
]
|
56
|
+
|
57
|
+
|
58
|
+
class CompressedTensorsMoEMethod(FusedMoEMethodBase):
|
59
|
+
|
60
|
+
@staticmethod
|
61
|
+
def get_moe_method(
|
62
|
+
quant_config: "CompressedTensorsConfig", # type: ignore # noqa E501
|
63
|
+
) -> "CompressedTensorsMoEMethod":
|
64
|
+
# TODO: @dsikka: refactor this to use schemes as other kernels
|
65
|
+
# are supported + check if the layer is being ignored.
|
66
|
+
weight_quant = quant_config.target_scheme_map["Linear"].get("weights")
|
67
|
+
input_quant = quant_config.target_scheme_map["Linear"].get("input_activations")
|
68
|
+
|
69
|
+
if quant_config._is_wNa16_group_channel(weight_quant, input_quant):
|
70
|
+
if not VLLM_AVAILABLE:
|
71
|
+
raise ImportError(
|
72
|
+
"vllm is not installed, to use CompressedTensorsWNA16MoEMethod, please install vllm"
|
73
|
+
)
|
74
|
+
return CompressedTensorsWNA16MoEMethod(quant_config)
|
75
|
+
elif quant_config._is_fp8_w8a8(weight_quant, input_quant):
|
76
|
+
return CompressedTensorsW8A8Fp8MoEMethod(quant_config)
|
77
|
+
else:
|
78
|
+
raise RuntimeError(
|
79
|
+
f"Unsupported FusedMoe scheme: {weight_quant}, {input_quant}"
|
80
|
+
)
|
81
|
+
|
82
|
+
|
83
|
+
class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
84
|
+
|
85
|
+
def __init__(
|
86
|
+
self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
|
87
|
+
):
|
88
|
+
self.quant_config = quant_config
|
89
|
+
self.weight_quant = self.quant_config.target_scheme_map["Linear"].get("weights")
|
90
|
+
self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
|
91
|
+
"input_activations"
|
92
|
+
)
|
93
|
+
|
94
|
+
if not (
|
95
|
+
self.weight_quant.strategy == QuantizationStrategy.TENSOR
|
96
|
+
and self.input_quant.strategy == QuantizationStrategy.TENSOR
|
97
|
+
):
|
98
|
+
raise ValueError(
|
99
|
+
"For FP8 Fused MoE layers, only per-tensor scales "
|
100
|
+
"for weights and activations are supported. Found "
|
101
|
+
f"{self.weight_quant}, {self.input_quant}"
|
102
|
+
)
|
103
|
+
|
104
|
+
self.static_input_scales = not self.input_quant.dynamic
|
105
|
+
|
106
|
+
def create_weights(
|
107
|
+
self,
|
108
|
+
layer: torch.nn.Module,
|
109
|
+
num_experts: int,
|
110
|
+
hidden_size: int,
|
111
|
+
intermediate_size_per_partition: int,
|
112
|
+
params_dtype: torch.dtype,
|
113
|
+
**extra_weight_attrs,
|
114
|
+
):
|
115
|
+
|
116
|
+
params_dtype = torch.float8_e4m3fn
|
117
|
+
|
118
|
+
# WEIGHTS
|
119
|
+
w13_weight = torch.nn.Parameter(
|
120
|
+
torch.empty(
|
121
|
+
num_experts,
|
122
|
+
2 * intermediate_size_per_partition,
|
123
|
+
hidden_size,
|
124
|
+
dtype=params_dtype,
|
125
|
+
),
|
126
|
+
requires_grad=False,
|
127
|
+
)
|
128
|
+
layer.register_parameter("w13_weight", w13_weight)
|
129
|
+
set_weight_attrs(w13_weight, extra_weight_attrs)
|
130
|
+
|
131
|
+
w2_weight = torch.nn.Parameter(
|
132
|
+
torch.empty(
|
133
|
+
num_experts,
|
134
|
+
hidden_size,
|
135
|
+
intermediate_size_per_partition,
|
136
|
+
dtype=params_dtype,
|
137
|
+
),
|
138
|
+
requires_grad=False,
|
139
|
+
)
|
140
|
+
layer.register_parameter("w2_weight", w2_weight)
|
141
|
+
set_weight_attrs(w2_weight, extra_weight_attrs)
|
142
|
+
|
143
|
+
# WEIGHT_SCALES
|
144
|
+
# Allocate 2 scales for w1 and w3 respectively.
|
145
|
+
# They will be combined to a single scale after weight loading.
|
146
|
+
w13_weight_scale = torch.nn.Parameter(
|
147
|
+
torch.ones(num_experts, 2, dtype=torch.float32), requires_grad=False
|
148
|
+
)
|
149
|
+
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
150
|
+
|
151
|
+
w2_weight_scale = torch.nn.Parameter(
|
152
|
+
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
|
153
|
+
)
|
154
|
+
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
155
|
+
# Add the quantization method used (per tensor/grouped/channel)
|
156
|
+
# to ensure the weight scales are loaded in properly
|
157
|
+
extra_weight_attrs.update(
|
158
|
+
{"quant_method": FusedMoeWeightScaleSupported.TENSOR.value}
|
159
|
+
)
|
160
|
+
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
161
|
+
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
162
|
+
|
163
|
+
# INPUT_SCALES
|
164
|
+
if self.static_input_scales:
|
165
|
+
w13_input_scale = torch.nn.Parameter(
|
166
|
+
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
|
167
|
+
)
|
168
|
+
layer.register_parameter("w13_input_scale", w13_input_scale)
|
169
|
+
set_weight_attrs(w13_input_scale, extra_weight_attrs)
|
170
|
+
|
171
|
+
w2_input_scale = torch.nn.Parameter(
|
172
|
+
torch.ones(num_experts, dtype=torch.float32), requires_grad=False
|
173
|
+
)
|
174
|
+
layer.register_parameter("w2_input_scale", w2_input_scale)
|
175
|
+
set_weight_attrs(w2_input_scale, extra_weight_attrs)
|
176
|
+
else:
|
177
|
+
layer.w13_input_scale = None
|
178
|
+
layer.w2_input_scale = None
|
179
|
+
|
180
|
+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
181
|
+
# Fp8 moe kernels require a single activation scale.
|
182
|
+
# We take the max of all the scales in case they differ.
|
183
|
+
if self.static_input_scales:
|
184
|
+
if layer.w13_input_scale is None or layer.w2_input_scale is None:
|
185
|
+
raise ValueError(
|
186
|
+
"QuantConfig has static quantization, but found "
|
187
|
+
"activation scales are None."
|
188
|
+
)
|
189
|
+
if not all_close_1d(layer.w13_input_scale) or not all_close_1d(
|
190
|
+
layer.w2_input_scale
|
191
|
+
):
|
192
|
+
logger.warning(
|
193
|
+
"Found input_scales that are not equal for "
|
194
|
+
"fp8 MoE layer. Using the maximum across experts "
|
195
|
+
"for each layer."
|
196
|
+
)
|
197
|
+
layer.w13_input_scale = torch.nn.Parameter(
|
198
|
+
layer.w13_input_scale.max(), requires_grad=False
|
199
|
+
)
|
200
|
+
layer.w2_input_scale = torch.nn.Parameter(
|
201
|
+
layer.w2_input_scale.max(), requires_grad=False
|
202
|
+
)
|
203
|
+
|
204
|
+
if is_fp8_fnuz():
|
205
|
+
# Normalize the weights and scales
|
206
|
+
w13_weight, w13_weight_scale, w13_input_scale = (
|
207
|
+
normalize_e4m3fn_to_e4m3fnuz(
|
208
|
+
layer.w13_weight, layer.w13_weight_scale, layer.w13_input_scale
|
209
|
+
)
|
210
|
+
)
|
211
|
+
w2_weight, w2_weight_scale, w2_input_scale = normalize_e4m3fn_to_e4m3fnuz(
|
212
|
+
layer.w2_weight, layer.w2_weight_scale, layer.w2_input_scale
|
213
|
+
)
|
214
|
+
# Reset the parameter
|
215
|
+
layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
|
216
|
+
layer.w13_weight_scale = torch.nn.Parameter(
|
217
|
+
w13_weight_scale, requires_grad=False
|
218
|
+
)
|
219
|
+
if w13_input_scale is not None:
|
220
|
+
layer.w13_input_scale = torch.nn.Parameter(
|
221
|
+
w13_input_scale, requires_grad=False
|
222
|
+
)
|
223
|
+
layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
|
224
|
+
layer.w2_weight_scale = torch.nn.Parameter(
|
225
|
+
w2_weight_scale, requires_grad=False
|
226
|
+
)
|
227
|
+
if w2_input_scale is not None:
|
228
|
+
layer.w2_input_scale = torch.nn.Parameter(
|
229
|
+
w2_input_scale, requires_grad=False
|
230
|
+
)
|
231
|
+
|
232
|
+
# Fp8 moe kernel needs single weight scale for w13 per expert.
|
233
|
+
# We take the max then dequant and requant each expert.
|
234
|
+
assert layer.w13_weight_scale is not None
|
235
|
+
shard_size = layer.intermediate_size_per_partition
|
236
|
+
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
|
237
|
+
for expert_id in range(layer.local_num_experts):
|
238
|
+
start = 0
|
239
|
+
for shard_id in range(2):
|
240
|
+
dq_weight = per_tensor_dequantize(
|
241
|
+
layer.w13_weight[expert_id][start : start + shard_size, :],
|
242
|
+
layer.w13_weight_scale[expert_id][shard_id],
|
243
|
+
)
|
244
|
+
|
245
|
+
if _is_cuda:
|
246
|
+
layer.w13_weight[expert_id][start : start + shard_size, :], _ = (
|
247
|
+
sgl_scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
|
248
|
+
)
|
249
|
+
else:
|
250
|
+
layer.w13_weight[expert_id][start : start + shard_size, :], _ = (
|
251
|
+
vllm_ops.scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
|
252
|
+
)
|
253
|
+
start += shard_size
|
254
|
+
|
255
|
+
layer.w13_weight_scale = torch.nn.Parameter(max_w13_scales, requires_grad=False)
|
256
|
+
|
257
|
+
def apply(
|
258
|
+
self,
|
259
|
+
layer: torch.nn.Module,
|
260
|
+
x: torch.Tensor,
|
261
|
+
router_logits: torch.Tensor,
|
262
|
+
top_k: int,
|
263
|
+
renormalize: bool,
|
264
|
+
use_grouped_topk: bool = False,
|
265
|
+
topk_group: Optional[int] = None,
|
266
|
+
num_expert_group: Optional[int] = None,
|
267
|
+
global_num_experts: int = -1,
|
268
|
+
expert_map: Optional[torch.Tensor] = None,
|
269
|
+
custom_routing_function: Optional[Callable] = None,
|
270
|
+
scoring_func: str = "softmax",
|
271
|
+
correction_bias: Optional[torch.Tensor] = None,
|
272
|
+
activation: str = "silu",
|
273
|
+
) -> torch.Tensor:
|
274
|
+
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
275
|
+
|
276
|
+
topk_weights, topk_ids = select_experts(
|
277
|
+
hidden_states=x,
|
278
|
+
router_logits=router_logits,
|
279
|
+
use_grouped_topk=use_grouped_topk,
|
280
|
+
top_k=top_k,
|
281
|
+
renormalize=renormalize,
|
282
|
+
topk_group=topk_group,
|
283
|
+
num_expert_group=num_expert_group,
|
284
|
+
custom_routing_function=custom_routing_function,
|
285
|
+
correction_bias=correction_bias,
|
286
|
+
)
|
287
|
+
|
288
|
+
return fused_experts(
|
289
|
+
x,
|
290
|
+
layer.w13_weight,
|
291
|
+
layer.w2_weight,
|
292
|
+
topk_weights=topk_weights,
|
293
|
+
topk_ids=topk_ids,
|
294
|
+
inplace=True,
|
295
|
+
activation=activation,
|
296
|
+
use_fp8_w8a8=True,
|
297
|
+
w1_scale=layer.w13_weight_scale,
|
298
|
+
w2_scale=layer.w2_weight_scale,
|
299
|
+
a1_scale=layer.w13_input_scale,
|
300
|
+
a2_scale=layer.w2_input_scale,
|
301
|
+
)
|
302
|
+
|
303
|
+
|
304
|
+
class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
305
|
+
|
306
|
+
def __init__(
|
307
|
+
self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
|
308
|
+
):
|
309
|
+
self.quant_config = quant_config
|
310
|
+
# TODO: @dsikka: refactor this to use schemes as other kernels
|
311
|
+
# are supported + check if the layer is being ignored.
|
312
|
+
config = self.quant_config.target_scheme_map["Linear"].get("weights")
|
313
|
+
self.num_bits = config.num_bits
|
314
|
+
self.packed_factor = 32 // config.num_bits
|
315
|
+
self.strategy = config.strategy
|
316
|
+
self.group_size = config.group_size
|
317
|
+
self.actorder = config.actorder
|
318
|
+
assert config.symmetric, "Only symmetric quantization is supported for MoE"
|
319
|
+
|
320
|
+
if not (
|
321
|
+
self.quant_config.quant_format == CompressionFormat.pack_quantized.value
|
322
|
+
and self.num_bits in WNA16_SUPPORTED_BITS
|
323
|
+
):
|
324
|
+
raise ValueError(
|
325
|
+
"For Fused MoE layers, only ",
|
326
|
+
f"{CompressionFormat.pack_quantized.value} ",
|
327
|
+
"is supported for the following bits: ",
|
328
|
+
f"{WNA16_SUPPORTED_BITS}",
|
329
|
+
)
|
330
|
+
|
331
|
+
def create_weights(
|
332
|
+
self,
|
333
|
+
layer: torch.nn.Module,
|
334
|
+
num_experts: int,
|
335
|
+
hidden_size: int,
|
336
|
+
intermediate_size_per_partition: int,
|
337
|
+
params_dtype: torch.dtype,
|
338
|
+
**extra_weight_attrs,
|
339
|
+
):
|
340
|
+
|
341
|
+
assert (
|
342
|
+
params_dtype == torch.float16
|
343
|
+
), "float16 is required for MoE compressed models. Set dtype=torch.float16" # noqa: E501
|
344
|
+
|
345
|
+
intermediate_size_full = extra_weight_attrs.pop("intermediate_size_full")
|
346
|
+
|
347
|
+
# Will transpose the loaded weight along the
|
348
|
+
# intermediate and hidden dim sizes. Will
|
349
|
+
# shard for TP along the transposed dims
|
350
|
+
extra_weight_attrs.update(
|
351
|
+
{"is_transposed": True, "quant_method": self.strategy}
|
352
|
+
)
|
353
|
+
w13_weight = torch.nn.Parameter(
|
354
|
+
torch.empty(
|
355
|
+
num_experts,
|
356
|
+
hidden_size // self.packed_factor,
|
357
|
+
2 * intermediate_size_per_partition,
|
358
|
+
dtype=torch.int32,
|
359
|
+
),
|
360
|
+
requires_grad=False,
|
361
|
+
)
|
362
|
+
layer.register_parameter("w13_weight_packed", w13_weight)
|
363
|
+
set_weight_attrs(w13_weight, extra_weight_attrs)
|
364
|
+
|
365
|
+
w2_weight = torch.nn.Parameter(
|
366
|
+
torch.empty(
|
367
|
+
num_experts,
|
368
|
+
intermediate_size_per_partition // self.packed_factor,
|
369
|
+
hidden_size,
|
370
|
+
dtype=torch.int32,
|
371
|
+
),
|
372
|
+
requires_grad=False,
|
373
|
+
)
|
374
|
+
layer.register_parameter("w2_weight_packed", w2_weight)
|
375
|
+
set_weight_attrs(w2_weight, extra_weight_attrs)
|
376
|
+
|
377
|
+
# In the case where we have actorder/g_idx,
|
378
|
+
# we do not partition the w2 scales
|
379
|
+
load_full_w2 = self.actorder and self.group_size != -1
|
380
|
+
w2_scales_size = (
|
381
|
+
intermediate_size_full if load_full_w2 else intermediate_size_per_partition
|
382
|
+
)
|
383
|
+
|
384
|
+
self.is_k_full = (not self.actorder) or (
|
385
|
+
intermediate_size_per_partition == intermediate_size_full
|
386
|
+
)
|
387
|
+
|
388
|
+
if self.strategy == "channel":
|
389
|
+
num_groups_w2 = num_groups_w13 = 1
|
390
|
+
self.group_size = -1
|
391
|
+
else:
|
392
|
+
num_groups_w2 = w2_scales_size // self.group_size
|
393
|
+
num_groups_w13 = hidden_size // self.group_size
|
394
|
+
|
395
|
+
w13_scale = torch.nn.Parameter(
|
396
|
+
torch.ones(
|
397
|
+
num_experts,
|
398
|
+
num_groups_w13,
|
399
|
+
2 * intermediate_size_per_partition,
|
400
|
+
dtype=params_dtype,
|
401
|
+
),
|
402
|
+
requires_grad=False,
|
403
|
+
)
|
404
|
+
layer.register_parameter("w13_weight_scale", w13_scale)
|
405
|
+
set_weight_attrs(w13_scale, extra_weight_attrs)
|
406
|
+
|
407
|
+
w2_scale = torch.nn.Parameter(
|
408
|
+
torch.ones(num_experts, num_groups_w2, hidden_size, dtype=params_dtype),
|
409
|
+
requires_grad=False,
|
410
|
+
)
|
411
|
+
layer.register_parameter("w2_weight_scale", w2_scale)
|
412
|
+
set_weight_attrs(w2_scale, extra_weight_attrs)
|
413
|
+
set_weight_attrs(w2_scale, {"load_full_w2": load_full_w2})
|
414
|
+
|
415
|
+
w2_weight_shape = torch.nn.Parameter(
|
416
|
+
torch.empty(num_experts, 2), requires_grad=False
|
417
|
+
)
|
418
|
+
layer.register_parameter("w2_weight_shape", w2_weight_shape)
|
419
|
+
set_weight_attrs(w2_weight_shape, extra_weight_attrs)
|
420
|
+
w13_weight_shape = torch.nn.Parameter(
|
421
|
+
torch.empty(num_experts, 2), requires_grad=False
|
422
|
+
)
|
423
|
+
|
424
|
+
layer.register_parameter("w13_weight_shape", w13_weight_shape)
|
425
|
+
set_weight_attrs(w13_weight_shape, extra_weight_attrs)
|
426
|
+
|
427
|
+
w13_g_idx = torch.nn.Parameter(
|
428
|
+
torch.empty(
|
429
|
+
num_experts,
|
430
|
+
hidden_size,
|
431
|
+
dtype=torch.int32,
|
432
|
+
),
|
433
|
+
requires_grad=False,
|
434
|
+
)
|
435
|
+
layer.register_parameter("w13_weight_g_idx", w13_g_idx)
|
436
|
+
set_weight_attrs(w13_g_idx, extra_weight_attrs)
|
437
|
+
|
438
|
+
w2_g_idx = torch.nn.Parameter(
|
439
|
+
torch.empty(
|
440
|
+
num_experts,
|
441
|
+
intermediate_size_per_partition,
|
442
|
+
dtype=torch.int32,
|
443
|
+
),
|
444
|
+
requires_grad=False,
|
445
|
+
)
|
446
|
+
layer.register_parameter("w2_weight_g_idx", w2_g_idx)
|
447
|
+
set_weight_attrs(w2_g_idx, extra_weight_attrs)
|
448
|
+
|
449
|
+
w13_g_idx_sort_indices = torch.nn.Parameter(
|
450
|
+
torch.empty(
|
451
|
+
num_experts,
|
452
|
+
hidden_size,
|
453
|
+
dtype=torch.int32,
|
454
|
+
),
|
455
|
+
requires_grad=False,
|
456
|
+
)
|
457
|
+
layer.register_parameter("w13_g_idx_sort_indices", w13_g_idx_sort_indices)
|
458
|
+
set_weight_attrs(w13_g_idx_sort_indices, extra_weight_attrs)
|
459
|
+
|
460
|
+
w2_g_idx_sort_indices = torch.nn.Parameter(
|
461
|
+
torch.empty(
|
462
|
+
num_experts,
|
463
|
+
intermediate_size_per_partition,
|
464
|
+
dtype=torch.int32,
|
465
|
+
),
|
466
|
+
requires_grad=False,
|
467
|
+
)
|
468
|
+
layer.register_parameter("w2_g_idx_sort_indices", w2_g_idx_sort_indices)
|
469
|
+
set_weight_attrs(w2_g_idx_sort_indices, extra_weight_attrs)
|
470
|
+
|
471
|
+
layer.a13_scale = None
|
472
|
+
layer.a2_scale = None
|
473
|
+
layer.marlin_state = GPTQMarlinState.REPACK
|
474
|
+
|
475
|
+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
476
|
+
|
477
|
+
def replace_tensor(name, new_t):
|
478
|
+
# It is important to use resize_() here since it ensures
|
479
|
+
# the same buffer is reused
|
480
|
+
getattr(layer, name).resize_(new_t.shape)
|
481
|
+
getattr(layer, name).copy_(new_t)
|
482
|
+
del new_t
|
483
|
+
|
484
|
+
def get_scale_perms(num_bits: int):
|
485
|
+
scale_perm: List[int] = []
|
486
|
+
for i in range(8):
|
487
|
+
scale_perm.extend([i + 8 * j for j in range(8)])
|
488
|
+
scale_perm_single: List[int] = []
|
489
|
+
for i in range(4):
|
490
|
+
scale_perm_single.extend(
|
491
|
+
[2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]]
|
492
|
+
)
|
493
|
+
return scale_perm, scale_perm_single
|
494
|
+
|
495
|
+
def marlin_permute_scales(
|
496
|
+
s: torch.Tensor, size_k: int, size_n: int, group_size: int, num_bits: int
|
497
|
+
):
|
498
|
+
scale_perm, scale_perm_single = get_scale_perms(num_bits)
|
499
|
+
if group_size < size_k and group_size != -1:
|
500
|
+
s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
|
501
|
+
else:
|
502
|
+
s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
|
503
|
+
s = s.reshape((-1, size_n)).contiguous()
|
504
|
+
return s
|
505
|
+
|
506
|
+
def marlin_moe_permute_scales(
|
507
|
+
s: torch.Tensor, size_k: int, size_n: int, group_size: int, num_bits: int
|
508
|
+
):
|
509
|
+
num_experts = s.shape[0]
|
510
|
+
output = torch.empty(
|
511
|
+
(num_experts, s.shape[1], s.shape[2]), device=s.device, dtype=s.dtype
|
512
|
+
)
|
513
|
+
for e in range(num_experts):
|
514
|
+
output[e] = marlin_permute_scales(
|
515
|
+
s[e], size_k, size_n, group_size, num_bits
|
516
|
+
)
|
517
|
+
return output
|
518
|
+
|
519
|
+
size_k2 = layer.w2_weight_packed.shape[2]
|
520
|
+
size_k13 = layer.w13_weight_packed.shape[2]
|
521
|
+
|
522
|
+
num_experts = layer.w13_weight_g_idx.shape[0]
|
523
|
+
device = layer.w13_weight_g_idx.device
|
524
|
+
|
525
|
+
# when running models with grouped act order,
|
526
|
+
# resort to g_idx values provided in checkpoint
|
527
|
+
if self.actorder == "group":
|
528
|
+
w13_g_idx_sort_indices = torch.empty_like(layer.w13_weight_g_idx)
|
529
|
+
w2_g_idx_sort_indices = torch.empty_like(layer.w2_weight_g_idx)
|
530
|
+
w13_sorted_g_idx = torch.empty_like(layer.w13_weight_g_idx)
|
531
|
+
w2_sorted_g_idx = torch.empty_like(layer.w2_weight_g_idx)
|
532
|
+
|
533
|
+
for e in range(num_experts):
|
534
|
+
w13_g_idx_sort_indices[e] = torch.argsort(layer.w13_weight_g_idx[e]).to(
|
535
|
+
torch.int32
|
536
|
+
)
|
537
|
+
w2_g_idx_sort_indices[e] = torch.argsort(layer.w2_weight_g_idx[e]).to(
|
538
|
+
torch.int32
|
539
|
+
)
|
540
|
+
w13_sorted_g_idx[e] = layer.w13_weight_g_idx[e][
|
541
|
+
w13_g_idx_sort_indices[e]
|
542
|
+
]
|
543
|
+
w2_sorted_g_idx[e] = layer.w2_weight_g_idx[e][w2_g_idx_sort_indices[e]]
|
544
|
+
|
545
|
+
replace_parameter(layer, "w13_weight_g_idx", w13_sorted_g_idx)
|
546
|
+
replace_parameter(layer, "w2_weight_g_idx", w2_sorted_g_idx)
|
547
|
+
replace_parameter(layer, "w13_g_idx_sort_indices", w13_g_idx_sort_indices)
|
548
|
+
replace_parameter(layer, "w2_g_idx_sort_indices", w2_g_idx_sort_indices)
|
549
|
+
|
550
|
+
else:
|
551
|
+
layer.w13_weight_g_idx = torch.nn.Parameter(
|
552
|
+
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
|
553
|
+
requires_grad=False,
|
554
|
+
)
|
555
|
+
layer.w2_weight_g_idx = torch.nn.Parameter(
|
556
|
+
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
|
557
|
+
requires_grad=False,
|
558
|
+
)
|
559
|
+
layer.w13_g_idx_sort_indices = torch.nn.Parameter(
|
560
|
+
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
|
561
|
+
requires_grad=False,
|
562
|
+
)
|
563
|
+
layer.w2_g_idx_sort_indices = torch.nn.Parameter(
|
564
|
+
torch.empty((num_experts, 0), dtype=torch.int32, device=device),
|
565
|
+
requires_grad=False,
|
566
|
+
)
|
567
|
+
|
568
|
+
marlin_w13_qweight = ops.gptq_marlin_moe_repack(
|
569
|
+
layer.w13_weight_packed,
|
570
|
+
layer.w13_g_idx_sort_indices,
|
571
|
+
layer.w13_weight_packed.shape[1] * self.packed_factor,
|
572
|
+
layer.w13_weight_packed.shape[2],
|
573
|
+
self.num_bits,
|
574
|
+
)
|
575
|
+
replace_tensor("w13_weight_packed", marlin_w13_qweight)
|
576
|
+
marlin_w2_qweight = ops.gptq_marlin_moe_repack(
|
577
|
+
layer.w2_weight_packed,
|
578
|
+
layer.w2_g_idx_sort_indices,
|
579
|
+
layer.w2_weight_packed.shape[1] * self.packed_factor,
|
580
|
+
layer.w2_weight_packed.shape[2],
|
581
|
+
self.num_bits,
|
582
|
+
)
|
583
|
+
replace_tensor("w2_weight_packed", marlin_w2_qweight)
|
584
|
+
# Repack scales
|
585
|
+
marlin_w13_scales = marlin_moe_permute_scales(
|
586
|
+
layer.w13_weight_scale,
|
587
|
+
size_k13,
|
588
|
+
layer.w13_weight_scale.shape[2],
|
589
|
+
self.group_size,
|
590
|
+
self.num_bits,
|
591
|
+
)
|
592
|
+
replace_tensor("w13_weight_scale", marlin_w13_scales)
|
593
|
+
marlin_w2_scales = marlin_moe_permute_scales(
|
594
|
+
layer.w2_weight_scale,
|
595
|
+
layer.w2_weight_scale.shape[1]
|
596
|
+
* (self.group_size if self.group_size != -1 else self.packed_factor),
|
597
|
+
size_k2,
|
598
|
+
self.group_size,
|
599
|
+
self.num_bits,
|
600
|
+
)
|
601
|
+
replace_tensor("w2_weight_scale", marlin_w2_scales)
|
602
|
+
|
603
|
+
def apply(
|
604
|
+
self,
|
605
|
+
layer: torch.nn.Module,
|
606
|
+
x: torch.Tensor,
|
607
|
+
router_logits: torch.Tensor,
|
608
|
+
top_k: int,
|
609
|
+
renormalize: bool,
|
610
|
+
use_grouped_topk: bool = False,
|
611
|
+
topk_group: Optional[int] = None,
|
612
|
+
num_expert_group: Optional[int] = None,
|
613
|
+
global_num_experts: int = -1,
|
614
|
+
expert_map: Optional[torch.Tensor] = None,
|
615
|
+
custom_routing_function: Optional[Callable] = None,
|
616
|
+
scoring_func: str = "softmax",
|
617
|
+
correction_bias: Optional[torch.Tensor] = None,
|
618
|
+
activation: str = "silu",
|
619
|
+
) -> torch.Tensor:
|
620
|
+
assert activation == "silu", "Only SiLU activation is supported."
|
621
|
+
if not VLLM_AVAILABLE:
|
622
|
+
raise ImportError(
|
623
|
+
"vllm is not installed, to use fused_marlin_moe, please install vllm"
|
624
|
+
)
|
625
|
+
if expert_map is not None:
|
626
|
+
raise NotImplementedError(
|
627
|
+
"Expert Parallelism is not supported for " "fused Marlin MoE method."
|
628
|
+
)
|
629
|
+
|
630
|
+
topk_weights, topk_ids = select_experts(
|
631
|
+
hidden_states=x,
|
632
|
+
router_logits=router_logits,
|
633
|
+
use_grouped_topk=use_grouped_topk,
|
634
|
+
top_k=top_k,
|
635
|
+
renormalize=renormalize,
|
636
|
+
topk_group=topk_group,
|
637
|
+
num_expert_group=num_expert_group,
|
638
|
+
custom_routing_function=custom_routing_function,
|
639
|
+
scoring_func=scoring_func,
|
640
|
+
correction_bias=correction_bias,
|
641
|
+
)
|
642
|
+
|
643
|
+
return torch.ops.vllm.fused_marlin_moe(
|
644
|
+
x,
|
645
|
+
layer.w13_weight_packed,
|
646
|
+
layer.w2_weight_packed,
|
647
|
+
layer.w13_weight_scale,
|
648
|
+
layer.w2_weight_scale,
|
649
|
+
router_logits,
|
650
|
+
topk_weights,
|
651
|
+
topk_ids,
|
652
|
+
g_idx1=layer.w13_weight_g_idx,
|
653
|
+
g_idx2=layer.w2_weight_g_idx,
|
654
|
+
sort_indices1=layer.w13_g_idx_sort_indices,
|
655
|
+
sort_indices2=layer.w2_g_idx_sort_indices,
|
656
|
+
num_bits=self.num_bits,
|
657
|
+
is_k_full=self.is_k_full,
|
658
|
+
)
|