sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +2 -1
- sglang/eval/loogle_eval.py +7 -0
- sglang/srt/_custom_ops.py +29 -1
- sglang/srt/configs/deepseekvl2.py +11 -2
- sglang/srt/configs/internvl.py +3 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/model_config.py +10 -8
- sglang/srt/configs/update_config.py +3 -1
- sglang/srt/conversation.py +2 -1
- sglang/srt/custom_op.py +5 -2
- sglang/srt/disaggregation/common/conn.py +34 -6
- sglang/srt/disaggregation/decode.py +9 -1
- sglang/srt/disaggregation/mini_lb.py +3 -2
- sglang/srt/disaggregation/mooncake/conn.py +93 -76
- sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
- sglang/srt/disaggregation/nixl/conn.py +17 -13
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
- sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
- sglang/srt/distributed/parallel_state.py +103 -15
- sglang/srt/entrypoints/engine.py +31 -33
- sglang/srt/entrypoints/http_server.py +20 -32
- sglang/srt/entrypoints/openai/protocol.py +3 -3
- sglang/srt/entrypoints/openai/serving_chat.py +48 -6
- sglang/srt/eplb/expert_location_dispatch.py +1 -1
- sglang/srt/function_call/base_format_detector.py +74 -12
- sglang/srt/function_call/deepseekv3_detector.py +26 -11
- sglang/srt/function_call/ebnf_composer.py +95 -63
- sglang/srt/function_call/function_call_parser.py +4 -2
- sglang/srt/function_call/kimik2_detector.py +41 -16
- sglang/srt/function_call/llama32_detector.py +6 -3
- sglang/srt/function_call/mistral_detector.py +11 -3
- sglang/srt/function_call/pythonic_detector.py +16 -14
- sglang/srt/function_call/qwen25_detector.py +12 -3
- sglang/srt/function_call/qwen3_coder_detector.py +151 -0
- sglang/srt/hf_transformers_utils.py +0 -1
- sglang/srt/layers/activation.py +24 -3
- sglang/srt/layers/attention/base_attn_backend.py +3 -1
- sglang/srt/layers/attention/flashattention_backend.py +3 -3
- sglang/srt/layers/attention/flashinfer_backend.py +40 -1
- sglang/srt/layers/communicator.py +12 -12
- sglang/srt/layers/dp_attention.py +72 -24
- sglang/srt/layers/linear.py +13 -102
- sglang/srt/layers/logits_processor.py +34 -24
- sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
- sglang/srt/layers/moe/ep_moe/layer.py +23 -402
- sglang/srt/layers/moe/fused_moe_native.py +7 -47
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +54 -263
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
- sglang/srt/layers/moe/topk.py +190 -23
- sglang/srt/layers/quantization/__init__.py +20 -134
- sglang/srt/layers/quantization/awq.py +578 -11
- sglang/srt/layers/quantization/awq_triton.py +339 -0
- sglang/srt/layers/quantization/base_config.py +85 -10
- sglang/srt/layers/quantization/blockwise_int8.py +17 -55
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +23 -79
- sglang/srt/layers/quantization/fp8.py +273 -62
- sglang/srt/layers/quantization/fp8_kernel.py +210 -46
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gptq.py +501 -143
- sglang/srt/layers/quantization/marlin_utils.py +790 -0
- sglang/srt/layers/quantization/modelopt_quant.py +34 -112
- sglang/srt/layers/quantization/moe_wna16.py +45 -49
- sglang/srt/layers/quantization/petit.py +252 -0
- sglang/srt/layers/quantization/petit_utils.py +104 -0
- sglang/srt/layers/quantization/qoq.py +7 -6
- sglang/srt/layers/quantization/scalar_type.py +352 -0
- sglang/srt/layers/quantization/unquant.py +422 -0
- sglang/srt/layers/quantization/utils.py +340 -9
- sglang/srt/layers/quantization/w4afp8.py +8 -4
- sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
- sglang/srt/layers/quantization/w8a8_int8.py +51 -115
- sglang/srt/layers/radix_attention.py +5 -3
- sglang/srt/layers/vocab_parallel_embedding.py +1 -41
- sglang/srt/lora/lora.py +0 -4
- sglang/srt/lora/lora_manager.py +162 -164
- sglang/srt/lora/lora_registry.py +124 -0
- sglang/srt/lora/mem_pool.py +83 -35
- sglang/srt/lora/utils.py +12 -5
- sglang/srt/managers/cache_controller.py +288 -0
- sglang/srt/managers/io_struct.py +60 -30
- sglang/srt/managers/mm_utils.py +7 -8
- sglang/srt/managers/schedule_batch.py +163 -113
- sglang/srt/managers/schedule_policy.py +68 -27
- sglang/srt/managers/scheduler.py +256 -86
- sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
- sglang/srt/managers/tokenizer_manager.py +38 -27
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/allocator.py +74 -23
- sglang/srt/mem_cache/base_prefix_cache.py +14 -2
- sglang/srt/mem_cache/chunk_cache.py +5 -2
- sglang/srt/mem_cache/hicache_storage.py +168 -0
- sglang/srt/mem_cache/hiradix_cache.py +194 -5
- sglang/srt/mem_cache/memory_pool.py +16 -1
- sglang/srt/mem_cache/memory_pool_host.py +44 -2
- sglang/srt/mem_cache/radix_cache.py +26 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +66 -31
- sglang/srt/model_executor/forward_batch_info.py +210 -25
- sglang/srt/model_executor/model_runner.py +147 -42
- sglang/srt/model_loader/loader.py +7 -1
- sglang/srt/model_loader/utils.py +4 -4
- sglang/srt/models/clip.py +1 -1
- sglang/srt/models/deepseek.py +9 -6
- sglang/srt/models/deepseek_janus_pro.py +1 -1
- sglang/srt/models/deepseek_v2.py +192 -173
- sglang/srt/models/deepseek_vl2.py +5 -5
- sglang/srt/models/gemma.py +48 -0
- sglang/srt/models/gemma2.py +52 -0
- sglang/srt/models/gemma3_causal.py +63 -0
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -4
- sglang/srt/models/granitemoe.py +385 -0
- sglang/srt/models/grok.py +9 -3
- sglang/srt/models/hunyuan.py +63 -16
- sglang/srt/models/internvl.py +1 -1
- sglang/srt/models/kimi_vl.py +1 -1
- sglang/srt/models/llama.py +41 -0
- sglang/srt/models/llama4.py +11 -11
- sglang/srt/models/llava.py +2 -2
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/minicpm.py +0 -2
- sglang/srt/models/minicpmo.py +3 -7
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mixtral.py +9 -2
- sglang/srt/models/mllama.py +3 -5
- sglang/srt/models/mllama4.py +13 -6
- sglang/srt/models/olmoe.py +8 -5
- sglang/srt/models/persimmon.py +330 -0
- sglang/srt/models/phi.py +321 -0
- sglang/srt/models/phi4mm.py +44 -4
- sglang/srt/models/phi4mm_audio.py +1260 -0
- sglang/srt/models/phi4mm_utils.py +1917 -0
- sglang/srt/models/phimoe.py +9 -3
- sglang/srt/models/qwen.py +37 -0
- sglang/srt/models/qwen2.py +41 -0
- sglang/srt/models/qwen2_5_vl.py +4 -4
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +53 -9
- sglang/srt/models/qwen2_vl.py +4 -4
- sglang/srt/models/qwen3.py +65 -1
- sglang/srt/models/qwen3_moe.py +57 -24
- sglang/srt/models/vila.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +91 -97
- sglang/srt/multimodal/processors/clip.py +21 -19
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
- sglang/srt/multimodal/processors/gemma3.py +13 -17
- sglang/srt/multimodal/processors/gemma3n.py +19 -23
- sglang/srt/multimodal/processors/internvl.py +9 -10
- sglang/srt/multimodal/processors/janus_pro.py +12 -27
- sglang/srt/multimodal/processors/kimi_vl.py +12 -14
- sglang/srt/multimodal/processors/llava.py +4 -2
- sglang/srt/multimodal/processors/minicpm.py +35 -44
- sglang/srt/multimodal/processors/mlama.py +21 -18
- sglang/srt/multimodal/processors/mllama4.py +4 -5
- sglang/srt/multimodal/processors/phi4mm.py +63 -39
- sglang/srt/multimodal/processors/pixtral.py +14 -35
- sglang/srt/multimodal/processors/qwen_audio.py +65 -0
- sglang/srt/multimodal/processors/qwen_vl.py +16 -21
- sglang/srt/multimodal/processors/vila.py +14 -14
- sglang/srt/reasoning_parser.py +46 -4
- sglang/srt/sampling/sampling_batch_info.py +6 -5
- sglang/srt/sampling/sampling_params.py +8 -1
- sglang/srt/server_args.py +454 -270
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +46 -37
- sglang/srt/speculative/eagle_utils.py +51 -23
- sglang/srt/speculative/eagle_worker.py +59 -44
- sglang/srt/two_batch_overlap.py +10 -5
- sglang/srt/utils.py +44 -69
- sglang/test/runners.py +14 -3
- sglang/test/test_activation.py +50 -1
- sglang/test/test_block_fp8.py +8 -3
- sglang/test/test_block_fp8_ep.py +1 -1
- sglang/test/test_custom_ops.py +12 -7
- sglang/test/test_cutlass_w4a8_moe.py +1 -3
- sglang/test/test_fp4_moe.py +1 -3
- sglang/test/test_marlin_moe.py +286 -0
- sglang/test/test_marlin_utils.py +171 -0
- sglang/test/test_utils.py +35 -0
- sglang/version.py +1 -1
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/METADATA +10 -10
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/RECORD +198 -175
- sglang/srt/layers/quantization/quant_utils.py +0 -166
- sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/top_level.txt +0 -0
@@ -1,46 +1,62 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
import logging
|
4
|
+
from dataclasses import dataclass
|
2
5
|
from fractions import Fraction
|
3
|
-
from typing import Any, Callable, Dict, List, Optional, Union
|
6
|
+
from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Union
|
4
7
|
|
5
8
|
import torch
|
6
9
|
|
7
|
-
from sglang.srt.layers.
|
10
|
+
from sglang.srt.layers.parameter import (
|
11
|
+
BasevLLMParameter,
|
12
|
+
ChannelQuantScaleParameter,
|
13
|
+
GroupQuantScaleParameter,
|
14
|
+
PackedColumnParameter,
|
15
|
+
PackedvLLMParameter,
|
16
|
+
RowvLLMParameter,
|
17
|
+
permute_param_layout_,
|
18
|
+
)
|
8
19
|
from sglang.srt.layers.quantization.base_config import (
|
20
|
+
FusedMoEMethodBase,
|
21
|
+
LinearMethodBase,
|
9
22
|
QuantizationConfig,
|
10
23
|
QuantizeMethodBase,
|
11
24
|
)
|
12
|
-
from sglang.srt.layers.quantization.
|
13
|
-
|
25
|
+
from sglang.srt.layers.quantization.marlin_utils import (
|
26
|
+
apply_gptq_marlin_linear,
|
27
|
+
check_marlin_supported,
|
28
|
+
check_marlin_supports_shape,
|
29
|
+
marlin_is_k_full,
|
30
|
+
marlin_make_empty_g_idx,
|
31
|
+
marlin_make_workspace,
|
32
|
+
marlin_moe_permute_scales,
|
33
|
+
marlin_permute_scales,
|
34
|
+
marlin_repeat_scales_on_all_ranks,
|
35
|
+
marlin_sort_g_idx,
|
36
|
+
marlin_zero_points,
|
37
|
+
verify_marlin_supported,
|
38
|
+
)
|
39
|
+
from sglang.srt.layers.quantization.scalar_type import ScalarType, scalar_types
|
40
|
+
from sglang.srt.layers.quantization.utils import (
|
41
|
+
get_linear_quant_method,
|
42
|
+
replace_parameter,
|
43
|
+
unpack_cols,
|
44
|
+
)
|
14
45
|
|
15
|
-
|
46
|
+
if TYPE_CHECKING:
|
47
|
+
from sglang.srt.layers.moe.topk import TopKOutput
|
16
48
|
|
17
49
|
try:
|
18
50
|
from vllm import _custom_ops as ops
|
19
|
-
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
|
20
|
-
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
21
|
-
FusedMoE,
|
22
|
-
FusedMoEMethodBase,
|
23
|
-
FusedMoeWeightScaleSupported,
|
24
|
-
GPTQMarlinLinearMethod,
|
25
|
-
marlin_moe_permute_scales,
|
26
|
-
)
|
27
|
-
from vllm.model_executor.layers.quantization.marlin import MarlinLinearMethod
|
28
|
-
from vllm.model_executor.layers.quantization.utils.marlin_utils import (
|
29
|
-
check_marlin_supported,
|
30
|
-
)
|
31
|
-
from vllm.scalar_type import scalar_types
|
32
|
-
|
33
|
-
VLLM_AVAILABLE = True
|
34
51
|
except ImportError:
|
35
|
-
|
52
|
+
ops = None
|
36
53
|
|
37
|
-
|
54
|
+
from sglang.srt.utils import is_cuda
|
38
55
|
|
39
|
-
|
56
|
+
_is_cuda = is_cuda()
|
40
57
|
|
41
|
-
|
42
|
-
|
43
|
-
uint8b128 = "uint8b128"
|
58
|
+
if _is_cuda:
|
59
|
+
from sgl_kernel import fused_marlin_moe
|
44
60
|
|
45
61
|
|
46
62
|
logger = logging.getLogger(__name__)
|
@@ -54,6 +70,38 @@ def check_marlin_format(hf_quant_cfg: Dict[str, Any]) -> bool:
|
|
54
70
|
)
|
55
71
|
|
56
72
|
|
73
|
+
def gptq_marlin_moe_repack(
|
74
|
+
b_q_weight: torch.Tensor,
|
75
|
+
perm: torch.Tensor,
|
76
|
+
size_k: int,
|
77
|
+
size_n: int,
|
78
|
+
num_bits: int,
|
79
|
+
) -> torch.Tensor:
|
80
|
+
num_experts = b_q_weight.shape[0]
|
81
|
+
assert size_k % 16 == 0
|
82
|
+
output = torch.empty(
|
83
|
+
(num_experts, size_k // 16, size_n * (num_bits // 2)),
|
84
|
+
device=b_q_weight.device,
|
85
|
+
dtype=b_q_weight.dtype,
|
86
|
+
)
|
87
|
+
for e in range(num_experts):
|
88
|
+
output[e] = torch.ops.sgl_kernel.gptq_marlin_repack(
|
89
|
+
b_q_weight[e], perm[e], size_k, size_n, num_bits
|
90
|
+
)
|
91
|
+
return output
|
92
|
+
|
93
|
+
|
94
|
+
@dataclass
|
95
|
+
class MarlinLinearLayerConfig:
|
96
|
+
full_weight_shape: tuple[int, int] # [in, out]
|
97
|
+
partition_weight_shape: tuple[int, int]
|
98
|
+
weight_type: ScalarType
|
99
|
+
act_type: torch.dtype
|
100
|
+
group_size: int
|
101
|
+
zero_points: bool
|
102
|
+
has_g_idx: bool
|
103
|
+
|
104
|
+
|
57
105
|
class GPTQConfig(QuantizationConfig):
|
58
106
|
"""Config class for GPTQ.
|
59
107
|
|
@@ -139,7 +187,7 @@ class GPTQConfig(QuantizationConfig):
|
|
139
187
|
return ["quantize_config.json"]
|
140
188
|
|
141
189
|
@classmethod
|
142
|
-
def from_config(cls, config: Dict[str, Any]) ->
|
190
|
+
def from_config(cls, config: Dict[str, Any]) -> GPTQConfig:
|
143
191
|
dynamic = cls.get_from_keys_or(config, ["dynamic"], default={})
|
144
192
|
dynamic = {} if dynamic is None else dynamic
|
145
193
|
|
@@ -151,11 +199,16 @@ class GPTQConfig(QuantizationConfig):
|
|
151
199
|
|
152
200
|
def get_quant_method(
|
153
201
|
self, layer: torch.nn.Module, prefix: str
|
154
|
-
) -> Optional[
|
202
|
+
) -> Optional[LinearMethodBase]:
|
155
203
|
# Delay the import to avoid circular dependency
|
156
|
-
from sglang.srt.layers.
|
204
|
+
from sglang.srt.layers.linear import LinearBase
|
205
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
157
206
|
|
158
|
-
|
207
|
+
if isinstance(layer, LinearBase):
|
208
|
+
return get_linear_quant_method(self, layer, prefix, GPTQLinearMethod)
|
209
|
+
elif isinstance(layer, FusedMoE):
|
210
|
+
raise TypeError("GPTQ Method does not support MoE, please use gptq_marlin")
|
211
|
+
return None
|
159
212
|
|
160
213
|
|
161
214
|
class GPTQMarlinConfig(QuantizationConfig):
|
@@ -258,7 +311,7 @@ class GPTQMarlinConfig(QuantizationConfig):
|
|
258
311
|
return ["quantize_config.json"]
|
259
312
|
|
260
313
|
@classmethod
|
261
|
-
def from_config(cls, config: Dict[str, Any]) ->
|
314
|
+
def from_config(cls, config: Dict[str, Any]) -> GPTQMarlinConfig:
|
262
315
|
dynamic = cls.get_from_keys_or(config, ["dynamic"], default={})
|
263
316
|
dynamic = {} if dynamic is None else dynamic
|
264
317
|
|
@@ -309,18 +362,9 @@ class GPTQMarlinConfig(QuantizationConfig):
|
|
309
362
|
) -> Optional[QuantizeMethodBase]:
|
310
363
|
# Delay the import to avoid circular dependency
|
311
364
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
312
|
-
from sglang.srt.layers.quantization import get_linear_quant_method
|
313
365
|
|
314
366
|
if isinstance(layer, FusedMoE):
|
315
367
|
return GPTQMarlinMoEMethod(self)
|
316
|
-
# TODO: re-enable after SGLang syncs with vllm >= 0.7.3
|
317
|
-
# if layer.num_experts > 32:
|
318
|
-
# # For MoEs with many experts the moe_wna16 kernel is faster
|
319
|
-
# return MoeWNA16Config.from_config(self.full_config).get_quant_method(
|
320
|
-
# layer, prefix
|
321
|
-
# )
|
322
|
-
# else:
|
323
|
-
# return GPTQMarlinMoEMethod(self)
|
324
368
|
return get_linear_quant_method(self, layer, prefix, GPTQMarlinLinearMethod)
|
325
369
|
|
326
370
|
@classmethod
|
@@ -344,112 +388,439 @@ class GPTQMarlinConfig(QuantizationConfig):
|
|
344
388
|
if (num_bits, sym) not in cls.TYPE_MAP:
|
345
389
|
return False
|
346
390
|
|
347
|
-
assert (
|
348
|
-
VLLM_AVAILABLE
|
349
|
-
), "vllm is not installed, to use gptq_marlin, please install vllm"
|
350
|
-
|
351
391
|
return check_marlin_supported(
|
352
392
|
quant_type=cls.TYPE_MAP[(num_bits, sym)], group_size=group_size
|
353
393
|
)
|
354
394
|
|
355
395
|
|
356
|
-
class
|
357
|
-
"""
|
396
|
+
class GPTQLinearMethod(LinearMethodBase):
|
397
|
+
"""Linear method for GPTQ.
|
358
398
|
|
359
|
-
|
399
|
+
Args:
|
400
|
+
quant_config: The GPTQ quantization config.
|
360
401
|
"""
|
361
402
|
|
362
|
-
def __init__(
|
403
|
+
def __init__(self, quant_config: GPTQConfig):
|
404
|
+
self.quant_config = quant_config
|
405
|
+
|
406
|
+
def create_weights(
|
363
407
|
self,
|
364
|
-
|
365
|
-
|
366
|
-
|
367
|
-
|
368
|
-
|
369
|
-
|
370
|
-
|
408
|
+
layer: torch.nn.Module,
|
409
|
+
input_size_per_partition: int,
|
410
|
+
output_partition_sizes: list[int],
|
411
|
+
input_size: int,
|
412
|
+
output_size: int,
|
413
|
+
params_dtype: torch.dtype,
|
414
|
+
**extra_weight_attrs,
|
415
|
+
):
|
416
|
+
del output_size # Unused.
|
417
|
+
weight_loader = extra_weight_attrs.get("weight_loader")
|
418
|
+
if input_size_per_partition % self.quant_config.group_size != 0:
|
371
419
|
raise ValueError(
|
372
|
-
"
|
373
|
-
"
|
374
|
-
|
420
|
+
"The input size is not aligned with the quantized "
|
421
|
+
"weight shape. This can be caused by too large "
|
422
|
+
"tensor parallel size."
|
423
|
+
)
|
424
|
+
output_size_per_partition = sum(output_partition_sizes)
|
425
|
+
if output_size_per_partition % self.quant_config.pack_factor.numerator != 0:
|
426
|
+
raise ValueError(
|
427
|
+
"The output size is not aligned with the quantized "
|
428
|
+
"weight shape. This can be caused by too large "
|
429
|
+
"tensor parallel size."
|
430
|
+
)
|
431
|
+
|
432
|
+
if self.quant_config.group_size != -1:
|
433
|
+
group_size = self.quant_config.group_size
|
434
|
+
else:
|
435
|
+
group_size = input_size
|
436
|
+
|
437
|
+
self.use_shuffle = True
|
438
|
+
scale_and_zero_size = input_size // group_size
|
439
|
+
scale_and_zero_input_dim = None
|
440
|
+
if (
|
441
|
+
input_size != input_size_per_partition
|
442
|
+
and self.quant_config.group_size != -1
|
443
|
+
):
|
444
|
+
if self.quant_config.desc_act:
|
445
|
+
self.use_shuffle = False
|
446
|
+
else:
|
447
|
+
# we need to partition qzeros and scales for exllama kernel
|
448
|
+
scale_and_zero_size = input_size_per_partition // group_size
|
449
|
+
scale_and_zero_input_dim = 0
|
450
|
+
|
451
|
+
qweight = PackedvLLMParameter(
|
452
|
+
data=torch.empty(
|
453
|
+
input_size_per_partition // self.quant_config.pack_factor,
|
454
|
+
output_size_per_partition,
|
455
|
+
dtype=torch.int32,
|
456
|
+
),
|
457
|
+
input_dim=0,
|
458
|
+
output_dim=1,
|
459
|
+
packed_dim=0,
|
460
|
+
packed_factor=self.quant_config.pack_factor,
|
461
|
+
weight_loader=weight_loader,
|
462
|
+
)
|
463
|
+
|
464
|
+
g_idx = RowvLLMParameter(
|
465
|
+
data=torch.tensor(
|
466
|
+
[
|
467
|
+
i // self.quant_config.group_size
|
468
|
+
for i in range(input_size_per_partition)
|
469
|
+
],
|
470
|
+
dtype=torch.int32,
|
471
|
+
),
|
472
|
+
input_dim=0,
|
473
|
+
weight_loader=weight_loader,
|
474
|
+
)
|
475
|
+
qzeros_args = {
|
476
|
+
"data": torch.empty(
|
477
|
+
scale_and_zero_size,
|
478
|
+
output_size_per_partition // self.quant_config.pack_factor,
|
479
|
+
dtype=torch.int32,
|
480
|
+
),
|
481
|
+
"weight_loader": weight_loader,
|
482
|
+
}
|
483
|
+
weight_scale_args = {
|
484
|
+
"data": torch.empty(
|
485
|
+
scale_and_zero_size,
|
486
|
+
output_size_per_partition,
|
487
|
+
dtype=params_dtype,
|
488
|
+
),
|
489
|
+
"weight_loader": weight_loader,
|
490
|
+
}
|
491
|
+
if scale_and_zero_input_dim is None:
|
492
|
+
scales = ChannelQuantScaleParameter(output_dim=1, **weight_scale_args)
|
493
|
+
qzeros = PackedColumnParameter(
|
494
|
+
output_dim=1,
|
495
|
+
packed_dim=1,
|
496
|
+
packed_factor=self.quant_config.pack_factor,
|
497
|
+
**qzeros_args,
|
375
498
|
)
|
376
499
|
|
377
|
-
|
378
|
-
|
500
|
+
else:
|
501
|
+
scales = GroupQuantScaleParameter(
|
502
|
+
output_dim=1, input_dim=0, **weight_scale_args
|
503
|
+
)
|
504
|
+
qzeros = PackedvLLMParameter(
|
505
|
+
input_dim=0,
|
506
|
+
output_dim=1,
|
507
|
+
packed_dim=1,
|
508
|
+
packed_factor=self.quant_config.pack_factor,
|
509
|
+
**qzeros_args,
|
510
|
+
)
|
379
511
|
|
380
|
-
|
381
|
-
|
512
|
+
layer.register_parameter("qweight", qweight)
|
513
|
+
layer.register_parameter("g_idx", g_idx)
|
514
|
+
layer.register_parameter("qzeros", qzeros)
|
515
|
+
layer.register_parameter("scales", scales)
|
382
516
|
|
383
|
-
|
384
|
-
|
517
|
+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
518
|
+
# for torch.compile
|
519
|
+
layer.qzeros = torch.nn.Parameter(layer.qzeros.data, requires_grad=False)
|
520
|
+
layer.qweight = torch.nn.Parameter(layer.qweight.data, requires_grad=False)
|
521
|
+
layer.g_idx = torch.nn.Parameter(layer.g_idx.data, requires_grad=False)
|
522
|
+
layer.scales = torch.nn.Parameter(layer.scales.data, requires_grad=False)
|
523
|
+
|
524
|
+
# exllama needs to shuffle the weight after the weight is loaded
|
525
|
+
# here we do the shuffle on first forward pass
|
526
|
+
if self.use_shuffle:
|
527
|
+
if self.quant_config.desc_act:
|
528
|
+
layer.g_idx.data = torch.argsort(layer.g_idx).to(torch.int)
|
529
|
+
else:
|
530
|
+
layer.g_idx.data = torch.empty(
|
531
|
+
(0,), dtype=torch.int, device=layer.g_idx.device
|
532
|
+
)
|
533
|
+
ops.gptq_shuffle(layer.qweight, layer.g_idx, self.quant_config.weight_bits)
|
385
534
|
|
386
|
-
|
387
|
-
self
|
535
|
+
def apply(
|
536
|
+
self,
|
537
|
+
layer: torch.nn.Module,
|
538
|
+
x: torch.Tensor,
|
539
|
+
bias: Optional[torch.Tensor] = None,
|
540
|
+
) -> torch.Tensor:
|
541
|
+
out_shape = x.shape[:-1] + (layer.qweight.shape[-1],)
|
542
|
+
reshaped_x = x.reshape(-1, x.shape[-1])
|
543
|
+
|
544
|
+
output = ops.gptq_gemm(
|
545
|
+
reshaped_x,
|
546
|
+
layer.qweight,
|
547
|
+
layer.qzeros,
|
548
|
+
layer.scales,
|
549
|
+
layer.g_idx,
|
550
|
+
self.use_shuffle,
|
551
|
+
self.quant_config.weight_bits,
|
552
|
+
)
|
553
|
+
if bias is not None:
|
554
|
+
output.add_(bias)
|
555
|
+
return output.reshape(out_shape)
|
388
556
|
|
389
|
-
# Max parallel problems to solve at once (improves large
|
390
|
-
# batch performance)
|
391
|
-
self.max_parallel = 16
|
392
557
|
|
393
|
-
|
394
|
-
|
558
|
+
class GPTQMarlinLinearMethod(LinearMethodBase):
|
559
|
+
"""Linear method for GPTQ Marlin.
|
395
560
|
|
396
|
-
|
397
|
-
|
398
|
-
|
399
|
-
|
561
|
+
Args:
|
562
|
+
quant_config: The GPTQ Marlin quantization config.
|
563
|
+
"""
|
564
|
+
|
565
|
+
_kernel_backends_being_used: set[str] = set()
|
566
|
+
|
567
|
+
def __init__(self, quant_config: GPTQMarlinConfig) -> None:
|
568
|
+
self.quant_config = quant_config
|
569
|
+
|
570
|
+
# Verify supported on platform.
|
571
|
+
verify_marlin_supported(
|
572
|
+
quant_type=self.quant_config.quant_type,
|
573
|
+
group_size=self.quant_config.group_size,
|
400
574
|
)
|
401
575
|
|
402
|
-
|
403
|
-
|
404
|
-
|
576
|
+
def create_weights(
|
577
|
+
self,
|
578
|
+
layer: torch.nn.Module,
|
579
|
+
input_size_per_partition: int,
|
580
|
+
output_partition_sizes: list[int],
|
581
|
+
input_size: int,
|
582
|
+
output_size: int,
|
583
|
+
params_dtype: torch.dtype,
|
584
|
+
**extra_weight_attrs,
|
585
|
+
) -> None:
|
586
|
+
output_size_per_partition = sum(output_partition_sizes)
|
587
|
+
is_row_parallel = input_size != input_size_per_partition
|
588
|
+
weight_loader = extra_weight_attrs.get("weight_loader")
|
589
|
+
|
590
|
+
self.kernel_config = MarlinLinearLayerConfig(
|
591
|
+
full_weight_shape=(input_size, output_size),
|
592
|
+
partition_weight_shape=(
|
593
|
+
input_size_per_partition,
|
594
|
+
output_size_per_partition,
|
595
|
+
),
|
596
|
+
weight_type=self.quant_config.quant_type,
|
597
|
+
act_type=params_dtype,
|
598
|
+
group_size=self.quant_config.group_size,
|
599
|
+
zero_points=False,
|
600
|
+
has_g_idx=self.quant_config.desc_act,
|
601
|
+
)
|
602
|
+
# Normalize group_size
|
603
|
+
if self.quant_config.group_size != -1:
|
604
|
+
group_size = self.quant_config.group_size
|
605
|
+
else:
|
606
|
+
group_size = input_size
|
405
607
|
|
406
|
-
|
407
|
-
|
408
|
-
|
608
|
+
# Determine sharding
|
609
|
+
if marlin_repeat_scales_on_all_ranks(
|
610
|
+
self.quant_config.desc_act, self.quant_config.group_size, is_row_parallel
|
611
|
+
):
|
612
|
+
# By setting scale_dim == None, weight_loader will
|
613
|
+
# repeat the scales on each GPU in TP>1 case.
|
614
|
+
scales_and_zp_input_dim = None
|
615
|
+
scales_and_zp_size = input_size // group_size
|
616
|
+
else:
|
617
|
+
# By setting scale_dim == 0, weight_loader will
|
618
|
+
# shard the scales in TP>1 case.
|
619
|
+
scales_and_zp_input_dim = 0
|
620
|
+
scales_and_zp_size = input_size_per_partition // group_size
|
621
|
+
|
622
|
+
# Quantized weights
|
623
|
+
qweight = PackedvLLMParameter(
|
624
|
+
data=torch.empty(
|
625
|
+
input_size_per_partition // self.quant_config.pack_factor,
|
626
|
+
output_size_per_partition,
|
627
|
+
dtype=torch.int32,
|
628
|
+
),
|
629
|
+
input_dim=0,
|
630
|
+
output_dim=1,
|
631
|
+
packed_dim=0,
|
632
|
+
packed_factor=self.quant_config.pack_factor,
|
633
|
+
weight_loader=weight_loader,
|
634
|
+
)
|
409
635
|
|
410
|
-
|
411
|
-
|
412
|
-
|
413
|
-
|
636
|
+
# Activation order
|
637
|
+
g_idx = RowvLLMParameter(
|
638
|
+
data=torch.empty(
|
639
|
+
input_size_per_partition,
|
640
|
+
dtype=torch.int32,
|
641
|
+
),
|
642
|
+
input_dim=0,
|
643
|
+
weight_loader=weight_loader,
|
644
|
+
)
|
414
645
|
|
415
|
-
|
416
|
-
|
417
|
-
|
646
|
+
qzeros_args = {
|
647
|
+
"data": torch.empty(
|
648
|
+
scales_and_zp_size,
|
649
|
+
output_size_per_partition // self.quant_config.pack_factor,
|
650
|
+
dtype=torch.int32,
|
651
|
+
),
|
652
|
+
"weight_loader": weight_loader,
|
653
|
+
}
|
654
|
+
weight_scale_args = {
|
655
|
+
"data": torch.empty(
|
656
|
+
scales_and_zp_size,
|
657
|
+
output_size_per_partition,
|
658
|
+
dtype=params_dtype,
|
659
|
+
),
|
660
|
+
"weight_loader": weight_loader,
|
661
|
+
}
|
662
|
+
|
663
|
+
if scales_and_zp_input_dim is None:
|
664
|
+
scales = ChannelQuantScaleParameter(output_dim=1, **weight_scale_args)
|
665
|
+
qzeros = PackedColumnParameter(
|
666
|
+
output_dim=1,
|
667
|
+
packed_dim=1,
|
668
|
+
packed_factor=self.quant_config.pack_factor,
|
669
|
+
**qzeros_args,
|
670
|
+
)
|
418
671
|
|
419
|
-
|
420
|
-
|
421
|
-
|
422
|
-
|
423
|
-
|
672
|
+
else:
|
673
|
+
scales = GroupQuantScaleParameter(
|
674
|
+
output_dim=1, input_dim=0, **weight_scale_args
|
675
|
+
)
|
676
|
+
qzeros = PackedvLLMParameter(
|
677
|
+
input_dim=0,
|
678
|
+
output_dim=1,
|
679
|
+
packed_dim=1,
|
680
|
+
packed_factor=self.quant_config.pack_factor,
|
681
|
+
**qzeros_args,
|
682
|
+
)
|
424
683
|
|
425
|
-
|
426
|
-
|
427
|
-
|
684
|
+
layer.register_parameter("qweight", qweight)
|
685
|
+
layer.register_parameter("g_idx", g_idx)
|
686
|
+
layer.register_parameter("scales", scales)
|
687
|
+
layer.register_parameter("qzeros", qzeros)
|
428
688
|
|
429
|
-
|
430
|
-
|
689
|
+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
690
|
+
device = getattr(layer, "qweight").device
|
691
|
+
c = self.kernel_config
|
692
|
+
|
693
|
+
check_marlin_supports_shape(
|
694
|
+
c.partition_weight_shape[1], # out_features
|
695
|
+
c.partition_weight_shape[0], # in_features
|
696
|
+
c.full_weight_shape[0], # in_features
|
697
|
+
c.group_size,
|
431
698
|
)
|
432
699
|
|
433
|
-
|
434
|
-
|
435
|
-
|
700
|
+
row_parallel = c.partition_weight_shape[0] != c.full_weight_shape[0]
|
701
|
+
self.is_k_full = marlin_is_k_full(c.has_g_idx, row_parallel)
|
702
|
+
|
703
|
+
# Allocate marlin workspace.
|
704
|
+
self.workspace = marlin_make_workspace(device)
|
705
|
+
|
706
|
+
# Default names since marlin requires empty parameters for these,
|
707
|
+
# TODO: remove this requirement from marlin (allow optional tensors)
|
708
|
+
self.w_q_name = "qweight"
|
709
|
+
self.w_s_name = "scales"
|
710
|
+
self.w_zp_name = "qzeros"
|
711
|
+
self.w_gidx_name = "g_idx"
|
712
|
+
|
713
|
+
def _transform_param(
|
714
|
+
layer: torch.nn.Module, name: Optional[str], fn: Callable
|
715
|
+
) -> None:
|
716
|
+
if name is not None and getattr(layer, name, None) is not None:
|
717
|
+
|
718
|
+
old_param = getattr(layer, name)
|
719
|
+
new_param = fn(old_param)
|
720
|
+
# replace the parameter with torch.nn.Parameter for TorchDynamo
|
721
|
+
# compatibility
|
722
|
+
replace_parameter(
|
723
|
+
layer, name, torch.nn.Parameter(new_param.data, requires_grad=False)
|
724
|
+
)
|
725
|
+
|
726
|
+
def transform_w_q(x):
|
727
|
+
assert isinstance(x, BasevLLMParameter)
|
728
|
+
permute_param_layout_(x, input_dim=0, output_dim=1, packed_dim=0)
|
729
|
+
x.data = torch.ops.sgl_kernel.gptq_marlin_repack(
|
730
|
+
x.data.contiguous(),
|
731
|
+
perm=layer.g_idx_sort_indices,
|
732
|
+
size_k=c.partition_weight_shape[0],
|
733
|
+
size_n=c.partition_weight_shape[1],
|
734
|
+
num_bits=c.weight_type.size_bits,
|
436
735
|
)
|
437
|
-
|
438
|
-
|
736
|
+
return x
|
737
|
+
|
738
|
+
def transform_w_s(x):
|
739
|
+
assert isinstance(x, BasevLLMParameter)
|
740
|
+
permute_param_layout_(x, input_dim=0, output_dim=1)
|
741
|
+
x.data = marlin_permute_scales(
|
742
|
+
x.data.contiguous(),
|
743
|
+
size_k=c.partition_weight_shape[0],
|
744
|
+
size_n=c.partition_weight_shape[1],
|
745
|
+
group_size=c.group_size,
|
746
|
+
)
|
747
|
+
return x
|
439
748
|
|
440
|
-
|
749
|
+
if c.has_g_idx:
|
750
|
+
g_idx, g_idx_sort_indices = marlin_sort_g_idx(
|
751
|
+
getattr(layer, self.w_gidx_name)
|
752
|
+
)
|
753
|
+
_transform_param(layer, self.w_gidx_name, lambda _: g_idx)
|
754
|
+
layer.g_idx_sort_indices = g_idx_sort_indices
|
755
|
+
else:
|
756
|
+
setattr(layer, self.w_gidx_name, marlin_make_empty_g_idx(device))
|
757
|
+
layer.g_idx_sort_indices = marlin_make_empty_g_idx(device)
|
441
758
|
|
442
|
-
|
443
|
-
|
444
|
-
|
445
|
-
|
446
|
-
|
759
|
+
if c.zero_points:
|
760
|
+
grouped_k = (
|
761
|
+
c.partition_weight_shape[0] // c.group_size if c.group_size != -1 else 1
|
762
|
+
)
|
763
|
+
_transform_param(
|
764
|
+
layer,
|
765
|
+
self.w_zp_name,
|
766
|
+
lambda x: marlin_zero_points(
|
767
|
+
unpack_cols(
|
768
|
+
x.t(),
|
769
|
+
c.weight_type.size_bits,
|
770
|
+
grouped_k,
|
771
|
+
c.partition_weight_shape[1],
|
772
|
+
),
|
773
|
+
size_k=grouped_k,
|
774
|
+
size_n=c.partition_weight_shape[1],
|
775
|
+
num_bits=c.weight_type.size_bits,
|
776
|
+
),
|
777
|
+
)
|
778
|
+
else:
|
779
|
+
setattr(layer, self.w_zp_name, marlin_make_empty_g_idx(device))
|
780
|
+
_transform_param(layer, self.w_q_name, transform_w_q)
|
781
|
+
_transform_param(layer, self.w_s_name, transform_w_s)
|
447
782
|
|
448
|
-
|
449
|
-
|
450
|
-
|
451
|
-
|
452
|
-
|
783
|
+
def apply(
|
784
|
+
self,
|
785
|
+
layer: torch.nn.Module,
|
786
|
+
x: torch.Tensor,
|
787
|
+
bias: Optional[torch.Tensor] = None,
|
788
|
+
) -> torch.Tensor:
|
789
|
+
c = self.kernel_config
|
790
|
+
|
791
|
+
def _get_weight_params(
|
792
|
+
layer: torch.nn.Module,
|
793
|
+
) -> tuple[
|
794
|
+
torch.Tensor, # w_q
|
795
|
+
torch.Tensor, # w_s
|
796
|
+
Optional[torch.Tensor], # w_zp,
|
797
|
+
Optional[torch.Tensor], # w_gidx
|
798
|
+
]:
|
799
|
+
return (
|
800
|
+
getattr(layer, self.w_q_name),
|
801
|
+
getattr(layer, self.w_s_name),
|
802
|
+
getattr(layer, self.w_zp_name or "", None),
|
803
|
+
getattr(layer, self.w_gidx_name or "", None),
|
804
|
+
)
|
805
|
+
|
806
|
+
w_q, w_s, w_zp, w_gidx = _get_weight_params(layer)
|
807
|
+
|
808
|
+
# `process_weights_after_loading` will ensure w_zp and w_gidx are not
|
809
|
+
# None for marlin
|
810
|
+
return apply_gptq_marlin_linear(
|
811
|
+
input=x,
|
812
|
+
weight=w_q,
|
813
|
+
weight_scale=w_s,
|
814
|
+
weight_zp=w_zp, # type: ignore
|
815
|
+
g_idx=w_gidx, # type: ignore
|
816
|
+
g_idx_sort_indices=layer.g_idx_sort_indices,
|
817
|
+
workspace=self.workspace,
|
818
|
+
wtype=c.weight_type,
|
819
|
+
input_size_per_partition=c.partition_weight_shape[0],
|
820
|
+
output_size_per_partition=c.partition_weight_shape[1],
|
821
|
+
is_k_full=self.is_k_full,
|
822
|
+
bias=bias,
|
823
|
+
)
|
453
824
|
|
454
825
|
|
455
826
|
class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
@@ -467,6 +838,10 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
|
467
838
|
params_dtype: torch.dtype,
|
468
839
|
**extra_weight_attrs,
|
469
840
|
):
|
841
|
+
# Delay the import to avoid circular dependency
|
842
|
+
from sglang.srt.layers.linear import set_weight_attrs
|
843
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
844
|
+
|
470
845
|
intermediate_size = extra_weight_attrs.pop("intermediate_size")
|
471
846
|
|
472
847
|
self.is_k_full = (not self.quant_config.desc_act) or (
|
@@ -644,20 +1019,20 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
|
644
1019
|
requires_grad=False,
|
645
1020
|
)
|
646
1021
|
# Repack weights
|
647
|
-
marlin_w13_qweight =
|
1022
|
+
marlin_w13_qweight = gptq_marlin_moe_repack(
|
648
1023
|
layer.w13_qweight,
|
649
1024
|
layer.w13_g_idx_sort_indices,
|
650
1025
|
layer.w13_qweight.shape[1] * self.quant_config.pack_factor,
|
651
1026
|
layer.w13_qweight.shape[2],
|
652
|
-
self.quant_config.
|
1027
|
+
self.quant_config.weight_bits,
|
653
1028
|
)
|
654
1029
|
replace_parameter(layer, "w13_qweight", marlin_w13_qweight)
|
655
|
-
marlin_w2_qweight =
|
1030
|
+
marlin_w2_qweight = gptq_marlin_moe_repack(
|
656
1031
|
layer.w2_qweight,
|
657
1032
|
layer.w2_g_idx_sort_indices,
|
658
1033
|
layer.w2_qweight.shape[1] * self.quant_config.pack_factor,
|
659
1034
|
layer.w2_qweight.shape[2],
|
660
|
-
self.quant_config.
|
1035
|
+
self.quant_config.weight_bits,
|
661
1036
|
)
|
662
1037
|
replace_parameter(layer, "w2_qweight", marlin_w2_qweight)
|
663
1038
|
# Repack scales
|
@@ -685,39 +1060,22 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
|
685
1060
|
self,
|
686
1061
|
layer: torch.nn.Module,
|
687
1062
|
x: torch.Tensor,
|
688
|
-
|
689
|
-
|
690
|
-
renormalize: bool,
|
691
|
-
use_grouped_topk: bool = False,
|
692
|
-
topk_group: Optional[int] = None,
|
693
|
-
num_expert_group: Optional[int] = None,
|
694
|
-
global_num_experts: int = -1,
|
695
|
-
expert_map: Optional[torch.Tensor] = None,
|
696
|
-
custom_routing_function: Optional[Callable] = None,
|
697
|
-
scoring_func: str = "softmax",
|
698
|
-
e_score_correction_bias: Optional[torch.Tensor] = None,
|
1063
|
+
topk_output: TopKOutput,
|
1064
|
+
*,
|
699
1065
|
activation: str = "silu",
|
1066
|
+
**kwargs,
|
700
1067
|
) -> torch.Tensor:
|
1068
|
+
# Delay the import to avoid circular dependency
|
1069
|
+
|
701
1070
|
assert activation == "silu", "Only SiLU activation is supported."
|
702
1071
|
|
703
1072
|
# The input must currently be float16
|
704
1073
|
orig_dtype = x.dtype
|
705
1074
|
x = x.half()
|
706
1075
|
|
707
|
-
topk_weights, topk_ids =
|
708
|
-
hidden_states=x,
|
709
|
-
router_logits=router_logits,
|
710
|
-
use_grouped_topk=use_grouped_topk,
|
711
|
-
top_k=top_k,
|
712
|
-
renormalize=renormalize,
|
713
|
-
topk_group=topk_group,
|
714
|
-
num_expert_group=num_expert_group,
|
715
|
-
custom_routing_function=custom_routing_function,
|
716
|
-
scoring_func=scoring_func,
|
717
|
-
e_score_correction_bias=e_score_correction_bias,
|
718
|
-
)
|
1076
|
+
topk_weights, topk_ids, router_logits = topk_output
|
719
1077
|
|
720
|
-
return
|
1078
|
+
return fused_marlin_moe(
|
721
1079
|
x,
|
722
1080
|
layer.w13_qweight,
|
723
1081
|
layer.w2_qweight,
|
@@ -730,6 +1088,6 @@ class GPTQMarlinMoEMethod(FusedMoEMethodBase):
|
|
730
1088
|
g_idx2=layer.w2_g_idx,
|
731
1089
|
sort_indices1=layer.w13_g_idx_sort_indices,
|
732
1090
|
sort_indices2=layer.w2_g_idx_sort_indices,
|
733
|
-
|
1091
|
+
num_bits=self.quant_config.weight_bits,
|
734
1092
|
is_k_full=self.is_k_full,
|
735
1093
|
).to(orig_dtype)
|