sglang 0.4.4__py3-none-any.whl → 0.4.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +6 -0
- sglang/bench_one_batch.py +1 -1
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +3 -1
- sglang/check_env.py +3 -4
- sglang/lang/backend/openai.py +18 -5
- sglang/lang/chat_template.py +28 -7
- sglang/lang/interpreter.py +7 -3
- sglang/lang/ir.py +10 -0
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/code_completion_parser.py +174 -0
- sglang/srt/configs/__init__.py +2 -6
- sglang/srt/configs/deepseekvl2.py +667 -0
- sglang/srt/configs/janus_pro.py +3 -4
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +63 -11
- sglang/srt/configs/utils.py +25 -0
- sglang/srt/connector/__init__.py +51 -0
- sglang/srt/connector/base_connector.py +112 -0
- sglang/srt/connector/redis.py +85 -0
- sglang/srt/connector/s3.py +122 -0
- sglang/srt/connector/serde/__init__.py +31 -0
- sglang/srt/connector/serde/safe_serde.py +29 -0
- sglang/srt/connector/serde/serde.py +43 -0
- sglang/srt/connector/utils.py +35 -0
- sglang/srt/conversation.py +88 -0
- sglang/srt/disaggregation/conn.py +81 -0
- sglang/srt/disaggregation/decode.py +495 -0
- sglang/srt/disaggregation/mini_lb.py +285 -0
- sglang/srt/disaggregation/prefill.py +249 -0
- sglang/srt/disaggregation/utils.py +44 -0
- sglang/srt/distributed/parallel_state.py +10 -3
- sglang/srt/entrypoints/engine.py +55 -5
- sglang/srt/entrypoints/http_server.py +71 -12
- sglang/srt/function_call_parser.py +164 -54
- sglang/srt/hf_transformers_utils.py +28 -3
- sglang/srt/layers/activation.py +4 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +295 -0
- sglang/srt/layers/attention/flashinfer_backend.py +1 -1
- sglang/srt/layers/attention/flashmla_backend.py +284 -0
- sglang/srt/layers/attention/triton_backend.py +171 -38
- sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
- sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
- sglang/srt/layers/attention/utils.py +53 -0
- sglang/srt/layers/attention/vision.py +9 -28
- sglang/srt/layers/dp_attention.py +62 -23
- sglang/srt/layers/elementwise.py +411 -0
- sglang/srt/layers/layernorm.py +24 -2
- sglang/srt/layers/linear.py +17 -5
- sglang/srt/layers/logits_processor.py +26 -7
- sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
- sglang/srt/layers/moe/ep_moe/layer.py +273 -1
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
- sglang/srt/layers/moe/fused_moe_native.py +2 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
- sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
- sglang/srt/layers/moe/router.py +342 -0
- sglang/srt/layers/moe/topk.py +31 -18
- sglang/srt/layers/parameter.py +1 -1
- sglang/srt/layers/quantization/__init__.py +184 -126
- sglang/srt/layers/quantization/base_config.py +5 -0
- sglang/srt/layers/quantization/blockwise_int8.py +1 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
- sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
- sglang/srt/layers/quantization/fp8.py +76 -34
- sglang/srt/layers/quantization/fp8_kernel.py +24 -8
- sglang/srt/layers/quantization/fp8_utils.py +284 -28
- sglang/srt/layers/quantization/gptq.py +36 -9
- sglang/srt/layers/quantization/kv_cache.py +98 -0
- sglang/srt/layers/quantization/modelopt_quant.py +9 -7
- sglang/srt/layers/quantization/utils.py +153 -0
- sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
- sglang/srt/layers/rotary_embedding.py +66 -87
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/layers.py +68 -0
- sglang/srt/lora/lora.py +2 -22
- sglang/srt/lora/lora_manager.py +47 -23
- sglang/srt/lora/mem_pool.py +110 -51
- sglang/srt/lora/utils.py +12 -1
- sglang/srt/managers/cache_controller.py +4 -5
- sglang/srt/managers/data_parallel_controller.py +31 -9
- sglang/srt/managers/expert_distribution.py +81 -0
- sglang/srt/managers/io_struct.py +39 -3
- sglang/srt/managers/mm_utils.py +373 -0
- sglang/srt/managers/multimodal_processor.py +68 -0
- sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
- sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
- sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
- sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
- sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
- sglang/srt/managers/schedule_batch.py +134 -31
- sglang/srt/managers/scheduler.py +325 -38
- sglang/srt/managers/scheduler_output_processor_mixin.py +4 -1
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +59 -23
- sglang/srt/managers/tp_worker.py +1 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
- sglang/srt/managers/utils.py +6 -1
- sglang/srt/mem_cache/hiradix_cache.py +27 -8
- sglang/srt/mem_cache/memory_pool.py +258 -98
- sglang/srt/mem_cache/paged_allocator.py +2 -2
- sglang/srt/mem_cache/radix_cache.py +4 -4
- sglang/srt/model_executor/cuda_graph_runner.py +85 -28
- sglang/srt/model_executor/forward_batch_info.py +81 -15
- sglang/srt/model_executor/model_runner.py +70 -6
- sglang/srt/model_loader/loader.py +160 -2
- sglang/srt/model_loader/weight_utils.py +45 -0
- sglang/srt/models/deepseek_janus_pro.py +29 -86
- sglang/srt/models/deepseek_nextn.py +22 -10
- sglang/srt/models/deepseek_v2.py +326 -192
- sglang/srt/models/deepseek_vl2.py +358 -0
- sglang/srt/models/gemma3_causal.py +684 -0
- sglang/srt/models/gemma3_mm.py +462 -0
- sglang/srt/models/grok.py +374 -119
- sglang/srt/models/llama.py +47 -7
- sglang/srt/models/llama_eagle.py +1 -0
- sglang/srt/models/llama_eagle3.py +196 -0
- sglang/srt/models/llava.py +3 -3
- sglang/srt/models/llavavid.py +3 -3
- sglang/srt/models/minicpmo.py +1995 -0
- sglang/srt/models/minicpmv.py +62 -137
- sglang/srt/models/mllama.py +4 -4
- sglang/srt/models/phi3_small.py +1 -1
- sglang/srt/models/qwen2.py +3 -0
- sglang/srt/models/qwen2_5_vl.py +68 -146
- sglang/srt/models/qwen2_classification.py +75 -0
- sglang/srt/models/qwen2_moe.py +9 -1
- sglang/srt/models/qwen2_vl.py +25 -63
- sglang/srt/openai_api/adapter.py +145 -47
- sglang/srt/openai_api/protocol.py +23 -2
- sglang/srt/sampling/sampling_batch_info.py +1 -1
- sglang/srt/sampling/sampling_params.py +6 -6
- sglang/srt/server_args.py +104 -14
- sglang/srt/speculative/build_eagle_tree.py +7 -347
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
- sglang/srt/speculative/eagle_utils.py +208 -252
- sglang/srt/speculative/eagle_worker.py +139 -53
- sglang/srt/speculative/spec_info.py +6 -1
- sglang/srt/torch_memory_saver_adapter.py +22 -0
- sglang/srt/utils.py +182 -21
- sglang/test/__init__.py +0 -0
- sglang/test/attention/__init__.py +0 -0
- sglang/test/attention/test_flashattn_backend.py +312 -0
- sglang/test/runners.py +2 -0
- sglang/test/test_activation.py +2 -1
- sglang/test/test_block_fp8.py +5 -4
- sglang/test/test_block_fp8_ep.py +2 -1
- sglang/test/test_dynamic_grad_mode.py +58 -0
- sglang/test/test_layernorm.py +3 -2
- sglang/test/test_utils.py +55 -4
- sglang/utils.py +31 -0
- sglang/version.py +1 -1
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/METADATA +12 -8
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/RECORD +171 -125
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/WHEEL +1 -1
- sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
- sglang/srt/managers/image_processor.py +0 -55
- sglang/srt/managers/image_processors/base_image_processor.py +0 -219
- sglang/srt/managers/image_processors/minicpmv.py +0 -86
- sglang/srt/managers/multi_modality_padding.py +0 -134
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info/licenses}/LICENSE +0 -0
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,652 @@
|
|
1
|
+
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
|
4
|
+
import logging
|
5
|
+
from contextlib import suppress
|
6
|
+
from typing import Any, Dict, List, Literal, NamedTuple, Optional, Tuple, cast
|
7
|
+
|
8
|
+
import torch
|
9
|
+
from compressed_tensors.config import (
|
10
|
+
CompressionFormat,
|
11
|
+
SparsityCompressionConfig,
|
12
|
+
SparsityStructure,
|
13
|
+
)
|
14
|
+
from compressed_tensors.quantization import (
|
15
|
+
QuantizationArgs,
|
16
|
+
QuantizationStrategy,
|
17
|
+
QuantizationType,
|
18
|
+
)
|
19
|
+
from pydantic import BaseModel
|
20
|
+
|
21
|
+
from sglang.srt.layers.linear import (
|
22
|
+
LinearBase,
|
23
|
+
LinearMethodBase,
|
24
|
+
UnquantizedLinearMethod,
|
25
|
+
)
|
26
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
27
|
+
from sglang.srt.layers.quantization.base_config import (
|
28
|
+
QuantizationConfig,
|
29
|
+
QuantizeMethodBase,
|
30
|
+
)
|
31
|
+
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_moe import ( # noqa: E501
|
32
|
+
CompressedTensorsMoEMethod,
|
33
|
+
)
|
34
|
+
from sglang.srt.layers.quantization.compressed_tensors.schemes import (
|
35
|
+
CompressedTensorsScheme,
|
36
|
+
CompressedTensorsW8A8Fp8,
|
37
|
+
)
|
38
|
+
from sglang.srt.layers.quantization.compressed_tensors.utils import (
|
39
|
+
find_matched_target,
|
40
|
+
is_activation_quantization_format,
|
41
|
+
should_ignore_layer,
|
42
|
+
)
|
43
|
+
from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
|
44
|
+
|
45
|
+
logger = logging.getLogger(__name__)
|
46
|
+
|
47
|
+
__all__ = ["CompressedTensorsLinearMethod"]
|
48
|
+
|
49
|
+
SPARSITY_CONFIG_NAME: Literal["sparsity_config"] = "sparsity_config"
|
50
|
+
QUANTIZATION_SCHEME_MAP_TYPE = Dict[str, Optional[Dict[str, QuantizationArgs]]]
|
51
|
+
|
52
|
+
|
53
|
+
class DeviceCapability(NamedTuple):
|
54
|
+
major: int
|
55
|
+
minor: int
|
56
|
+
|
57
|
+
def as_version_str(self) -> str:
|
58
|
+
return f"{self.major}.{self.minor}"
|
59
|
+
|
60
|
+
def to_int(self) -> int:
|
61
|
+
"""
|
62
|
+
Express device capability as an integer ``<major><minor>``.
|
63
|
+
|
64
|
+
It is assumed that the minor version is always a single digit.
|
65
|
+
"""
|
66
|
+
assert 0 <= self.minor < 10
|
67
|
+
return self.major * 10 + self.minor
|
68
|
+
|
69
|
+
|
70
|
+
class CompressedTensorsConfig(QuantizationConfig):
|
71
|
+
|
72
|
+
def __init__(
|
73
|
+
self,
|
74
|
+
target_scheme_map: Dict[str, Any],
|
75
|
+
ignore: List[str],
|
76
|
+
quant_format: str,
|
77
|
+
sparsity_scheme_map: Dict[str, SparsityCompressionConfig],
|
78
|
+
sparsity_ignore_list: List[str],
|
79
|
+
kv_cache_scheme: Optional[Dict[str, Any]] = None,
|
80
|
+
config: Optional[Dict[str, Any]] = None,
|
81
|
+
):
|
82
|
+
super().__init__()
|
83
|
+
self.ignore = ignore
|
84
|
+
self.quant_format = quant_format
|
85
|
+
# Map from [target -> scheme]
|
86
|
+
self.target_scheme_map = target_scheme_map
|
87
|
+
self.kv_cache_scheme = kv_cache_scheme
|
88
|
+
self.sparsity_scheme_map = sparsity_scheme_map
|
89
|
+
self.sparsity_ignore_list = sparsity_ignore_list
|
90
|
+
self.config = config
|
91
|
+
|
92
|
+
def get_linear_method(self) -> "CompressedTensorsLinearMethod":
|
93
|
+
return CompressedTensorsLinearMethod(self)
|
94
|
+
|
95
|
+
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
96
|
+
return [torch.float16, torch.bfloat16]
|
97
|
+
|
98
|
+
@classmethod
|
99
|
+
def get_min_capability(cls) -> int:
|
100
|
+
return 70
|
101
|
+
|
102
|
+
def get_name(self) -> str:
|
103
|
+
return "compressed_tensors"
|
104
|
+
|
105
|
+
def get_scaled_act_names(self) -> List[str]:
|
106
|
+
return []
|
107
|
+
|
108
|
+
def get_quant_method(
|
109
|
+
self,
|
110
|
+
layer: torch.nn.Module,
|
111
|
+
prefix: str,
|
112
|
+
) -> Optional["QuantizeMethodBase"]:
|
113
|
+
|
114
|
+
# Check if the layer is skipped for quantization.
|
115
|
+
# TODO (@robertgshaw2): support module names
|
116
|
+
if should_ignore_layer(
|
117
|
+
prefix, ignore=self.ignore, fused_mapping=self.packed_modules_mapping
|
118
|
+
):
|
119
|
+
return UnquantizedLinearMethod()
|
120
|
+
if isinstance(layer, LinearBase):
|
121
|
+
scheme = self.get_scheme(layer=layer, layer_name=prefix)
|
122
|
+
if scheme is None:
|
123
|
+
return UnquantizedLinearMethod()
|
124
|
+
layer.scheme = scheme
|
125
|
+
return CompressedTensorsLinearMethod(self)
|
126
|
+
if isinstance(layer, FusedMoE):
|
127
|
+
return CompressedTensorsMoEMethod.get_moe_method(self)
|
128
|
+
return None
|
129
|
+
|
130
|
+
@classmethod
|
131
|
+
def from_config(cls, config: Dict[str, Any]) -> "CompressedTensorsConfig":
|
132
|
+
ignore: List[str] = cast(List[str], config.get("ignore", []))
|
133
|
+
quant_format = cast(str, config.get("format"))
|
134
|
+
target_scheme_map = cls._quantization_scheme_map_from_config(config=config)
|
135
|
+
sparsity_scheme_map, sparsity_ignore_list = cls._parse_sparsity_config(
|
136
|
+
config=config
|
137
|
+
)
|
138
|
+
|
139
|
+
return cls(
|
140
|
+
target_scheme_map=target_scheme_map,
|
141
|
+
ignore=ignore,
|
142
|
+
quant_format=quant_format,
|
143
|
+
sparsity_scheme_map=sparsity_scheme_map,
|
144
|
+
sparsity_ignore_list=sparsity_ignore_list,
|
145
|
+
config=config,
|
146
|
+
)
|
147
|
+
|
148
|
+
@classmethod
|
149
|
+
def _parse_sparsity_config(
|
150
|
+
cls, config: Dict[str, Any]
|
151
|
+
) -> Tuple[Dict[str, SparsityCompressionConfig], List[str]]:
|
152
|
+
"""
|
153
|
+
:param config: The `quantization_config` dictionary from config.json
|
154
|
+
:return: A tuple with two elements
|
155
|
+
1. A dictionary mapping target layer names to their corresponding
|
156
|
+
sparsity_config
|
157
|
+
2. A list of layer names to ignore for sparsity
|
158
|
+
"""
|
159
|
+
if not (sparsity_config := config.get(SPARSITY_CONFIG_NAME)):
|
160
|
+
return dict(), []
|
161
|
+
|
162
|
+
sparsity_config = SparsityCompressionConfig.model_validate(sparsity_config)
|
163
|
+
sparse_scheme_map: Dict[str, SparsityCompressionConfig] = {
|
164
|
+
target: sparsity_config for target in sparsity_config.targets or list()
|
165
|
+
}
|
166
|
+
sparsity_ignore_list = sparsity_config.ignore or list()
|
167
|
+
return sparse_scheme_map, sparsity_ignore_list
|
168
|
+
|
169
|
+
@classmethod
|
170
|
+
def _quantization_scheme_map_from_config(
|
171
|
+
cls, config: Dict[str, Any]
|
172
|
+
) -> QUANTIZATION_SCHEME_MAP_TYPE:
|
173
|
+
"""
|
174
|
+
:param config: The `quantization_config` dictionary from config.json
|
175
|
+
:return: A dictionary mapping target layer names to their corresponding
|
176
|
+
quantization_args for weights and input activations
|
177
|
+
"""
|
178
|
+
target_scheme_map: Dict[str, Any] = dict()
|
179
|
+
quant_format = cast(str, config.get("format"))
|
180
|
+
|
181
|
+
# The quant_config has multiple config_groups, each containing
|
182
|
+
# an input_activations key with details about how the activations are
|
183
|
+
# quantized, a weights key indicating how the weights are quantized,
|
184
|
+
# and a list of targets under the `targets` key, dictating which
|
185
|
+
# layers are impacted by the quantization details. The quantization
|
186
|
+
# details follow the structure defined by the QuantizationArgs
|
187
|
+
# pydantic model, which is used to verify the structure of the
|
188
|
+
# quant_config and also store the details for later use.
|
189
|
+
|
190
|
+
config_groups = config.get("config_groups", dict())
|
191
|
+
for _, quant_config in config_groups.items():
|
192
|
+
targets = quant_config.get("targets")
|
193
|
+
for target in targets:
|
194
|
+
target_scheme_map[target] = {}
|
195
|
+
target_scheme_map[target]["weights"] = QuantizationArgs.model_validate(
|
196
|
+
quant_config.get("weights")
|
197
|
+
)
|
198
|
+
|
199
|
+
target_scheme_map[target]["input_activations"] = None
|
200
|
+
if is_activation_quantization_format(quant_format):
|
201
|
+
input_activations = quant_config.get("input_activations")
|
202
|
+
# The only case where we have activation quant supported
|
203
|
+
# but no input_activations provided in the config
|
204
|
+
# should be w8a16fp8 w8a16fp8 can also run for cases where
|
205
|
+
# there is an input_quant but it is ignored
|
206
|
+
if not input_activations:
|
207
|
+
assert (
|
208
|
+
target_scheme_map[target]["weights"].type
|
209
|
+
== QuantizationType.FLOAT
|
210
|
+
)
|
211
|
+
else:
|
212
|
+
target_scheme_map[target]["input_activations"] = (
|
213
|
+
QuantizationArgs.model_validate( # noqa: E501
|
214
|
+
quant_config.get("input_activations")
|
215
|
+
)
|
216
|
+
)
|
217
|
+
return target_scheme_map
|
218
|
+
|
219
|
+
@classmethod
|
220
|
+
def get_config_filenames(cls) -> List[str]:
|
221
|
+
return []
|
222
|
+
|
223
|
+
def _check_scheme_supported(self, min_capability: int, error: bool = True) -> bool:
|
224
|
+
capability_tuple = DeviceCapability(*torch.cuda.get_device_capability())
|
225
|
+
|
226
|
+
if capability_tuple is not None:
|
227
|
+
capability = capability_tuple.to_int()
|
228
|
+
supported = capability >= min_capability
|
229
|
+
if error and not supported:
|
230
|
+
raise RuntimeError(
|
231
|
+
"Quantization scheme is not supported for ",
|
232
|
+
f"the current GPU. Min capability: {min_capability}. ",
|
233
|
+
f"Current capability: {capability}.",
|
234
|
+
)
|
235
|
+
return supported
|
236
|
+
else:
|
237
|
+
return False
|
238
|
+
|
239
|
+
def _is_static_tensor_w8a8(
|
240
|
+
self, weight_quant: BaseModel, input_quant: BaseModel
|
241
|
+
) -> bool:
|
242
|
+
is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
|
243
|
+
weight_strategy = (
|
244
|
+
weight_quant.strategy == QuantizationStrategy.TENSOR.value
|
245
|
+
or weight_quant.strategy == QuantizationStrategy.CHANNEL.value
|
246
|
+
)
|
247
|
+
is_tensor = (
|
248
|
+
weight_strategy
|
249
|
+
and input_quant.strategy == QuantizationStrategy.TENSOR.value
|
250
|
+
)
|
251
|
+
is_static = not weight_quant.dynamic and not input_quant.dynamic
|
252
|
+
|
253
|
+
# Both symmetric and asymmetric input quantization supported.
|
254
|
+
# Only symmetric weight quantization supported.
|
255
|
+
return is_8_bits and is_tensor and weight_quant.symmetric and is_static
|
256
|
+
|
257
|
+
def _is_dynamic_token_w8a8(
|
258
|
+
self, weight_quant: BaseModel, input_quant: BaseModel
|
259
|
+
) -> bool:
|
260
|
+
is_8_bits = weight_quant.num_bits == input_quant.num_bits == 8
|
261
|
+
weight_strategy = (
|
262
|
+
weight_quant.strategy == QuantizationStrategy.TENSOR.value
|
263
|
+
or weight_quant.strategy == QuantizationStrategy.CHANNEL.value
|
264
|
+
)
|
265
|
+
is_token = (
|
266
|
+
weight_strategy and input_quant.strategy == QuantizationStrategy.TOKEN.value
|
267
|
+
)
|
268
|
+
is_dynamic = not weight_quant.dynamic and input_quant.dynamic
|
269
|
+
|
270
|
+
# Both symmetric and asymmetric input quantization supported.
|
271
|
+
# Only symmetric weight quantization supported.
|
272
|
+
return is_8_bits and is_token and weight_quant.symmetric and is_dynamic
|
273
|
+
|
274
|
+
def _is_fp8_w8a8(self, weight_quant: BaseModel, input_quant: BaseModel) -> bool:
|
275
|
+
# Confirm weights and activations quantized.
|
276
|
+
if weight_quant is None or input_quant is None:
|
277
|
+
return False
|
278
|
+
|
279
|
+
# Confirm weight scheme is supported.
|
280
|
+
is_floating_point = (
|
281
|
+
weight_quant.type == QuantizationType.FLOAT
|
282
|
+
and input_quant.type == QuantizationType.FLOAT
|
283
|
+
)
|
284
|
+
is_symmetric_weight = weight_quant.symmetric
|
285
|
+
is_static_weight = not weight_quant.dynamic
|
286
|
+
is_per_tensor_or_channel_weight = weight_quant.strategy in [
|
287
|
+
QuantizationStrategy.TENSOR,
|
288
|
+
QuantizationStrategy.CHANNEL,
|
289
|
+
]
|
290
|
+
if not (
|
291
|
+
is_floating_point
|
292
|
+
and is_symmetric_weight
|
293
|
+
and is_static_weight
|
294
|
+
and is_per_tensor_or_channel_weight
|
295
|
+
):
|
296
|
+
return False
|
297
|
+
|
298
|
+
# Dynamic quantization is always supported if weights supported.
|
299
|
+
if input_quant.dynamic:
|
300
|
+
return True
|
301
|
+
|
302
|
+
# Confirm activation scheme is supported.
|
303
|
+
is_symmetric_activation = input_quant.symmetric
|
304
|
+
is_per_tensor_activation = input_quant.strategy == QuantizationStrategy.TENSOR
|
305
|
+
return is_symmetric_activation and is_per_tensor_activation
|
306
|
+
|
307
|
+
def _is_fp8_w8a16(self, weight_quant: BaseModel, input_quant: BaseModel) -> bool:
|
308
|
+
# Confirm weights quantized.
|
309
|
+
if weight_quant is None:
|
310
|
+
return False
|
311
|
+
|
312
|
+
# Confirm we have floating points.
|
313
|
+
if weight_quant.type != QuantizationType.FLOAT:
|
314
|
+
return False
|
315
|
+
|
316
|
+
# Confirm weight scheme is supported.
|
317
|
+
is_symmetric_weight = weight_quant.symmetric
|
318
|
+
is_static_weight = not weight_quant.dynamic
|
319
|
+
is_per_tensor_or_channel_weight = weight_quant.strategy in [
|
320
|
+
QuantizationStrategy.TENSOR,
|
321
|
+
QuantizationStrategy.CHANNEL,
|
322
|
+
]
|
323
|
+
if not (
|
324
|
+
is_symmetric_weight
|
325
|
+
and is_static_weight # noqa: SIM103
|
326
|
+
and is_per_tensor_or_channel_weight
|
327
|
+
):
|
328
|
+
return False
|
329
|
+
|
330
|
+
# All conditions satisfied.
|
331
|
+
return True
|
332
|
+
|
333
|
+
def _is_wNa16_group_channel(
|
334
|
+
self, weight_quant: BaseModel, input_quant: BaseModel
|
335
|
+
) -> bool:
|
336
|
+
input_quant_none = input_quant is None
|
337
|
+
is_symmetric = weight_quant.symmetric
|
338
|
+
is_channel_group = (
|
339
|
+
weight_quant.strategy == QuantizationStrategy.CHANNEL.value
|
340
|
+
or weight_quant.strategy == QuantizationStrategy.GROUP.value
|
341
|
+
)
|
342
|
+
is_static = not weight_quant.dynamic
|
343
|
+
|
344
|
+
return is_channel_group and input_quant_none and is_symmetric and is_static
|
345
|
+
|
346
|
+
def _get_scheme_from_parts(
|
347
|
+
self, weight_quant: BaseModel, input_quant: BaseModel
|
348
|
+
) -> "CompressedTensorsScheme":
|
349
|
+
|
350
|
+
# Detect If Mixed Precision
|
351
|
+
if self._is_wNa16_group_channel(weight_quant, input_quant):
|
352
|
+
if not VLLM_AVAILABLE:
|
353
|
+
raise ImportError(
|
354
|
+
"vllm is not installed, to use CompressedTensorsW4A16Sparse24 and CompressedTensorsWNA16, please install vllm"
|
355
|
+
)
|
356
|
+
if (
|
357
|
+
self.quant_format == CompressionFormat.marlin_24.value
|
358
|
+
and weight_quant.num_bits in W4A16SPARSE24_SUPPORTED_BITS
|
359
|
+
):
|
360
|
+
return CompressedTensorsW4A16Sparse24(
|
361
|
+
strategy=weight_quant.strategy,
|
362
|
+
num_bits=weight_quant.num_bits,
|
363
|
+
group_size=weight_quant.group_size,
|
364
|
+
)
|
365
|
+
if (
|
366
|
+
self.quant_format == CompressionFormat.pack_quantized.value
|
367
|
+
and weight_quant.num_bits in WNA16_SUPPORTED_BITS
|
368
|
+
):
|
369
|
+
return CompressedTensorsWNA16(
|
370
|
+
num_bits=weight_quant.num_bits,
|
371
|
+
strategy=weight_quant.strategy,
|
372
|
+
group_size=weight_quant.group_size,
|
373
|
+
actorder=weight_quant.actorder,
|
374
|
+
)
|
375
|
+
|
376
|
+
if is_activation_quantization_format(self.quant_format):
|
377
|
+
if self._is_fp8_w8a8(weight_quant, input_quant):
|
378
|
+
is_fp8_w8a8_supported = self._check_scheme_supported(
|
379
|
+
CompressedTensorsW8A8Fp8.get_min_capability(), error=False
|
380
|
+
)
|
381
|
+
if is_fp8_w8a8_supported:
|
382
|
+
return CompressedTensorsW8A8Fp8(
|
383
|
+
strategy=weight_quant.strategy,
|
384
|
+
is_static_input_scheme=(
|
385
|
+
input_quant and not input_quant.dynamic
|
386
|
+
),
|
387
|
+
)
|
388
|
+
else:
|
389
|
+
# note: input_quant will be present for converted models;
|
390
|
+
# will be ignored during inference post loading
|
391
|
+
return CompressedTensorsW8A16Fp8(
|
392
|
+
strategy=weight_quant.strategy,
|
393
|
+
is_static_input_scheme=not input_quant.dynamic,
|
394
|
+
)
|
395
|
+
|
396
|
+
# note: input_quant can be None
|
397
|
+
if self._is_fp8_w8a16(weight_quant, input_quant):
|
398
|
+
if not VLLM_AVAILABLE:
|
399
|
+
raise ImportError(
|
400
|
+
"vllm is not installed, to use CompressedTensorsW8A16Fp8, please install vllm"
|
401
|
+
)
|
402
|
+
is_static_input_scheme = input_quant and not input_quant.dynamic
|
403
|
+
return CompressedTensorsW8A16Fp8(
|
404
|
+
strategy=weight_quant.strategy,
|
405
|
+
is_static_input_scheme=is_static_input_scheme,
|
406
|
+
)
|
407
|
+
|
408
|
+
if self._is_static_tensor_w8a8(weight_quant, input_quant):
|
409
|
+
return CompressedTensorsW8A8Int8(
|
410
|
+
strategy=weight_quant.strategy,
|
411
|
+
is_static_input_scheme=True,
|
412
|
+
input_symmetric=input_quant.symmetric,
|
413
|
+
)
|
414
|
+
|
415
|
+
if self._is_dynamic_token_w8a8(weight_quant, input_quant):
|
416
|
+
return CompressedTensorsW8A8Int8(
|
417
|
+
strategy=weight_quant.strategy,
|
418
|
+
is_static_input_scheme=False,
|
419
|
+
input_symmetric=input_quant.symmetric,
|
420
|
+
)
|
421
|
+
|
422
|
+
raise NotImplementedError("No compressed-tensors compatible scheme was found.")
|
423
|
+
|
424
|
+
def get_scheme(
|
425
|
+
self, layer: torch.nn.Module, layer_name: Optional[str] = None
|
426
|
+
) -> Optional["CompressedTensorsScheme"]:
|
427
|
+
"""
|
428
|
+
compressed-tensors supports non uniform in the following way:
|
429
|
+
|
430
|
+
targets of config_groups: There can be N config_groups which each
|
431
|
+
have a quantization scheme. Each config_group has a list of targets
|
432
|
+
which can be a full layer_name, a regex for a layer_name, or
|
433
|
+
an nn.Module name.
|
434
|
+
|
435
|
+
Detect whether a layer_name is found in any target and
|
436
|
+
use the quantization scheme corresponding to the matched target
|
437
|
+
to select the CompressedTensorsScheme used for infernece.
|
438
|
+
"""
|
439
|
+
|
440
|
+
# Find the "target" in the compressed-tensors config
|
441
|
+
# that our layer conforms to.
|
442
|
+
# TODO (@robertgshaw): add compressed-tensors as dep
|
443
|
+
# so we do not have to re-write these functions
|
444
|
+
# need to make accelerate optional in ct to do this
|
445
|
+
|
446
|
+
# Will be empty for models with only sparsity
|
447
|
+
weight_quant = input_quant = None
|
448
|
+
if self.target_scheme_map:
|
449
|
+
matched_target = find_matched_target(
|
450
|
+
layer_name=layer_name,
|
451
|
+
module=layer,
|
452
|
+
targets=self.target_scheme_map.keys(),
|
453
|
+
fused_mapping=self.packed_modules_mapping,
|
454
|
+
)
|
455
|
+
|
456
|
+
scheme_dict = self.target_scheme_map[matched_target]
|
457
|
+
weight_quant = scheme_dict.get("weights")
|
458
|
+
input_quant = scheme_dict.get("input_activations")
|
459
|
+
|
460
|
+
# Find the sparsity scheme of the layer
|
461
|
+
# assume that fused layers inerhit first component's sparsity scheme
|
462
|
+
sparsity_targets = self.sparsity_scheme_map.keys() - set(
|
463
|
+
self.sparsity_ignore_list
|
464
|
+
)
|
465
|
+
sparsity_scheme: Optional[SparsityCompressionConfig] = None
|
466
|
+
with suppress(ValueError):
|
467
|
+
matched_target = find_matched_target(
|
468
|
+
layer_name=layer_name,
|
469
|
+
module=layer,
|
470
|
+
targets=sparsity_targets,
|
471
|
+
fused_mapping=self.packed_modules_mapping,
|
472
|
+
)
|
473
|
+
sparsity_scheme = self.sparsity_scheme_map[matched_target]
|
474
|
+
|
475
|
+
if self.supports_cutlass_24(
|
476
|
+
weight_quant=weight_quant,
|
477
|
+
input_quant=input_quant,
|
478
|
+
sparsity_scheme=sparsity_scheme,
|
479
|
+
):
|
480
|
+
if not VLLM_AVAILABLE:
|
481
|
+
raise ImportError(
|
482
|
+
"vllm is not installed, to use CompressedTensors24, please install vllm"
|
483
|
+
)
|
484
|
+
# Have a valid sparsity scheme
|
485
|
+
# Validate layer is supported by Cutlass 2:4 Kernel
|
486
|
+
model_compression_config = (
|
487
|
+
None
|
488
|
+
if sparsity_scheme is None or sparsity_scheme.format == "dense"
|
489
|
+
else self.config
|
490
|
+
)
|
491
|
+
|
492
|
+
scheme = CompressedTensors24(
|
493
|
+
quantized=weight_quant is not None or input_quant is not None,
|
494
|
+
weight_quant=weight_quant,
|
495
|
+
input_quant=input_quant,
|
496
|
+
model_compression_config=model_compression_config,
|
497
|
+
)
|
498
|
+
elif weight_quant is None:
|
499
|
+
logger.warning_once(
|
500
|
+
"Acceleration for non-quantized schemes is "
|
501
|
+
"not supported by Compressed Tensors. "
|
502
|
+
"Falling back to UnquantizedLinearMethod"
|
503
|
+
)
|
504
|
+
return None
|
505
|
+
|
506
|
+
else:
|
507
|
+
# Find the quant_scheme
|
508
|
+
scheme = self._get_scheme_from_parts( # type: ignore
|
509
|
+
weight_quant=weight_quant,
|
510
|
+
input_quant=input_quant,
|
511
|
+
)
|
512
|
+
|
513
|
+
# Raise error if device does not support the scheme
|
514
|
+
# (e.g. fp8 needs ada lovelace)
|
515
|
+
self._check_scheme_supported(scheme.get_min_capability())
|
516
|
+
logger.debug("Using scheme: %s for %s", scheme.__class__.__name__, layer_name)
|
517
|
+
return scheme
|
518
|
+
|
519
|
+
def get_cache_scale(self, name: str) -> Optional[str]:
|
520
|
+
"""
|
521
|
+
Check whether the param name matches the format for k/v cache scales
|
522
|
+
in compressed-tensors. If this is the case, return its equivalent
|
523
|
+
param name expected by vLLM
|
524
|
+
|
525
|
+
:param name: param name
|
526
|
+
:return: matching param name for KV cache scale in vLLM
|
527
|
+
"""
|
528
|
+
if name.endswith(".output_scale") and ".k_proj" in name:
|
529
|
+
return name.replace(".k_proj.output_scale", ".attn.k_scale")
|
530
|
+
if name.endswith(".output_scale") and ".v_proj" in name:
|
531
|
+
return name.replace(".v_proj.output_scale", ".attn.v_scale")
|
532
|
+
# If no matches, return None
|
533
|
+
return None
|
534
|
+
|
535
|
+
@staticmethod
|
536
|
+
def supports_cutlass_24(
|
537
|
+
weight_quant: Optional[QuantizationArgs],
|
538
|
+
input_quant: Optional[QuantizationArgs],
|
539
|
+
sparsity_scheme: Optional[SparsityCompressionConfig] = None,
|
540
|
+
) -> bool:
|
541
|
+
"""
|
542
|
+
Check if the layer is supported by the Cutlass 2:4 Kernel
|
543
|
+
Conditions:
|
544
|
+
- Overarching condition: Sparsity Structure is 2:4
|
545
|
+
- Unquantized cases are supported
|
546
|
+
- Weight only quantization is not-supported
|
547
|
+
- Supported weight quantization strategies are TENSOR and CHANNEL
|
548
|
+
- Supported input quantization strategies are TENSOR and TOKEN
|
549
|
+
- Only 8 bit quantization is supported
|
550
|
+
|
551
|
+
:return: True if the layer is supported by the Cutlass 2:4 Kernel
|
552
|
+
False otherwise
|
553
|
+
"""
|
554
|
+
if sparsity_scheme is None:
|
555
|
+
return False
|
556
|
+
|
557
|
+
is_valid_sparsity_structure: bool = (
|
558
|
+
sparsity_scheme.sparsity_structure == SparsityStructure.TWO_FOUR.value
|
559
|
+
)
|
560
|
+
|
561
|
+
valid_compressors = {
|
562
|
+
CompressionFormat.dense.value,
|
563
|
+
CompressionFormat.sparse_24_bitmask.value,
|
564
|
+
}
|
565
|
+
|
566
|
+
is_valid_sparsity = (
|
567
|
+
is_valid_sparsity_structure and sparsity_scheme.format in valid_compressors
|
568
|
+
)
|
569
|
+
|
570
|
+
if not is_valid_sparsity:
|
571
|
+
return False
|
572
|
+
|
573
|
+
# Unquantized cases are supported
|
574
|
+
if weight_quant is None and input_quant is None:
|
575
|
+
return True
|
576
|
+
|
577
|
+
# Weight only quantization is not-supported
|
578
|
+
if weight_quant is not None and input_quant is None:
|
579
|
+
return False
|
580
|
+
|
581
|
+
supported_weight_quant_strategies = [
|
582
|
+
QuantizationStrategy.TENSOR.value,
|
583
|
+
QuantizationStrategy.CHANNEL.value,
|
584
|
+
]
|
585
|
+
|
586
|
+
assert weight_quant is not None
|
587
|
+
assert input_quant is not None
|
588
|
+
if weight_quant.strategy not in supported_weight_quant_strategies:
|
589
|
+
return False
|
590
|
+
|
591
|
+
supported_input_quant_strategies = [
|
592
|
+
QuantizationStrategy.TENSOR.value,
|
593
|
+
QuantizationStrategy.TOKEN.value,
|
594
|
+
]
|
595
|
+
|
596
|
+
if input_quant.strategy not in supported_input_quant_strategies:
|
597
|
+
return False
|
598
|
+
|
599
|
+
return weight_quant.num_bits == input_quant.num_bits == 8
|
600
|
+
|
601
|
+
|
602
|
+
class CompressedTensorsLinearMethod(LinearMethodBase):
|
603
|
+
|
604
|
+
def __init__(self, quantization_config: CompressedTensorsConfig):
|
605
|
+
self.quantization_config = quantization_config
|
606
|
+
|
607
|
+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
608
|
+
layer.scheme.process_weights_after_loading(layer)
|
609
|
+
|
610
|
+
def create_weights(
|
611
|
+
self,
|
612
|
+
layer: torch.nn.Module,
|
613
|
+
input_size_per_partition: int,
|
614
|
+
output_partition_sizes: List[int],
|
615
|
+
input_size: int,
|
616
|
+
output_size: int,
|
617
|
+
params_dtype: torch.dtype,
|
618
|
+
**extra_weight_attrs,
|
619
|
+
):
|
620
|
+
"""
|
621
|
+
Use the CompressedTensorsScheme associated with each layer to create
|
622
|
+
the necessary parameters for the layer. See LinearMethodBase for param
|
623
|
+
details
|
624
|
+
"""
|
625
|
+
weight_loader = extra_weight_attrs.get("weight_loader")
|
626
|
+
layer.scheme.create_weights(
|
627
|
+
layer=layer,
|
628
|
+
input_size=input_size,
|
629
|
+
input_size_per_partition=input_size_per_partition,
|
630
|
+
output_partition_sizes=output_partition_sizes,
|
631
|
+
output_size=output_size,
|
632
|
+
params_dtype=params_dtype,
|
633
|
+
weight_loader=weight_loader,
|
634
|
+
)
|
635
|
+
|
636
|
+
def apply(
|
637
|
+
self,
|
638
|
+
layer: torch.nn.Module,
|
639
|
+
x: torch.Tensor,
|
640
|
+
bias: Optional[torch.Tensor] = None,
|
641
|
+
):
|
642
|
+
"""
|
643
|
+
Use the output of create_weights and the CompressedTensorsScheme
|
644
|
+
associated with the layer to apply the forward pass with the
|
645
|
+
layer input. See LinearMethodBase for param details
|
646
|
+
|
647
|
+
"""
|
648
|
+
|
649
|
+
scheme = layer.scheme
|
650
|
+
if scheme is None:
|
651
|
+
raise ValueError("A scheme must be defined for each layer")
|
652
|
+
return scheme.apply_weights(layer, x, bias=bias)
|