sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +2 -1
- sglang/eval/loogle_eval.py +7 -0
- sglang/srt/configs/deepseekvl2.py +11 -2
- sglang/srt/configs/internvl.py +3 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/model_config.py +9 -7
- sglang/srt/configs/update_config.py +3 -1
- sglang/srt/conversation.py +1 -0
- sglang/srt/custom_op.py +5 -2
- sglang/srt/disaggregation/decode.py +9 -1
- sglang/srt/disaggregation/mooncake/conn.py +44 -56
- sglang/srt/distributed/parallel_state.py +33 -0
- sglang/srt/entrypoints/engine.py +30 -26
- sglang/srt/entrypoints/openai/serving_chat.py +21 -2
- sglang/srt/eplb/expert_location_dispatch.py +1 -1
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/qwen3_detector.py +150 -0
- sglang/srt/hf_transformers_utils.py +0 -1
- sglang/srt/layers/activation.py +13 -0
- sglang/srt/layers/attention/flashattention_backend.py +3 -3
- sglang/srt/layers/attention/flashinfer_backend.py +40 -1
- sglang/srt/layers/linear.py +13 -102
- sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
- sglang/srt/layers/moe/ep_moe/layer.py +23 -402
- sglang/srt/layers/moe/fused_moe_native.py +7 -47
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +35 -45
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
- sglang/srt/layers/moe/topk.py +187 -12
- sglang/srt/layers/quantization/__init__.py +20 -134
- sglang/srt/layers/quantization/awq.py +578 -11
- sglang/srt/layers/quantization/awq_triton.py +339 -0
- sglang/srt/layers/quantization/base_config.py +85 -10
- sglang/srt/layers/quantization/blockwise_int8.py +17 -55
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +24 -73
- sglang/srt/layers/quantization/fp8.py +273 -62
- sglang/srt/layers/quantization/fp8_kernel.py +210 -46
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gptq.py +501 -143
- sglang/srt/layers/quantization/marlin_utils.py +790 -0
- sglang/srt/layers/quantization/modelopt_quant.py +26 -108
- sglang/srt/layers/quantization/moe_wna16.py +45 -49
- sglang/srt/layers/quantization/petit.py +252 -0
- sglang/srt/layers/quantization/petit_utils.py +104 -0
- sglang/srt/layers/quantization/qoq.py +7 -6
- sglang/srt/layers/quantization/scalar_type.py +352 -0
- sglang/srt/layers/quantization/unquant.py +422 -0
- sglang/srt/layers/quantization/utils.py +343 -3
- sglang/srt/layers/quantization/w4afp8.py +8 -4
- sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
- sglang/srt/layers/quantization/w8a8_int8.py +51 -115
- sglang/srt/layers/vocab_parallel_embedding.py +1 -41
- sglang/srt/lora/lora.py +0 -4
- sglang/srt/lora/lora_manager.py +87 -53
- sglang/srt/lora/mem_pool.py +81 -33
- sglang/srt/lora/utils.py +12 -5
- sglang/srt/managers/cache_controller.py +241 -0
- sglang/srt/managers/io_struct.py +41 -29
- sglang/srt/managers/mm_utils.py +7 -8
- sglang/srt/managers/schedule_batch.py +150 -110
- sglang/srt/managers/schedule_policy.py +68 -27
- sglang/srt/managers/scheduler.py +243 -61
- sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
- sglang/srt/managers/tokenizer_manager.py +11 -3
- sglang/srt/managers/tp_worker.py +14 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/allocator.py +7 -16
- sglang/srt/mem_cache/base_prefix_cache.py +14 -2
- sglang/srt/mem_cache/chunk_cache.py +5 -2
- sglang/srt/mem_cache/hicache_storage.py +152 -0
- sglang/srt/mem_cache/hiradix_cache.py +179 -4
- sglang/srt/mem_cache/memory_pool.py +16 -1
- sglang/srt/mem_cache/memory_pool_host.py +41 -2
- sglang/srt/mem_cache/radix_cache.py +26 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +5 -6
- sglang/srt/model_executor/forward_batch_info.py +14 -1
- sglang/srt/model_executor/model_runner.py +109 -22
- sglang/srt/model_loader/loader.py +7 -1
- sglang/srt/model_loader/utils.py +4 -4
- sglang/srt/models/clip.py +1 -1
- sglang/srt/models/deepseek.py +9 -6
- sglang/srt/models/deepseek_janus_pro.py +1 -1
- sglang/srt/models/deepseek_v2.py +191 -171
- sglang/srt/models/deepseek_vl2.py +5 -5
- sglang/srt/models/gemma.py +48 -0
- sglang/srt/models/gemma2.py +52 -0
- sglang/srt/models/gemma3_causal.py +63 -0
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -4
- sglang/srt/models/granitemoe.py +385 -0
- sglang/srt/models/grok.py +9 -3
- sglang/srt/models/hunyuan.py +63 -16
- sglang/srt/models/internvl.py +1 -1
- sglang/srt/models/kimi_vl.py +1 -1
- sglang/srt/models/llama.py +41 -0
- sglang/srt/models/llama4.py +11 -11
- sglang/srt/models/llava.py +2 -2
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/minicpm.py +0 -2
- sglang/srt/models/minicpmo.py +3 -7
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mixtral.py +9 -2
- sglang/srt/models/mllama.py +3 -5
- sglang/srt/models/mllama4.py +3 -3
- sglang/srt/models/olmoe.py +8 -5
- sglang/srt/models/persimmon.py +330 -0
- sglang/srt/models/phi.py +321 -0
- sglang/srt/models/phi4mm.py +44 -4
- sglang/srt/models/phi4mm_audio.py +1260 -0
- sglang/srt/models/phi4mm_utils.py +1917 -0
- sglang/srt/models/phimoe.py +9 -3
- sglang/srt/models/qwen.py +37 -0
- sglang/srt/models/qwen2.py +41 -0
- sglang/srt/models/qwen2_5_vl.py +4 -4
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +53 -5
- sglang/srt/models/qwen2_vl.py +4 -4
- sglang/srt/models/qwen3.py +65 -1
- sglang/srt/models/qwen3_moe.py +56 -18
- sglang/srt/models/vila.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +91 -97
- sglang/srt/multimodal/processors/clip.py +21 -19
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
- sglang/srt/multimodal/processors/gemma3.py +13 -17
- sglang/srt/multimodal/processors/gemma3n.py +19 -23
- sglang/srt/multimodal/processors/internvl.py +9 -10
- sglang/srt/multimodal/processors/janus_pro.py +12 -27
- sglang/srt/multimodal/processors/kimi_vl.py +12 -14
- sglang/srt/multimodal/processors/llava.py +4 -2
- sglang/srt/multimodal/processors/minicpm.py +35 -44
- sglang/srt/multimodal/processors/mlama.py +21 -18
- sglang/srt/multimodal/processors/mllama4.py +4 -5
- sglang/srt/multimodal/processors/phi4mm.py +63 -39
- sglang/srt/multimodal/processors/pixtral.py +14 -35
- sglang/srt/multimodal/processors/qwen_audio.py +65 -0
- sglang/srt/multimodal/processors/qwen_vl.py +16 -21
- sglang/srt/multimodal/processors/vila.py +14 -14
- sglang/srt/sampling/sampling_params.py +8 -1
- sglang/srt/server_args.py +393 -230
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +9 -1
- sglang/srt/two_batch_overlap.py +1 -0
- sglang/srt/utils.py +27 -1
- sglang/test/runners.py +14 -3
- sglang/test/test_block_fp8.py +8 -3
- sglang/test/test_block_fp8_ep.py +1 -1
- sglang/test/test_custom_ops.py +12 -7
- sglang/test/test_cutlass_w4a8_moe.py +1 -3
- sglang/test/test_fp4_moe.py +1 -3
- sglang/test/test_marlin_moe.py +286 -0
- sglang/test/test_marlin_utils.py +171 -0
- sglang/test/test_utils.py +35 -0
- sglang/version.py +1 -1
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/METADATA +8 -8
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/RECORD +166 -146
- sglang/srt/layers/quantization/quant_utils.py +0 -166
- sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,790 @@
|
|
1
|
+
# SPDX-License-Identifier: Apache-2.0
|
2
|
+
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/utils/marlin_utils.py
|
3
|
+
|
4
|
+
from __future__ import annotations
|
5
|
+
|
6
|
+
import logging
|
7
|
+
from typing import TYPE_CHECKING, Any, Optional
|
8
|
+
|
9
|
+
import numpy
|
10
|
+
import torch
|
11
|
+
|
12
|
+
from sglang.srt.layers.parameter import (
|
13
|
+
BasevLLMParameter,
|
14
|
+
ChannelQuantScaleParameter,
|
15
|
+
GroupQuantScaleParameter,
|
16
|
+
PackedvLLMParameter,
|
17
|
+
)
|
18
|
+
from sglang.srt.layers.quantization.base_config import (
|
19
|
+
LinearMethodBase,
|
20
|
+
QuantizationConfig,
|
21
|
+
)
|
22
|
+
from sglang.srt.layers.quantization.scalar_type import ScalarType, scalar_types
|
23
|
+
from sglang.srt.layers.quantization.utils import pack_cols, unpack_cols
|
24
|
+
from sglang.srt.utils import get_device_capability
|
25
|
+
|
26
|
+
if TYPE_CHECKING:
|
27
|
+
from sglang.srt.layers.linear import LinearBase
|
28
|
+
|
29
|
+
try:
|
30
|
+
from vllm import _custom_ops as ops
|
31
|
+
except ImportError:
|
32
|
+
ops = None
|
33
|
+
|
34
|
+
logger = logging.getLogger(__name__)
|
35
|
+
|
36
|
+
GPTQ_MARLIN_TILE = 16
|
37
|
+
GPTQ_MARLIN_MIN_THREAD_N = 64
|
38
|
+
GPTQ_MARLIN_MIN_THREAD_K = 128
|
39
|
+
GPTQ_MARLIN_MAX_PARALLEL = 16
|
40
|
+
|
41
|
+
MARLIN_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
|
42
|
+
|
43
|
+
# In case there is a performance issue with Marlin, the variable below can be
|
44
|
+
# changed to False, which allows Marlin to perform global reductions in fp16
|
45
|
+
# precision (instead of fp32), and therefore, save on some memory movements.
|
46
|
+
USE_FP32_REDUCE_DEFAULT = True
|
47
|
+
|
48
|
+
|
49
|
+
# For binary size and compile time, we don't support the same types for with and
|
50
|
+
# without runtime zero-point. We support common cases, i.e. AWQ and GPTQ.
|
51
|
+
# TODO: we may want to move this into the C++ so its closer to the actual impl
|
52
|
+
def query_marlin_supported_quant_types(
|
53
|
+
has_zp: Optional[bool] = None,
|
54
|
+
include_fp_type: bool = True,
|
55
|
+
device_capability: Optional[int] = None,
|
56
|
+
):
|
57
|
+
if device_capability is None:
|
58
|
+
major, minor = get_device_capability()
|
59
|
+
capability = major * 10 + minor
|
60
|
+
device_capability = -1 if capability is None else capability
|
61
|
+
|
62
|
+
if device_capability < 80:
|
63
|
+
return []
|
64
|
+
|
65
|
+
# - has_zp is True: return quant_types that has zero points
|
66
|
+
# - has_zp is False: return quant_types that has not zero points
|
67
|
+
# - has_zp is None: both
|
68
|
+
if has_zp is None:
|
69
|
+
types0 = query_marlin_supported_quant_types(
|
70
|
+
False, include_fp_type, device_capability
|
71
|
+
)
|
72
|
+
types1 = query_marlin_supported_quant_types(
|
73
|
+
True, include_fp_type, device_capability
|
74
|
+
)
|
75
|
+
return types0 + types1
|
76
|
+
|
77
|
+
if has_zp:
|
78
|
+
# AWQ style, unsigned + runtime zero-point
|
79
|
+
return [scalar_types.uint4]
|
80
|
+
else:
|
81
|
+
# GPTQ style, unsigned + symmetric bias
|
82
|
+
res = [scalar_types.uint4b8, scalar_types.uint8b128]
|
83
|
+
if include_fp_type:
|
84
|
+
res += [scalar_types.float8_e4m3fn, scalar_types.float4_e2m1f]
|
85
|
+
return res
|
86
|
+
|
87
|
+
|
88
|
+
def _check_marlin_supported(
|
89
|
+
quant_type: ScalarType,
|
90
|
+
group_size: Optional[int],
|
91
|
+
has_zp: bool,
|
92
|
+
device_capability: Optional[int] = None,
|
93
|
+
) -> tuple[bool, Optional[str]]:
|
94
|
+
|
95
|
+
if device_capability is None:
|
96
|
+
major, minor = get_device_capability()
|
97
|
+
capability = major * 10 + minor
|
98
|
+
device_capability = -1 if capability is None else capability
|
99
|
+
|
100
|
+
supported_types = query_marlin_supported_quant_types(
|
101
|
+
has_zp, True, device_capability
|
102
|
+
)
|
103
|
+
|
104
|
+
if quant_type not in supported_types:
|
105
|
+
return (
|
106
|
+
False,
|
107
|
+
f"Marlin does not support weight_bits = {quant_type}. "
|
108
|
+
f"Only types = {supported_types} "
|
109
|
+
f"are supported (for group_size = {group_size}, "
|
110
|
+
f"device_capability = {device_capability}, zp = {has_zp}).",
|
111
|
+
)
|
112
|
+
if group_size is None or group_size not in MARLIN_SUPPORTED_GROUP_SIZES:
|
113
|
+
return (
|
114
|
+
False,
|
115
|
+
f"Marlin does not support group_size = {group_size}. "
|
116
|
+
f"Only group_sizes = {MARLIN_SUPPORTED_GROUP_SIZES} "
|
117
|
+
"are supported.",
|
118
|
+
)
|
119
|
+
|
120
|
+
return True, None
|
121
|
+
|
122
|
+
|
123
|
+
def check_marlin_supported(
|
124
|
+
quant_type: ScalarType,
|
125
|
+
group_size: int,
|
126
|
+
has_zp: bool = False,
|
127
|
+
device_capability: Optional[int] = None,
|
128
|
+
) -> bool:
|
129
|
+
cond, _ = _check_marlin_supported(quant_type, group_size, has_zp, device_capability)
|
130
|
+
return cond
|
131
|
+
|
132
|
+
|
133
|
+
def verify_marlin_supported(
|
134
|
+
quant_type: ScalarType, group_size: int, has_zp: bool = False
|
135
|
+
) -> None:
|
136
|
+
cond, err_msg = _check_marlin_supported(quant_type, group_size, has_zp)
|
137
|
+
if not cond:
|
138
|
+
assert err_msg is not None
|
139
|
+
raise ValueError(err_msg)
|
140
|
+
|
141
|
+
|
142
|
+
def verify_marlin_supports_shape(
|
143
|
+
output_size_per_partition: int,
|
144
|
+
input_size_per_partition: int,
|
145
|
+
input_size: int,
|
146
|
+
group_size: int,
|
147
|
+
) -> None:
|
148
|
+
|
149
|
+
# Validate output_size_per_partition
|
150
|
+
if output_size_per_partition % GPTQ_MARLIN_MIN_THREAD_N != 0:
|
151
|
+
raise ValueError(
|
152
|
+
f"Weight output_size_per_partition = "
|
153
|
+
f"{output_size_per_partition} is not divisible by "
|
154
|
+
f" min_thread_n = {GPTQ_MARLIN_MIN_THREAD_N}. "
|
155
|
+
"Consider reducing tensor_parallel_size or running "
|
156
|
+
"with --quantization gptq."
|
157
|
+
)
|
158
|
+
|
159
|
+
# Validate input_size_per_partition
|
160
|
+
if input_size_per_partition % GPTQ_MARLIN_MIN_THREAD_K != 0:
|
161
|
+
raise ValueError(
|
162
|
+
f"Weight input_size_per_partition = "
|
163
|
+
f"{input_size_per_partition} is not divisible "
|
164
|
+
f"by min_thread_k = {GPTQ_MARLIN_MIN_THREAD_K}. "
|
165
|
+
"Consider reducing tensor_parallel_size or running "
|
166
|
+
"with --quantization gptq."
|
167
|
+
)
|
168
|
+
|
169
|
+
if group_size < input_size and input_size_per_partition % group_size != 0:
|
170
|
+
raise ValueError(
|
171
|
+
f"Weight input_size_per_partition = {input_size_per_partition}"
|
172
|
+
f" is not divisible by group_size = {group_size}. "
|
173
|
+
"Consider reducing tensor_parallel_size or running "
|
174
|
+
"with --quantization gptq."
|
175
|
+
)
|
176
|
+
|
177
|
+
|
178
|
+
def check_marlin_supports_shape(
|
179
|
+
output_size_per_partition: int,
|
180
|
+
input_size_per_partition: int,
|
181
|
+
input_size: int,
|
182
|
+
group_size: int,
|
183
|
+
) -> tuple[bool, Optional[str]]:
|
184
|
+
try:
|
185
|
+
verify_marlin_supports_shape(
|
186
|
+
output_size_per_partition, input_size_per_partition, input_size, group_size
|
187
|
+
)
|
188
|
+
except ValueError as e:
|
189
|
+
return False, e.__str__()
|
190
|
+
return True, None
|
191
|
+
|
192
|
+
|
193
|
+
def check_marlin_supports_layer(layer: LinearBase, group_size: int) -> bool:
|
194
|
+
output_size_per_partition = (
|
195
|
+
getattr(layer, "output_size_per_partition", None) or layer.output_size
|
196
|
+
)
|
197
|
+
input_size_per_partition = (
|
198
|
+
getattr(layer, "input_size_per_partition", None) or layer.input_size
|
199
|
+
)
|
200
|
+
|
201
|
+
return check_marlin_supports_shape(
|
202
|
+
output_size_per_partition=output_size_per_partition,
|
203
|
+
input_size_per_partition=input_size_per_partition,
|
204
|
+
input_size=layer.input_size,
|
205
|
+
group_size=group_size,
|
206
|
+
)[0]
|
207
|
+
|
208
|
+
|
209
|
+
def check_moe_marlin_supports_layer(layer: LinearBase, group_size: int) -> bool:
|
210
|
+
hidden_size = layer.hidden_size
|
211
|
+
intermediate_size_per_partition = layer.intermediate_size_per_partition
|
212
|
+
# apply_router_weight_on_input is not supported for moe marlin
|
213
|
+
supports_router_weight = not layer.apply_router_weight_on_input
|
214
|
+
# moe marlin requires the activation to be silu
|
215
|
+
supports_activation = layer.activation == "silu"
|
216
|
+
|
217
|
+
# gate-up: (n, k) = (intermediate_size_per_partition * 2, hidden_size)
|
218
|
+
# down: (n, k) = (hidden_size, intermediate_size_per_partition)
|
219
|
+
# moe marlin requires n % 128 == 0 and k % 64 == 0
|
220
|
+
supports_shape = (
|
221
|
+
hidden_size % 128 == 0
|
222
|
+
and intermediate_size_per_partition % max(64, group_size) == 0
|
223
|
+
)
|
224
|
+
supports_group_size = group_size in [-1, 32, 64, 128]
|
225
|
+
return (
|
226
|
+
supports_shape
|
227
|
+
and supports_group_size
|
228
|
+
and supports_router_weight
|
229
|
+
and supports_activation
|
230
|
+
)
|
231
|
+
|
232
|
+
|
233
|
+
def marlin_make_workspace(
|
234
|
+
device: torch.device, max_blocks_per_sm: int = 1
|
235
|
+
) -> torch.Tensor:
|
236
|
+
# In the new marlin kernel, we use the num of threadblocks as workspace
|
237
|
+
# size. The num of threadblocks is is sms_count * max_blocks_per_sm.
|
238
|
+
sms = torch.cuda.get_device_properties(device).multi_processor_count
|
239
|
+
return torch.zeros(
|
240
|
+
sms * max_blocks_per_sm, dtype=torch.int, device=device, requires_grad=False
|
241
|
+
)
|
242
|
+
|
243
|
+
|
244
|
+
def marlin_is_k_full(act_order: bool, is_row_parallel: bool) -> bool:
|
245
|
+
return (not act_order) or (act_order and not is_row_parallel)
|
246
|
+
|
247
|
+
|
248
|
+
def marlin_repeat_scales_on_all_ranks(
|
249
|
+
act_order: bool, group_size: int, is_row_parallel: bool
|
250
|
+
) -> bool:
|
251
|
+
# Need to repeat scales on every rank if act_ordering or
|
252
|
+
# channelwise and RowParallelLinear
|
253
|
+
is_channelwise = group_size == -1
|
254
|
+
return act_order or (is_channelwise and is_row_parallel)
|
255
|
+
|
256
|
+
|
257
|
+
def marlin_make_empty_g_idx(device: torch.device) -> torch.Tensor:
|
258
|
+
return torch.nn.Parameter(
|
259
|
+
torch.empty(0, dtype=torch.int, device=device), requires_grad=False
|
260
|
+
)
|
261
|
+
|
262
|
+
|
263
|
+
def marlin_make_empty_zp(device: torch.device) -> torch.Tensor:
|
264
|
+
return torch.nn.Parameter(
|
265
|
+
torch.empty(0, dtype=torch.int, device=device), requires_grad=False
|
266
|
+
)
|
267
|
+
|
268
|
+
|
269
|
+
def marlin_sort_g_idx(g_idx: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
270
|
+
g_idx_sort_indices = torch.argsort(g_idx).to(torch.int)
|
271
|
+
return g_idx[g_idx_sort_indices], g_idx_sort_indices
|
272
|
+
|
273
|
+
|
274
|
+
def get_scale_perms():
|
275
|
+
scale_perm: list[int] = []
|
276
|
+
for i in range(8):
|
277
|
+
scale_perm.extend([i + 8 * j for j in range(8)])
|
278
|
+
scale_perm_single: list[int] = []
|
279
|
+
for i in range(4):
|
280
|
+
scale_perm_single.extend([2 * i + j for j in [0, 1, 8, 9, 16, 17, 24, 25]])
|
281
|
+
return scale_perm, scale_perm_single
|
282
|
+
|
283
|
+
|
284
|
+
def marlin_permute_scales(
|
285
|
+
s: torch.Tensor, size_k: int, size_n: int, group_size: int
|
286
|
+
) -> torch.Tensor:
|
287
|
+
|
288
|
+
scale_perm, scale_perm_single = get_scale_perms()
|
289
|
+
if group_size < size_k and group_size != -1:
|
290
|
+
s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
|
291
|
+
else:
|
292
|
+
s = s.reshape((-1, len(scale_perm_single)))[:, scale_perm_single]
|
293
|
+
s = s.reshape((-1, size_n)).contiguous()
|
294
|
+
|
295
|
+
return s
|
296
|
+
|
297
|
+
|
298
|
+
def marlin_moe_permute_scales(
|
299
|
+
s: torch.Tensor,
|
300
|
+
size_k: int,
|
301
|
+
size_n: int,
|
302
|
+
group_size: int,
|
303
|
+
):
|
304
|
+
num_experts = s.shape[0]
|
305
|
+
output = torch.empty(
|
306
|
+
(num_experts, s.shape[1], s.shape[2]),
|
307
|
+
device=s.device,
|
308
|
+
dtype=s.dtype,
|
309
|
+
)
|
310
|
+
|
311
|
+
for e in range(num_experts):
|
312
|
+
output[e] = marlin_permute_scales(s[e], size_k, size_n, group_size)
|
313
|
+
return output
|
314
|
+
|
315
|
+
|
316
|
+
def marlin_zero_points(
|
317
|
+
zp: torch.Tensor, size_k: int, size_n: int, num_bits: int
|
318
|
+
) -> torch.Tensor:
|
319
|
+
# Permute zero-points in a similar way to scales, but do not use the
|
320
|
+
# "single" permutation, since zero-points are applied on every MMA
|
321
|
+
scale_perm, _ = get_scale_perms()
|
322
|
+
zp = zp.reshape((-1, len(scale_perm)))[:, scale_perm]
|
323
|
+
|
324
|
+
# Interleave column dim (for the dequantize code) and pack it to int32
|
325
|
+
if num_bits == 4:
|
326
|
+
interleave = numpy.array([0, 2, 4, 6, 1, 3, 5, 7])
|
327
|
+
elif num_bits == 8:
|
328
|
+
interleave = numpy.array([0, 2, 1, 3])
|
329
|
+
else:
|
330
|
+
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
|
331
|
+
|
332
|
+
zp = zp.reshape((-1, len(interleave)))[:, interleave].ravel()
|
333
|
+
zp = zp.reshape((-1, size_n)).contiguous()
|
334
|
+
zp = pack_cols(zp, num_bits, size_k, size_n)
|
335
|
+
|
336
|
+
return zp
|
337
|
+
|
338
|
+
|
339
|
+
def awq_to_marlin_zero_points(
|
340
|
+
q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int
|
341
|
+
) -> torch.Tensor:
|
342
|
+
# AWQ zero-points are quantized and packed on the column dim.
|
343
|
+
# In addition, the values are permuted based on dequantizer.
|
344
|
+
# Here we undo both of these, and then apply marlin permutation
|
345
|
+
# and pack it back.
|
346
|
+
q_zp = unpack_cols(q_zp_packed, num_bits, size_k, size_n)
|
347
|
+
|
348
|
+
# Undo interleaving (use argsort(..) to get inverse perm)
|
349
|
+
if num_bits == 4:
|
350
|
+
undo_interleave = numpy.argsort(numpy.array([0, 2, 4, 6, 1, 3, 5, 7]))
|
351
|
+
elif num_bits == 8:
|
352
|
+
undo_interleave = numpy.argsort(numpy.array([0, 2, 1, 3]))
|
353
|
+
else:
|
354
|
+
raise Exception("num_bits must be 4 or 8, got {}".format(num_bits))
|
355
|
+
|
356
|
+
q_zp = q_zp.reshape((-1, len(undo_interleave)))[:, undo_interleave].ravel()
|
357
|
+
q_zp = q_zp.reshape((-1, size_n)).contiguous()
|
358
|
+
|
359
|
+
marlin_zp = marlin_zero_points(q_zp, size_k, size_n, num_bits)
|
360
|
+
return marlin_zp
|
361
|
+
|
362
|
+
|
363
|
+
def moe_awq_to_marlin_zero_points(
|
364
|
+
q_zp_packed: torch.Tensor, size_k: int, size_n: int, num_bits: int
|
365
|
+
):
|
366
|
+
num_experts = q_zp_packed.shape[0]
|
367
|
+
output = torch.empty(
|
368
|
+
(num_experts, q_zp_packed.shape[1], q_zp_packed.shape[2]),
|
369
|
+
device=q_zp_packed.device,
|
370
|
+
dtype=q_zp_packed.dtype,
|
371
|
+
)
|
372
|
+
for e in range(num_experts):
|
373
|
+
output[e] = awq_to_marlin_zero_points(q_zp_packed[e], size_k, size_n, num_bits)
|
374
|
+
return output
|
375
|
+
|
376
|
+
|
377
|
+
def maybe_warn_marlin_atomic_add(device, dtype):
|
378
|
+
if torch.compiler.is_dynamo_compiling():
|
379
|
+
return
|
380
|
+
device_capability = torch.cuda.get_device_capability(device)
|
381
|
+
if device_capability[0] < 9 and dtype == torch.bfloat16:
|
382
|
+
logger.info_once(
|
383
|
+
"You are running Marlin kernel with bf16 on GPUs before SM90. "
|
384
|
+
"You can consider change to fp16 to achieve better performance "
|
385
|
+
"if possible."
|
386
|
+
)
|
387
|
+
|
388
|
+
|
389
|
+
def maybe_warn_marlin_atomic_add_env():
|
390
|
+
if torch.compiler.is_dynamo_compiling():
|
391
|
+
return
|
392
|
+
# TODO(yiyun): Need to add sglang's MARLIN_USE_ATOMIC_ADD: bool = False
|
393
|
+
if True:
|
394
|
+
return
|
395
|
+
# if envs.VLLM_MARLIN_USE_ATOMIC_ADD:
|
396
|
+
# return
|
397
|
+
logger.info_once(
|
398
|
+
"Marlin kernel can achieve better performance for small size_n "
|
399
|
+
"with experimental use_atomic_add feature. "
|
400
|
+
"You can consider set environment variable "
|
401
|
+
"VLLM_MARLIN_USE_ATOMIC_ADD to 1 if possible."
|
402
|
+
)
|
403
|
+
|
404
|
+
|
405
|
+
def should_use_atomic_add_reduce(
|
406
|
+
m: int, n: int, k: int, device: torch.device, dtype: torch.dtype
|
407
|
+
) -> bool:
|
408
|
+
|
409
|
+
# the performance of atomicAdd is better than global reduce
|
410
|
+
# only when m*n is small and k is large
|
411
|
+
if n >= 2048 or k < 2048 or device.type != "cuda":
|
412
|
+
return False
|
413
|
+
|
414
|
+
# disable atomicAdd reduce by default,
|
415
|
+
# one can enable it with VLLM_MARLIN_USE_ATOMIC_ADD=1
|
416
|
+
# TODO: Need to add sglang's MARLIN_USE_ATOMIC_ADD: bool = False
|
417
|
+
if not True:
|
418
|
+
maybe_warn_marlin_atomic_add_env()
|
419
|
+
return False
|
420
|
+
|
421
|
+
# sm8x doesn't support atomicAdd + bfloat16 natively
|
422
|
+
device_capability = torch.cuda.get_device_capability(device)
|
423
|
+
if device_capability[0] < 9 and dtype == torch.bfloat16:
|
424
|
+
maybe_warn_marlin_atomic_add(device, dtype)
|
425
|
+
return False
|
426
|
+
|
427
|
+
return True
|
428
|
+
|
429
|
+
|
430
|
+
def apply_gptq_marlin_linear(
|
431
|
+
input: torch.Tensor,
|
432
|
+
weight: torch.Tensor,
|
433
|
+
weight_scale: torch.Tensor,
|
434
|
+
weight_zp: torch.Tensor,
|
435
|
+
g_idx: torch.Tensor,
|
436
|
+
g_idx_sort_indices: torch.Tensor,
|
437
|
+
workspace: torch.Tensor,
|
438
|
+
wtype: ScalarType,
|
439
|
+
output_size_per_partition: int,
|
440
|
+
input_size_per_partition: int,
|
441
|
+
is_k_full: bool,
|
442
|
+
bias: Optional[torch.Tensor] = None,
|
443
|
+
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
|
444
|
+
) -> torch.Tensor:
|
445
|
+
reshaped_x = input.reshape(-1, input.shape[-1])
|
446
|
+
out_shape = input.shape[:-1] + (output_size_per_partition,)
|
447
|
+
|
448
|
+
use_atomic_add = should_use_atomic_add_reduce(
|
449
|
+
m=reshaped_x.size(0),
|
450
|
+
n=output_size_per_partition,
|
451
|
+
k=reshaped_x.size(1),
|
452
|
+
device=input.device,
|
453
|
+
dtype=input.dtype,
|
454
|
+
)
|
455
|
+
|
456
|
+
output = ops.gptq_marlin_gemm(
|
457
|
+
reshaped_x,
|
458
|
+
None,
|
459
|
+
weight,
|
460
|
+
weight_scale,
|
461
|
+
None,
|
462
|
+
weight_zp,
|
463
|
+
g_idx,
|
464
|
+
g_idx_sort_indices,
|
465
|
+
workspace,
|
466
|
+
wtype,
|
467
|
+
size_m=reshaped_x.shape[0],
|
468
|
+
size_n=output_size_per_partition,
|
469
|
+
size_k=input_size_per_partition,
|
470
|
+
is_k_full=is_k_full,
|
471
|
+
use_atomic_add=use_atomic_add,
|
472
|
+
use_fp32_reduce=use_fp32_reduce,
|
473
|
+
is_zp_float=False,
|
474
|
+
)
|
475
|
+
|
476
|
+
if bias is not None:
|
477
|
+
output.add_(bias) # In-place add
|
478
|
+
|
479
|
+
return output.reshape(out_shape)
|
480
|
+
|
481
|
+
|
482
|
+
def apply_awq_marlin_linear(
|
483
|
+
input: torch.Tensor,
|
484
|
+
weight: torch.Tensor,
|
485
|
+
weight_scale: torch.Tensor,
|
486
|
+
weight_zp: torch.Tensor,
|
487
|
+
g_idx: torch.Tensor,
|
488
|
+
g_idx_sort_indices: torch.Tensor,
|
489
|
+
workspace: torch.Tensor,
|
490
|
+
quant_type: ScalarType,
|
491
|
+
output_size_per_partition: int,
|
492
|
+
input_size_per_partition: int,
|
493
|
+
bias: Optional[torch.Tensor] = None,
|
494
|
+
use_fp32_reduce: bool = USE_FP32_REDUCE_DEFAULT,
|
495
|
+
) -> torch.Tensor:
|
496
|
+
reshaped_x = input.reshape(-1, input.shape[-1])
|
497
|
+
out_shape = input.shape[:-1] + (output_size_per_partition,)
|
498
|
+
|
499
|
+
use_atomic_add = should_use_atomic_add_reduce(
|
500
|
+
m=reshaped_x.size(0),
|
501
|
+
n=output_size_per_partition,
|
502
|
+
k=reshaped_x.size(1),
|
503
|
+
device=input.device,
|
504
|
+
dtype=input.dtype,
|
505
|
+
)
|
506
|
+
|
507
|
+
output = ops.gptq_marlin_gemm(
|
508
|
+
reshaped_x,
|
509
|
+
None,
|
510
|
+
weight,
|
511
|
+
weight_scale,
|
512
|
+
None,
|
513
|
+
weight_zp,
|
514
|
+
g_idx,
|
515
|
+
g_idx_sort_indices,
|
516
|
+
workspace,
|
517
|
+
quant_type,
|
518
|
+
size_m=reshaped_x.shape[0],
|
519
|
+
size_n=output_size_per_partition,
|
520
|
+
size_k=input_size_per_partition,
|
521
|
+
use_atomic_add=use_atomic_add,
|
522
|
+
use_fp32_reduce=use_fp32_reduce,
|
523
|
+
is_zp_float=False,
|
524
|
+
)
|
525
|
+
|
526
|
+
if bias is not None:
|
527
|
+
output.add_(bias) # In-place add
|
528
|
+
|
529
|
+
return output.reshape(out_shape)
|
530
|
+
|
531
|
+
|
532
|
+
class MarlinConfig(QuantizationConfig):
|
533
|
+
"""Config class for Marlin.
|
534
|
+
|
535
|
+
Reference: https://github.com/IST-DASLab/marlin/tree/master
|
536
|
+
"""
|
537
|
+
|
538
|
+
def __init__(
|
539
|
+
self,
|
540
|
+
group_size: int,
|
541
|
+
lm_head_quantized: bool,
|
542
|
+
) -> None:
|
543
|
+
super().__init__()
|
544
|
+
|
545
|
+
# Group size for the quantization.
|
546
|
+
self.group_size = group_size
|
547
|
+
self.lm_head_quantized = lm_head_quantized
|
548
|
+
if self.group_size != 128 and self.group_size != -1:
|
549
|
+
raise ValueError(
|
550
|
+
"Currently, only group size 128 and -1 (channelwise) "
|
551
|
+
"is supported for Marlin, but got group_size of "
|
552
|
+
f"{self.group_size}"
|
553
|
+
)
|
554
|
+
|
555
|
+
# 4 Bits packed into 32 bit datatype.
|
556
|
+
self.pack_factor = 32 // 4
|
557
|
+
|
558
|
+
# Tile size used by marlin kernels.
|
559
|
+
self.tile_size = 16
|
560
|
+
|
561
|
+
# Min out_features dim
|
562
|
+
self.min_n_threads = 64
|
563
|
+
|
564
|
+
# Min in_features dim
|
565
|
+
self.min_k_threads = 128
|
566
|
+
|
567
|
+
# Max parallel problems to solve at once (improves large
|
568
|
+
# batch performance)
|
569
|
+
self.max_parallel = 16
|
570
|
+
|
571
|
+
# Permutation length used by the marlin kernels.
|
572
|
+
self.perm_len = 1024
|
573
|
+
|
574
|
+
def __repr__(self) -> str:
|
575
|
+
return (
|
576
|
+
f"MarlinConfig(group_size={self.group_size}, "
|
577
|
+
f"lm_head_quantized={self.lm_head_quantized})"
|
578
|
+
)
|
579
|
+
|
580
|
+
@classmethod
|
581
|
+
def get_name(cls) -> str:
|
582
|
+
return "marlin"
|
583
|
+
|
584
|
+
@classmethod
|
585
|
+
def get_supported_act_dtypes(cls) -> list[torch.dtype]:
|
586
|
+
return [torch.half]
|
587
|
+
|
588
|
+
@classmethod
|
589
|
+
# Need to figure it out
|
590
|
+
def get_min_capability(cls) -> int:
|
591
|
+
return 80
|
592
|
+
|
593
|
+
@classmethod
|
594
|
+
def get_config_filenames(cls) -> list[str]:
|
595
|
+
return ["quantize_config.json"]
|
596
|
+
|
597
|
+
@classmethod
|
598
|
+
def from_config(cls, config: dict[str, Any]) -> "MarlinConfig":
|
599
|
+
group_size = cls.get_from_keys(config, ["group_size"])
|
600
|
+
lm_head_quantized = cls.get_from_keys_or(config, ["lm_head"], default=False)
|
601
|
+
return cls(group_size, lm_head_quantized)
|
602
|
+
|
603
|
+
@classmethod
|
604
|
+
def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]:
|
605
|
+
# compat: autogptq >=0.8.0 use checkpoint_format: str
|
606
|
+
# compat: autogptq <=0.7.1 is_marlin_format: bool
|
607
|
+
is_marlin_format = hf_quant_cfg.get(
|
608
|
+
"checkpoint_format"
|
609
|
+
) == "marlin" or hf_quant_cfg.get("is_marlin_format", False)
|
610
|
+
|
611
|
+
is_valid_user_quant = (
|
612
|
+
user_quant is None or user_quant == "gptq" or user_quant == "marlin"
|
613
|
+
)
|
614
|
+
|
615
|
+
if is_marlin_format and is_valid_user_quant:
|
616
|
+
msg = "The model is serialized in {} format. Using {} kernel.".format(
|
617
|
+
cls.get_name(), cls.get_name()
|
618
|
+
)
|
619
|
+
logger.info(msg)
|
620
|
+
return cls.get_name()
|
621
|
+
|
622
|
+
return None
|
623
|
+
|
624
|
+
def get_quant_method(
|
625
|
+
self, layer: torch.nn.Module, prefix: str
|
626
|
+
) -> Optional[MarlinLinearMethod]:
|
627
|
+
from sglang.srt.layers.linear import LinearBase
|
628
|
+
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
629
|
+
|
630
|
+
if isinstance(layer, LinearBase) or (
|
631
|
+
isinstance(layer, ParallelLMHead) and self.lm_head_quantized
|
632
|
+
):
|
633
|
+
return MarlinLinearMethod(self)
|
634
|
+
return None
|
635
|
+
|
636
|
+
|
637
|
+
class MarlinLinearMethod(LinearMethodBase):
|
638
|
+
"""Linear method for Marlin.
|
639
|
+
|
640
|
+
Args:
|
641
|
+
quant_config: The Marlin quantization config.
|
642
|
+
"""
|
643
|
+
|
644
|
+
def __init__(self, quant_config: MarlinConfig):
|
645
|
+
self.quant_config = quant_config
|
646
|
+
|
647
|
+
def create_weights(
|
648
|
+
self,
|
649
|
+
layer: torch.nn.Module,
|
650
|
+
input_size_per_partition: int,
|
651
|
+
output_partition_sizes: list[int],
|
652
|
+
input_size: int,
|
653
|
+
output_size: int,
|
654
|
+
params_dtype: torch.dtype,
|
655
|
+
**extra_weight_attrs,
|
656
|
+
):
|
657
|
+
del output_size # Unused.
|
658
|
+
weight_loader = extra_weight_attrs["weight_loader"]
|
659
|
+
|
660
|
+
if params_dtype != torch.float16:
|
661
|
+
raise ValueError(
|
662
|
+
f"The params dtype must be float16, but got {params_dtype}"
|
663
|
+
)
|
664
|
+
|
665
|
+
# Validate output_size_per_partition
|
666
|
+
output_size_per_partition = sum(output_partition_sizes)
|
667
|
+
if output_size_per_partition % self.quant_config.min_n_threads != 0:
|
668
|
+
raise ValueError(
|
669
|
+
f"Weight output_size_per_partition = "
|
670
|
+
f"{output_size_per_partition} is not divisible by "
|
671
|
+
f"min_n_threads = {self.quant_config.min_n_threads}."
|
672
|
+
)
|
673
|
+
if output_size_per_partition % self.quant_config.pack_factor != 0:
|
674
|
+
raise ValueError(
|
675
|
+
f"Weight output_size_per_partition = "
|
676
|
+
f"{output_size_per_partition} is not divisible by "
|
677
|
+
f"pack_factor = {self.quant_config.pack_factor}."
|
678
|
+
)
|
679
|
+
|
680
|
+
# Validate input_size_per_partition
|
681
|
+
if input_size_per_partition % self.quant_config.min_k_threads != 0:
|
682
|
+
raise ValueError(
|
683
|
+
f"Weight input_size_per_partition = "
|
684
|
+
f"{input_size_per_partition} is not divisible by "
|
685
|
+
f"min_k_threads = {self.quant_config.min_k_threads}."
|
686
|
+
)
|
687
|
+
if (
|
688
|
+
self.quant_config.group_size != -1
|
689
|
+
and input_size_per_partition % self.quant_config.group_size != 0
|
690
|
+
):
|
691
|
+
raise ValueError(
|
692
|
+
f"Weight input_size_per_partition = "
|
693
|
+
f"{input_size_per_partition} is not divisible by "
|
694
|
+
f"group_size = {self.quant_config.group_size}."
|
695
|
+
)
|
696
|
+
|
697
|
+
# Check that we have at least 4 tiles horizontally in the shard
|
698
|
+
num_tiles_per_perm = self.quant_config.perm_len // (
|
699
|
+
self.quant_config.tile_size**2
|
700
|
+
)
|
701
|
+
if output_size_per_partition % num_tiles_per_perm != 0:
|
702
|
+
raise ValueError("Each permutation group must reside on the same gpu")
|
703
|
+
|
704
|
+
# Quantized 4Bit weights packed into Int32.
|
705
|
+
qweight = PackedvLLMParameter(
|
706
|
+
data=torch.empty(
|
707
|
+
input_size_per_partition // self.quant_config.tile_size,
|
708
|
+
output_size_per_partition
|
709
|
+
* self.quant_config.tile_size
|
710
|
+
// self.quant_config.pack_factor,
|
711
|
+
device="cuda",
|
712
|
+
dtype=torch.int32,
|
713
|
+
),
|
714
|
+
input_dim=0,
|
715
|
+
output_dim=1,
|
716
|
+
packed_dim=1,
|
717
|
+
packed_factor=self.quant_config.pack_factor,
|
718
|
+
marlin_tile_size=self.quant_config.tile_size,
|
719
|
+
weight_loader=weight_loader,
|
720
|
+
)
|
721
|
+
|
722
|
+
# Determine if channelwise or not
|
723
|
+
input_groups = (
|
724
|
+
1
|
725
|
+
if self.quant_config.group_size == -1
|
726
|
+
else input_size_per_partition // self.quant_config.group_size
|
727
|
+
)
|
728
|
+
|
729
|
+
weight_scale_args = {
|
730
|
+
"data": torch.empty(
|
731
|
+
input_groups,
|
732
|
+
output_size_per_partition,
|
733
|
+
device="cuda",
|
734
|
+
dtype=params_dtype,
|
735
|
+
),
|
736
|
+
"weight_loader": weight_loader,
|
737
|
+
}
|
738
|
+
if input_groups == 1:
|
739
|
+
scales = ChannelQuantScaleParameter(output_dim=1, **weight_scale_args)
|
740
|
+
else:
|
741
|
+
scales = GroupQuantScaleParameter(
|
742
|
+
output_dim=1, input_dim=0, **weight_scale_args
|
743
|
+
)
|
744
|
+
|
745
|
+
# Allocate workspace (Used for internal locking mechanism)
|
746
|
+
max_workspace_size = (
|
747
|
+
output_size_per_partition // self.quant_config.min_n_threads
|
748
|
+
) * self.quant_config.max_parallel
|
749
|
+
|
750
|
+
workspace = BasevLLMParameter(
|
751
|
+
data=torch.zeros(max_workspace_size, device="cuda", dtype=torch.int),
|
752
|
+
weight_loader=weight_loader,
|
753
|
+
)
|
754
|
+
|
755
|
+
layer.register_parameter("B", qweight)
|
756
|
+
layer.register_parameter("s", scales)
|
757
|
+
layer.register_parameter("workspace", workspace)
|
758
|
+
|
759
|
+
def process_weights_after_loading(self, layer: torch.nn.Module) -> None:
|
760
|
+
# required by torch.compile
|
761
|
+
layer.B = torch.nn.Parameter(layer.B.data, requires_grad=False)
|
762
|
+
layer.s = torch.nn.Parameter(layer.s.data, requires_grad=False)
|
763
|
+
layer.workspace = torch.nn.Parameter(layer.workspace.data, requires_grad=False)
|
764
|
+
|
765
|
+
def apply(
|
766
|
+
self,
|
767
|
+
layer: torch.nn.Module,
|
768
|
+
x: torch.Tensor,
|
769
|
+
bias: Optional[torch.Tensor] = None,
|
770
|
+
) -> torch.Tensor:
|
771
|
+
qweight = layer.B
|
772
|
+
scales = layer.s
|
773
|
+
workspace = layer.workspace
|
774
|
+
|
775
|
+
x_2d = x.view(-1, x.shape[-1])
|
776
|
+
|
777
|
+
size_m = x_2d.shape[0]
|
778
|
+
size_k = x_2d.shape[1]
|
779
|
+
size_n = scales.shape[1]
|
780
|
+
|
781
|
+
output_2d = ops.marlin_gemm(
|
782
|
+
x_2d, qweight, scales, workspace, size_m, size_n, size_k
|
783
|
+
)
|
784
|
+
|
785
|
+
output = output_2d.view(x.shape[:-1] + (output_2d.shape[1],))
|
786
|
+
|
787
|
+
if bias is not None:
|
788
|
+
output.add_(bias) # In-place add
|
789
|
+
|
790
|
+
return output
|