sglang 0.5.0rc2__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +0 -6
- sglang/bench_one_batch_server.py +7 -2
- sglang/bench_serving.py +3 -3
- sglang/eval/llama3_eval.py +0 -1
- sglang/srt/configs/model_config.py +24 -9
- sglang/srt/configs/update_config.py +40 -5
- sglang/srt/constrained/xgrammar_backend.py +23 -11
- sglang/srt/conversation.py +2 -15
- sglang/srt/disaggregation/ascend/conn.py +1 -3
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +1 -1
- sglang/srt/disaggregation/launch_lb.py +7 -1
- sglang/srt/disaggregation/mini_lb.py +11 -5
- sglang/srt/disaggregation/mooncake/conn.py +141 -47
- sglang/srt/disaggregation/prefill.py +261 -5
- sglang/srt/disaggregation/utils.py +2 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
- sglang/srt/distributed/device_communicators/pynccl.py +68 -18
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
- sglang/srt/distributed/naive_distributed.py +112 -0
- sglang/srt/distributed/parallel_state.py +90 -4
- sglang/srt/entrypoints/context.py +20 -1
- sglang/srt/entrypoints/engine.py +27 -2
- sglang/srt/entrypoints/http_server.py +12 -0
- sglang/srt/entrypoints/openai/protocol.py +2 -2
- sglang/srt/entrypoints/openai/serving_chat.py +22 -6
- sglang/srt/entrypoints/openai/serving_completions.py +9 -1
- sglang/srt/entrypoints/openai/serving_responses.py +2 -2
- sglang/srt/eplb/expert_distribution.py +2 -3
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +24 -0
- sglang/srt/host_shared_memory.py +83 -0
- sglang/srt/layers/attention/ascend_backend.py +132 -22
- sglang/srt/layers/attention/flashattention_backend.py +24 -17
- sglang/srt/layers/attention/flashinfer_backend.py +11 -3
- sglang/srt/layers/attention/flashinfer_mla_backend.py +226 -76
- sglang/srt/layers/attention/triton_backend.py +85 -46
- sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
- sglang/srt/layers/attention/trtllm_mha_backend.py +390 -30
- sglang/srt/layers/attention/trtllm_mla_backend.py +39 -16
- sglang/srt/layers/attention/utils.py +94 -15
- sglang/srt/layers/attention/vision.py +40 -13
- sglang/srt/layers/attention/vision_utils.py +65 -0
- sglang/srt/layers/communicator.py +51 -3
- sglang/srt/layers/dp_attention.py +23 -4
- sglang/srt/layers/elementwise.py +94 -0
- sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
- sglang/srt/layers/layernorm.py +8 -1
- sglang/srt/layers/linear.py +24 -0
- sglang/srt/layers/logits_processor.py +5 -1
- sglang/srt/layers/moe/__init__.py +31 -0
- sglang/srt/layers/moe/ep_moe/layer.py +37 -33
- sglang/srt/layers/moe/fused_moe_native.py +14 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
- sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
- sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
- sglang/srt/layers/moe/moe_runner/base.py +13 -0
- sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
- sglang/srt/layers/moe/router.py +15 -9
- sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
- sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
- sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
- sglang/srt/layers/moe/topk.py +167 -83
- sglang/srt/layers/moe/utils.py +159 -18
- sglang/srt/layers/quantization/__init__.py +13 -14
- sglang/srt/layers/quantization/awq.py +7 -7
- sglang/srt/layers/quantization/base_config.py +2 -6
- sglang/srt/layers/quantization/blockwise_int8.py +4 -12
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -28
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -1
- sglang/srt/layers/quantization/fp8.py +127 -119
- sglang/srt/layers/quantization/fp8_kernel.py +195 -24
- sglang/srt/layers/quantization/fp8_utils.py +34 -9
- sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
- sglang/srt/layers/quantization/gptq.py +5 -4
- sglang/srt/layers/quantization/marlin_utils.py +11 -3
- sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
- sglang/srt/layers/quantization/modelopt_quant.py +165 -68
- sglang/srt/layers/quantization/moe_wna16.py +10 -15
- sglang/srt/layers/quantization/mxfp4.py +206 -37
- sglang/srt/layers/quantization/quark/quark.py +390 -0
- sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
- sglang/srt/layers/quantization/unquant.py +34 -70
- sglang/srt/layers/quantization/utils.py +25 -0
- sglang/srt/layers/quantization/w4afp8.py +7 -8
- sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
- sglang/srt/layers/quantization/w8a8_int8.py +5 -13
- sglang/srt/layers/radix_attention.py +6 -0
- sglang/srt/layers/rotary_embedding.py +1 -0
- sglang/srt/lora/lora_manager.py +21 -22
- sglang/srt/lora/lora_registry.py +3 -3
- sglang/srt/lora/mem_pool.py +26 -24
- sglang/srt/lora/utils.py +10 -12
- sglang/srt/managers/cache_controller.py +76 -18
- sglang/srt/managers/detokenizer_manager.py +10 -2
- sglang/srt/managers/io_struct.py +9 -0
- sglang/srt/managers/mm_utils.py +1 -1
- sglang/srt/managers/schedule_batch.py +4 -9
- sglang/srt/managers/scheduler.py +25 -16
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/template_manager.py +7 -5
- sglang/srt/managers/tokenizer_manager.py +60 -21
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/managers/utils.py +59 -1
- sglang/srt/mem_cache/allocator.py +7 -5
- sglang/srt/mem_cache/allocator_ascend.py +0 -11
- sglang/srt/mem_cache/hicache_storage.py +14 -4
- sglang/srt/mem_cache/memory_pool.py +3 -3
- sglang/srt/mem_cache/memory_pool_host.py +35 -2
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
- sglang/srt/model_executor/cuda_graph_runner.py +25 -12
- sglang/srt/model_executor/forward_batch_info.py +4 -1
- sglang/srt/model_executor/model_runner.py +43 -32
- sglang/srt/model_executor/npu_graph_runner.py +94 -0
- sglang/srt/model_loader/loader.py +24 -6
- sglang/srt/models/dbrx.py +12 -6
- sglang/srt/models/deepseek.py +2 -1
- sglang/srt/models/deepseek_nextn.py +3 -1
- sglang/srt/models/deepseek_v2.py +224 -223
- sglang/srt/models/ernie4.py +2 -2
- sglang/srt/models/glm4_moe.py +25 -63
- sglang/srt/models/glm4v.py +52 -1
- sglang/srt/models/glm4v_moe.py +8 -11
- sglang/srt/models/gpt_oss.py +34 -74
- sglang/srt/models/granitemoe.py +0 -1
- sglang/srt/models/grok.py +376 -48
- sglang/srt/models/interns1.py +12 -47
- sglang/srt/models/internvl.py +6 -51
- sglang/srt/models/llama4.py +0 -2
- sglang/srt/models/minicpm3.py +0 -1
- sglang/srt/models/mixtral.py +0 -2
- sglang/srt/models/nemotron_nas.py +435 -0
- sglang/srt/models/olmoe.py +0 -1
- sglang/srt/models/phi4mm.py +3 -21
- sglang/srt/models/qwen2_5_vl.py +2 -0
- sglang/srt/models/qwen2_moe.py +3 -18
- sglang/srt/models/qwen3.py +2 -2
- sglang/srt/models/qwen3_classification.py +7 -1
- sglang/srt/models/qwen3_moe.py +9 -38
- sglang/srt/models/step3_vl.py +2 -1
- sglang/srt/models/xverse_moe.py +11 -5
- sglang/srt/multimodal/processors/base_processor.py +3 -3
- sglang/srt/multimodal/processors/internvl.py +7 -2
- sglang/srt/multimodal/processors/llava.py +11 -7
- sglang/srt/offloader.py +433 -0
- sglang/srt/operations.py +6 -1
- sglang/srt/reasoning_parser.py +4 -3
- sglang/srt/server_args.py +237 -104
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -0
- sglang/srt/speculative/eagle_utils.py +36 -13
- sglang/srt/speculative/eagle_worker.py +56 -3
- sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
- sglang/srt/two_batch_overlap.py +16 -11
- sglang/srt/utils.py +68 -70
- sglang/test/runners.py +8 -5
- sglang/test/test_block_fp8.py +5 -6
- sglang/test/test_block_fp8_ep.py +13 -19
- sglang/test/test_cutlass_moe.py +4 -6
- sglang/test/test_cutlass_w4a8_moe.py +4 -3
- sglang/test/test_fp4_moe.py +4 -3
- sglang/test/test_utils.py +7 -0
- sglang/utils.py +0 -1
- sglang/version.py +1 -1
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/METADATA +7 -7
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/RECORD +179 -161
- sglang/srt/layers/quantization/fp4.py +0 -557
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/WHEEL +0 -0
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,390 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
|
3
|
+
import fnmatch
|
4
|
+
import logging
|
5
|
+
from typing import Any, List, Optional, cast
|
6
|
+
|
7
|
+
import torch
|
8
|
+
|
9
|
+
from sglang.srt.layers.linear import LinearBase, UnquantizedLinearMethod
|
10
|
+
from sglang.srt.layers.quantization.base_config import ( # noqa: E501
|
11
|
+
LinearMethodBase,
|
12
|
+
QuantizationConfig,
|
13
|
+
QuantizeMethodBase,
|
14
|
+
)
|
15
|
+
from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
|
16
|
+
from sglang.srt.layers.quantization.quark.quark_moe import QuarkMoEMethod
|
17
|
+
from sglang.srt.layers.quantization.quark.schemes import QuarkScheme, QuarkW4A4MXFP4
|
18
|
+
from sglang.srt.layers.quantization.quark.utils import deep_compare, should_ignore_layer
|
19
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
20
|
+
from sglang.srt.utils import get_device_capability
|
21
|
+
|
22
|
+
__all__ = ["QuarkLinearMethod"]
|
23
|
+
|
24
|
+
logger = logging.getLogger(__name__)
|
25
|
+
|
26
|
+
|
27
|
+
class QuarkConfig(QuantizationConfig):
|
28
|
+
|
29
|
+
def __init__(
|
30
|
+
self,
|
31
|
+
quant_config: dict[str, Any],
|
32
|
+
kv_cache_group: Optional[list[str]] = None,
|
33
|
+
kv_cache_config: Optional[dict[str, Any]] = None,
|
34
|
+
pack_method: str = "reorder",
|
35
|
+
):
|
36
|
+
super().__init__()
|
37
|
+
if kv_cache_group is None:
|
38
|
+
kv_cache_group = []
|
39
|
+
self.quant_config = quant_config
|
40
|
+
self.kv_cache_group = kv_cache_group
|
41
|
+
self.kv_cache_config = kv_cache_config
|
42
|
+
self.pack_method = pack_method
|
43
|
+
|
44
|
+
self.packed_modules_mapping = self.quant_config["packed_modules_mapping"]
|
45
|
+
|
46
|
+
def get_linear_method(self) -> "QuarkLinearMethod":
|
47
|
+
return QuarkLinearMethod(self)
|
48
|
+
|
49
|
+
@classmethod
|
50
|
+
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
51
|
+
return [torch.float16, torch.bfloat16]
|
52
|
+
|
53
|
+
@classmethod
|
54
|
+
def get_min_capability(cls) -> int:
|
55
|
+
return 70
|
56
|
+
|
57
|
+
def get_name(self) -> str:
|
58
|
+
return "quark"
|
59
|
+
|
60
|
+
def get_quant_method(
|
61
|
+
self, layer: torch.nn.Module, prefix: str
|
62
|
+
) -> Optional["QuantizeMethodBase"]:
|
63
|
+
# Check if the layer is skipped for quantization.
|
64
|
+
exclude_layers = cast(list[str], self.quant_config.get("exclude"))
|
65
|
+
if should_ignore_layer(
|
66
|
+
prefix, ignore=exclude_layers, fused_mapping=self.packed_modules_mapping
|
67
|
+
):
|
68
|
+
return UnquantizedLinearMethod()
|
69
|
+
|
70
|
+
if isinstance(layer, LinearBase):
|
71
|
+
scheme = self.get_scheme(layer=layer, layer_name=prefix)
|
72
|
+
layer.scheme = scheme
|
73
|
+
return QuarkLinearMethod(self)
|
74
|
+
|
75
|
+
if isinstance(layer, RadixAttention):
|
76
|
+
return QuarkKVCacheMethod(self)
|
77
|
+
|
78
|
+
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
79
|
+
|
80
|
+
if isinstance(layer, FusedMoE):
|
81
|
+
return QuarkMoEMethod.get_moe_method(self, module=layer, layer_name=prefix)
|
82
|
+
|
83
|
+
return None
|
84
|
+
|
85
|
+
@classmethod
|
86
|
+
def from_config(cls, config: dict[str, Any]) -> "QuarkConfig":
|
87
|
+
export_config = config.get("export")
|
88
|
+
if export_config is None:
|
89
|
+
raise ValueError(
|
90
|
+
"The export key should be included in "
|
91
|
+
"the configurations of Quark quantized model"
|
92
|
+
)
|
93
|
+
|
94
|
+
kv_cache_group = cast(list[str], export_config.get("kv_cache_group"))
|
95
|
+
pack_method = cast(str, export_config.get("pack_method"))
|
96
|
+
|
97
|
+
# In the export model of quark, the quantization configuration
|
98
|
+
# of kv_cache is stored in layer_quant_config. First, it is
|
99
|
+
# judged whether kv_cache_group exists, and then it is judged
|
100
|
+
# whether layer_quant_config has a quantization configuration
|
101
|
+
# that matches kv_cache.
|
102
|
+
if len(kv_cache_group) == 0:
|
103
|
+
kv_cache_config = None
|
104
|
+
else:
|
105
|
+
kv_cache_set = set(kv_cache_group)
|
106
|
+
layer_quant_config = cast(dict[str, Any], config.get("layer_quant_config"))
|
107
|
+
layer_quant_names = list(layer_quant_config.keys())
|
108
|
+
layer_quant_set = set(layer_quant_names)
|
109
|
+
|
110
|
+
if not kv_cache_set.issubset(layer_quant_set):
|
111
|
+
raise ValueError(
|
112
|
+
"The Quark quantized model has the "
|
113
|
+
"kv_cache_group parameter setting, "
|
114
|
+
"but no kv_cache quantization settings "
|
115
|
+
"were found in the quantization "
|
116
|
+
"configuration."
|
117
|
+
)
|
118
|
+
|
119
|
+
q_configs = [
|
120
|
+
cast(dict[str, Any], layer_quant_config.get(name))
|
121
|
+
for name in kv_cache_group
|
122
|
+
]
|
123
|
+
if not all(deep_compare(q_config, q_configs[0]) for q_config in q_configs):
|
124
|
+
raise ValueError(
|
125
|
+
"The quantization method used for kv_cache should "
|
126
|
+
"be the same, but the quantization method for the "
|
127
|
+
"kv_cache layer in the config is different."
|
128
|
+
)
|
129
|
+
kv_cache_config = q_configs[0].get("output_tensors")
|
130
|
+
if kv_cache_config is None:
|
131
|
+
raise ValueError("The kv_cache quantization configuration is empty.")
|
132
|
+
|
133
|
+
# Since we have already set kv_cache quantization configurations,
|
134
|
+
# we will remove the quantization configuration for the
|
135
|
+
# output_tensors corresponding to the kv_cache layer.
|
136
|
+
for q_config in q_configs:
|
137
|
+
q_config["output_tensors"] = None
|
138
|
+
|
139
|
+
# In case q_proj output is also quantized, remove the configuration
|
140
|
+
# to keep qkv consistency.
|
141
|
+
q_proj_q_config = cast(dict[str, Any], layer_quant_config.get("*q_proj"))
|
142
|
+
if q_proj_q_config is not None:
|
143
|
+
q_proj_q_config["output_tensors"] = None
|
144
|
+
|
145
|
+
return cls(
|
146
|
+
quant_config=config,
|
147
|
+
kv_cache_group=kv_cache_group,
|
148
|
+
kv_cache_config=kv_cache_config,
|
149
|
+
pack_method=pack_method,
|
150
|
+
)
|
151
|
+
|
152
|
+
@classmethod
|
153
|
+
def get_config_filenames(cls) -> list[str]:
|
154
|
+
return []
|
155
|
+
|
156
|
+
def _check_scheme_supported(self, min_capability: int, error: bool = True) -> bool:
|
157
|
+
capability_tuple = get_device_capability()
|
158
|
+
|
159
|
+
if capability_tuple is not None:
|
160
|
+
assert 0 <= capability_tuple[1] < 10
|
161
|
+
capability = capability_tuple[0] * 10 + capability_tuple[1]
|
162
|
+
|
163
|
+
supported = capability >= min_capability
|
164
|
+
if error and not supported:
|
165
|
+
raise RuntimeError(
|
166
|
+
"Quantization scheme is not supported for ",
|
167
|
+
f"the current GPU. Min capability: {min_capability}. ",
|
168
|
+
f"Current capability: {capability}.",
|
169
|
+
)
|
170
|
+
return supported
|
171
|
+
else:
|
172
|
+
return False
|
173
|
+
|
174
|
+
def _is_mx_fp4(
|
175
|
+
self,
|
176
|
+
weight_quant: Optional[dict[str, Any]],
|
177
|
+
input_quant: Optional[dict[str, Any]],
|
178
|
+
) -> bool:
|
179
|
+
# Confirm weights and input quantized.
|
180
|
+
if weight_quant is None or input_quant is None:
|
181
|
+
logger.debug(
|
182
|
+
"Quark model is not in MX-FP4 format: "
|
183
|
+
"weight_quant or input_quant not set"
|
184
|
+
)
|
185
|
+
return False
|
186
|
+
|
187
|
+
# Input and weight dtype needs to be fp4.
|
188
|
+
if weight_quant.get("dtype") != "fp4" or input_quant.get("dtype") != "fp4":
|
189
|
+
logger.debug("Quark model is not in MX-FP4 format: dtype not fp4")
|
190
|
+
return False
|
191
|
+
|
192
|
+
# Input and weight qscheme needs to be per group.
|
193
|
+
if (
|
194
|
+
weight_quant.get("qscheme") != "per_group"
|
195
|
+
or input_quant.get("qscheme") != "per_group"
|
196
|
+
):
|
197
|
+
logger.debug("Quark model is not in MX-FP4 format: not per_group")
|
198
|
+
return False
|
199
|
+
|
200
|
+
# Input and weight group size needs to be 32.
|
201
|
+
if weight_quant.get("group_size") != 32 or input_quant.get("group_size") != 32:
|
202
|
+
logger.debug("Quark model is not in MX-FP4 format: not group_size=32")
|
203
|
+
return False
|
204
|
+
|
205
|
+
# Weights need to use static quantization.
|
206
|
+
if weight_quant.get("is_dynamic") is True:
|
207
|
+
logger.debug("Quark model is not in MX-FP4 format: not weight static")
|
208
|
+
return False
|
209
|
+
|
210
|
+
# Activations need to use dynamic quantization.
|
211
|
+
if input_quant.get("is_dynamic") is False:
|
212
|
+
logger.debug("Quark model is not in MX-FP4 format: not activation dynamic")
|
213
|
+
return False
|
214
|
+
|
215
|
+
# Activations and weight scales need to be in e8m0 format.
|
216
|
+
if (
|
217
|
+
weight_quant.get("scale_format") != "e8m0"
|
218
|
+
or input_quant.get("scale_format") != "e8m0"
|
219
|
+
):
|
220
|
+
logger.debug("Quark model is not in MX-FP4 format: not scale_format e8m0")
|
221
|
+
return False
|
222
|
+
|
223
|
+
return True
|
224
|
+
|
225
|
+
def _find_matched_config(
|
226
|
+
self, layer_name: str, module: torch.nn.Module
|
227
|
+
) -> dict[str, Any]:
|
228
|
+
|
229
|
+
proj_name = layer_name.split(".")[-1]
|
230
|
+
if proj_name in self.packed_modules_mapping:
|
231
|
+
shard_proj_names = self.packed_modules_mapping[proj_name]
|
232
|
+
|
233
|
+
# Convert fused_name --> [shard_names]
|
234
|
+
shard_names = [
|
235
|
+
layer_name.replace(proj_name, shard_proj_name)
|
236
|
+
for shard_proj_name in shard_proj_names
|
237
|
+
]
|
238
|
+
shard_configs = [
|
239
|
+
self._find_matched_config(shard_name, module)
|
240
|
+
for shard_name in shard_names
|
241
|
+
]
|
242
|
+
if not all(
|
243
|
+
deep_compare(q_config, shard_configs[0]) for q_config in shard_configs
|
244
|
+
):
|
245
|
+
raise ValueError(
|
246
|
+
f"Found a different quantization configuration for "
|
247
|
+
f"{shard_proj_names} in {layer_name}. vLLM "
|
248
|
+
"requires all to use the same scheme."
|
249
|
+
)
|
250
|
+
return shard_configs[0]
|
251
|
+
else:
|
252
|
+
layer_quant_config = cast(
|
253
|
+
dict[str, Any], self.quant_config.get("layer_quant_config")
|
254
|
+
)
|
255
|
+
for name_pattern in layer_quant_config:
|
256
|
+
if fnmatch.fnmatch(layer_name, name_pattern):
|
257
|
+
return layer_quant_config[name_pattern]
|
258
|
+
|
259
|
+
layer_type = type(module).__name__
|
260
|
+
layer_type_quant_config = cast(
|
261
|
+
dict[str, Any], self.quant_config.get("layer_type_quant_config")
|
262
|
+
)
|
263
|
+
if layer_type in layer_type_quant_config:
|
264
|
+
return layer_type_quant_config[layer_type]
|
265
|
+
|
266
|
+
global_quant_config = cast(
|
267
|
+
dict[str, Any], self.quant_config.get("global_quant_config")
|
268
|
+
)
|
269
|
+
return global_quant_config
|
270
|
+
|
271
|
+
def _get_scheme_from_config(self, config: dict[str, Any]) -> "QuarkScheme":
|
272
|
+
if config.get("output_tensors") or config.get("bias"):
|
273
|
+
raise NotImplementedError(
|
274
|
+
"Currently, Quark models with output_tensors "
|
275
|
+
"and bias quantized are not supported"
|
276
|
+
)
|
277
|
+
weight_config = cast(dict[str, Any], config.get("weight"))
|
278
|
+
input_config = cast(dict[str, Any], config.get("input_tensors"))
|
279
|
+
|
280
|
+
if self._is_mx_fp4(weight_config, input_config):
|
281
|
+
return QuarkW4A4MXFP4(weight_config, input_config)
|
282
|
+
|
283
|
+
raise NotImplementedError(
|
284
|
+
"No quark compatible scheme was found. "
|
285
|
+
f"Weight config: {weight_config}, "
|
286
|
+
f"Input config: {input_config}"
|
287
|
+
)
|
288
|
+
|
289
|
+
def get_scheme(self, layer: torch.nn.Module, layer_name: str) -> "QuarkScheme":
|
290
|
+
|
291
|
+
layer_quant_config = self._find_matched_config(layer_name, layer)
|
292
|
+
|
293
|
+
# Find the quant_scheme
|
294
|
+
scheme = self._get_scheme_from_config(layer_quant_config)
|
295
|
+
|
296
|
+
# Raise error if device does not support the scheme
|
297
|
+
# (e.g. fp8 needs ada lovelace)
|
298
|
+
self._check_scheme_supported(scheme.get_min_capability())
|
299
|
+
|
300
|
+
return scheme
|
301
|
+
|
302
|
+
def get_scaled_act_names(self) -> List[str]:
|
303
|
+
return []
|
304
|
+
|
305
|
+
|
306
|
+
class QuarkLinearMethod(LinearMethodBase):
|
307
|
+
|
308
|
+
def __init__(self, quantization_config: QuarkConfig):
|
309
|
+
self.quantization_config = quantization_config
|
310
|
+
|
311
|
+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
312
|
+
layer.scheme.process_weights_after_loading(layer)
|
313
|
+
|
314
|
+
def create_weights(
|
315
|
+
self,
|
316
|
+
layer: torch.nn.Module,
|
317
|
+
input_size_per_partition: int,
|
318
|
+
output_partition_sizes: list[int],
|
319
|
+
input_size: int,
|
320
|
+
output_size: int,
|
321
|
+
params_dtype: torch.dtype,
|
322
|
+
**extra_weight_attrs,
|
323
|
+
):
|
324
|
+
"""
|
325
|
+
Use the CompressedTensorsScheme associated with each layer to create
|
326
|
+
the necessary parameters for the layer. See LinearMethodBase for param
|
327
|
+
details
|
328
|
+
"""
|
329
|
+
weight_loader = extra_weight_attrs.get("weight_loader")
|
330
|
+
layer.scheme.create_weights(
|
331
|
+
layer=layer,
|
332
|
+
input_size=input_size,
|
333
|
+
input_size_per_partition=input_size_per_partition,
|
334
|
+
output_partition_sizes=output_partition_sizes,
|
335
|
+
output_size=output_size,
|
336
|
+
params_dtype=params_dtype,
|
337
|
+
weight_loader=weight_loader,
|
338
|
+
)
|
339
|
+
|
340
|
+
def apply(
|
341
|
+
self,
|
342
|
+
layer: torch.nn.Module,
|
343
|
+
x: torch.Tensor,
|
344
|
+
bias: Optional[torch.Tensor] = None,
|
345
|
+
):
|
346
|
+
"""
|
347
|
+
Use the output of create_weights and the CompressedTensorsScheme
|
348
|
+
associated with the layer to apply the forward pass with the
|
349
|
+
layer input. See LinearMethodBase for param details
|
350
|
+
|
351
|
+
"""
|
352
|
+
scheme = layer.scheme
|
353
|
+
if scheme is None:
|
354
|
+
raise ValueError("A scheme must be defined for each layer")
|
355
|
+
return scheme.apply_weights(layer, x, bias=bias)
|
356
|
+
|
357
|
+
|
358
|
+
class QuarkKVCacheMethod(BaseKVCacheMethod):
|
359
|
+
"""
|
360
|
+
Supports loading kv-cache scaling factors from quark checkpoints.
|
361
|
+
"""
|
362
|
+
|
363
|
+
def __init__(self, quant_config: QuarkConfig):
|
364
|
+
self.validate_kv_cache_config(quant_config.kv_cache_config)
|
365
|
+
super().__init__(quant_config)
|
366
|
+
|
367
|
+
@staticmethod
|
368
|
+
def validate_kv_cache_config(kv_cache_config: Optional[dict[str, Any]]):
|
369
|
+
"""
|
370
|
+
Validator for the kv cache configuration. Useful for controlling the
|
371
|
+
kv cache quantization schemes, that are being supported in vLLM
|
372
|
+
:param kv_cache_config: the quark kv cache scheme
|
373
|
+
"""
|
374
|
+
if kv_cache_config is None:
|
375
|
+
return
|
376
|
+
|
377
|
+
dtype = kv_cache_config.get("dtype")
|
378
|
+
if dtype != "fp8_e4m3":
|
379
|
+
raise NotImplementedError(
|
380
|
+
"Currently supported kv cache quantization is "
|
381
|
+
f"dtype=fp8_e4m3, however received {dtype}"
|
382
|
+
)
|
383
|
+
|
384
|
+
qscheme = kv_cache_config.get("qscheme")
|
385
|
+
if qscheme != "per_tensor":
|
386
|
+
raise NotImplementedError(
|
387
|
+
"Only support per-tensor scaling factor "
|
388
|
+
"for quark KV cache. "
|
389
|
+
f"Expected qscheme: per_tensor, found qscheme: {qscheme}"
|
390
|
+
)
|
@@ -0,0 +1,197 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
|
3
|
+
from __future__ import annotations
|
4
|
+
|
5
|
+
import logging
|
6
|
+
from typing import TYPE_CHECKING, Any, Callable, Optional
|
7
|
+
|
8
|
+
import torch
|
9
|
+
from aiter import ActivationType, QuantType, biased_grouped_topk
|
10
|
+
from aiter.fused_moe import fused_moe
|
11
|
+
from aiter.utility.fp4_utils import e8m0_shuffle
|
12
|
+
|
13
|
+
from sglang.srt.utils import get_bool_env_var, mxfp_supported, set_weight_attrs
|
14
|
+
|
15
|
+
logger = logging.getLogger(__name__)
|
16
|
+
|
17
|
+
__all__ = ["QuarkMoEMethod", "QuarkW4A4MXFp4MoEMethod"]
|
18
|
+
|
19
|
+
OCP_MX_BLOCK_SIZE = 32
|
20
|
+
|
21
|
+
if TYPE_CHECKING:
|
22
|
+
from sglang.srt.layers.moe.topk import TopKOutput
|
23
|
+
|
24
|
+
|
25
|
+
class QuarkMoEMethod:
|
26
|
+
def __new__(cls, *args, **kwargs):
|
27
|
+
from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase
|
28
|
+
|
29
|
+
if not hasattr(cls, "_initialized"):
|
30
|
+
original_init = cls.__init__
|
31
|
+
new_cls = type(
|
32
|
+
cls.__name__,
|
33
|
+
(FusedMoEMethodBase,),
|
34
|
+
{
|
35
|
+
"__init__": original_init,
|
36
|
+
**{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
|
37
|
+
},
|
38
|
+
)
|
39
|
+
obj = super(new_cls, new_cls).__new__(new_cls)
|
40
|
+
obj.__init__(*args, **kwargs)
|
41
|
+
return obj
|
42
|
+
return super().__new__(cls)
|
43
|
+
|
44
|
+
@staticmethod
|
45
|
+
def get_moe_method(
|
46
|
+
quant_config: "QuarkConfig", # type: ignore # noqa E501 # noqa F821
|
47
|
+
module: torch.nn.Module,
|
48
|
+
layer_name: str,
|
49
|
+
) -> "QuarkMoEMethod":
|
50
|
+
layer_quant_config = quant_config._find_matched_config(layer_name, module)
|
51
|
+
|
52
|
+
if layer_quant_config.get("output_tensors") or layer_quant_config.get("bias"):
|
53
|
+
raise NotImplementedError(
|
54
|
+
"Currently, Quark models with "
|
55
|
+
"output_tensors and bias "
|
56
|
+
"quantized are not supported"
|
57
|
+
)
|
58
|
+
weight_config = layer_quant_config.get("weight")
|
59
|
+
input_config = layer_quant_config.get("input_tensors")
|
60
|
+
|
61
|
+
if quant_config._is_mx_fp4(weight_config, input_config):
|
62
|
+
return QuarkW4A4MXFp4MoEMethod(weight_config, input_config)
|
63
|
+
else:
|
64
|
+
raise RuntimeError("Unsupported FusedMoe scheme")
|
65
|
+
|
66
|
+
|
67
|
+
class QuarkW4A4MXFp4MoEMethod(QuarkMoEMethod):
|
68
|
+
|
69
|
+
def __init__(self, weight_config: dict[str, Any], input_config: dict[str, Any]):
|
70
|
+
self.weight_quant = weight_config
|
71
|
+
self.input_quant = input_config
|
72
|
+
|
73
|
+
weight_qscheme = self.weight_quant.get("qscheme")
|
74
|
+
input_qscheme = self.input_quant.get("qscheme")
|
75
|
+
if not (weight_qscheme == "per_group" and input_qscheme == "per_group"):
|
76
|
+
raise ValueError(
|
77
|
+
"For MX(FP4) Fused MoE layers, only per-group scales "
|
78
|
+
"for weights and activations are supported. Found "
|
79
|
+
f"{weight_qscheme}, {input_qscheme}"
|
80
|
+
) # noqa E501
|
81
|
+
|
82
|
+
self.static_input_scales = not self.input_quant.get("is_dynamic")
|
83
|
+
self.with_bias = False
|
84
|
+
|
85
|
+
def create_weights(
|
86
|
+
self,
|
87
|
+
layer: torch.nn.Module,
|
88
|
+
num_experts: int,
|
89
|
+
hidden_size: int,
|
90
|
+
intermediate_size_per_partition: int,
|
91
|
+
params_dtype: torch.dtype,
|
92
|
+
**extra_weight_attrs,
|
93
|
+
):
|
94
|
+
|
95
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
96
|
+
|
97
|
+
# Add the quantization method used (per tensor/grouped/channel)
|
98
|
+
# to ensure the weight scales are loaded in properly
|
99
|
+
extra_weight_attrs.update(
|
100
|
+
{"quant_method": FusedMoeWeightScaleSupported.BLOCK.value}
|
101
|
+
)
|
102
|
+
|
103
|
+
params_dtype = torch.uint8
|
104
|
+
|
105
|
+
# WEIGHTS
|
106
|
+
w13_weight = torch.nn.Parameter(
|
107
|
+
torch.empty(
|
108
|
+
num_experts,
|
109
|
+
2 * intermediate_size_per_partition,
|
110
|
+
hidden_size // 2,
|
111
|
+
dtype=params_dtype,
|
112
|
+
),
|
113
|
+
requires_grad=False,
|
114
|
+
)
|
115
|
+
layer.register_parameter("w13_weight", w13_weight)
|
116
|
+
|
117
|
+
set_weight_attrs(w13_weight, extra_weight_attrs)
|
118
|
+
|
119
|
+
w2_weight = torch.nn.Parameter(
|
120
|
+
torch.empty(
|
121
|
+
num_experts,
|
122
|
+
hidden_size,
|
123
|
+
intermediate_size_per_partition // 2,
|
124
|
+
dtype=params_dtype,
|
125
|
+
),
|
126
|
+
requires_grad=False,
|
127
|
+
)
|
128
|
+
layer.register_parameter("w2_weight", w2_weight)
|
129
|
+
|
130
|
+
set_weight_attrs(w2_weight, extra_weight_attrs)
|
131
|
+
|
132
|
+
# WEIGHT_SCALES
|
133
|
+
w13_weight_scale = torch.nn.Parameter(
|
134
|
+
torch.ones(
|
135
|
+
num_experts,
|
136
|
+
2 * intermediate_size_per_partition,
|
137
|
+
hidden_size // OCP_MX_BLOCK_SIZE,
|
138
|
+
dtype=params_dtype,
|
139
|
+
),
|
140
|
+
requires_grad=False,
|
141
|
+
)
|
142
|
+
w2_weight_scale = torch.nn.Parameter(
|
143
|
+
torch.ones(
|
144
|
+
num_experts,
|
145
|
+
hidden_size,
|
146
|
+
intermediate_size_per_partition // OCP_MX_BLOCK_SIZE,
|
147
|
+
dtype=params_dtype,
|
148
|
+
),
|
149
|
+
requires_grad=False,
|
150
|
+
)
|
151
|
+
set_weight_attrs(w2_weight_scale, extra_weight_attrs)
|
152
|
+
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
153
|
+
|
154
|
+
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
155
|
+
layer.register_parameter("w2_weight_scale", w2_weight_scale)
|
156
|
+
|
157
|
+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
158
|
+
float_dtype = torch.get_default_dtype()
|
159
|
+
|
160
|
+
# Pre-shuffle weight scales
|
161
|
+
s0, s1, _ = layer.w13_weight_scale.shape
|
162
|
+
w13_weight_scale = layer.w13_weight_scale.view(s0 * s1, -1)
|
163
|
+
w13_weight_scale = e8m0_shuffle(w13_weight_scale)
|
164
|
+
# layer.w13_weight_scale = torch.nn.Parameter(w13_weight_scale, requires_grad=False)
|
165
|
+
layer.w13_weight_scale.data = w13_weight_scale.view(s0, s1, -1)
|
166
|
+
|
167
|
+
s0, s1, _ = layer.w2_weight_scale.shape
|
168
|
+
w2_weight_scale = layer.w2_weight_scale.view(s0 * s1, -1)
|
169
|
+
w2_weight_scale = e8m0_shuffle(w2_weight_scale)
|
170
|
+
# layer.w2_weight_scale = torch.nn.Parameter(w2_weight_scale, requires_grad=False)
|
171
|
+
layer.w2_weight_scale.data = w2_weight_scale.view(s0, s1, -1)
|
172
|
+
|
173
|
+
def apply(
|
174
|
+
self,
|
175
|
+
layer: torch.nn.Module,
|
176
|
+
x: torch.Tensor,
|
177
|
+
topk_output: TopKOutput,
|
178
|
+
moe_runner_config: MoeRunnerConfig,
|
179
|
+
) -> torch.Tensor:
|
180
|
+
topk_weights, topk_ids, _ = topk_output
|
181
|
+
|
182
|
+
return fused_moe(
|
183
|
+
x,
|
184
|
+
layer.w13_weight,
|
185
|
+
layer.w2_weight,
|
186
|
+
topk_weights,
|
187
|
+
topk_ids,
|
188
|
+
quant_type=QuantType.per_1x32,
|
189
|
+
w1_scale=layer.w13_weight_scale,
|
190
|
+
w2_scale=layer.w2_weight_scale,
|
191
|
+
activation=(
|
192
|
+
ActivationType.Silu
|
193
|
+
if moe_runner_config.activation == "silu"
|
194
|
+
else ActivationType.Gelu
|
195
|
+
),
|
196
|
+
doweight_stage1=False,
|
197
|
+
)
|