sglang 0.4.5.post1__py3-none-any.whl → 0.4.5.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -4
- sglang/bench_one_batch.py +2 -2
- sglang/bench_serving.py +3 -6
- sglang/compile_deep_gemm.py +136 -0
- sglang/lang/backend/anthropic.py +0 -4
- sglang/lang/backend/base_backend.py +1 -1
- sglang/lang/backend/openai.py +6 -2
- sglang/lang/backend/runtime_endpoint.py +5 -1
- sglang/lang/backend/vertexai.py +0 -1
- sglang/lang/compiler.py +1 -7
- sglang/lang/tracer.py +3 -7
- sglang/srt/_custom_ops.py +0 -2
- sglang/srt/configs/model_config.py +4 -1
- sglang/srt/constrained/outlines_jump_forward.py +14 -1
- sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
- sglang/srt/constrained/xgrammar_backend.py +27 -4
- sglang/srt/custom_op.py +0 -62
- sglang/srt/disaggregation/decode.py +105 -6
- sglang/srt/disaggregation/mini_lb.py +74 -9
- sglang/srt/disaggregation/mooncake/conn.py +33 -63
- sglang/srt/disaggregation/mooncake/transfer_engine.py +30 -61
- sglang/srt/disaggregation/nixl/__init__.py +1 -0
- sglang/srt/disaggregation/nixl/conn.py +622 -0
- sglang/srt/disaggregation/prefill.py +137 -17
- sglang/srt/disaggregation/utils.py +32 -0
- sglang/srt/entrypoints/engine.py +4 -0
- sglang/srt/entrypoints/http_server.py +3 -7
- sglang/srt/entrypoints/verl_engine.py +7 -5
- sglang/srt/function_call_parser.py +60 -0
- sglang/srt/layers/activation.py +6 -8
- sglang/srt/layers/attention/flashattention_backend.py +883 -209
- sglang/srt/layers/attention/flashinfer_backend.py +5 -2
- sglang/srt/layers/attention/torch_native_backend.py +6 -1
- sglang/srt/layers/attention/triton_backend.py +6 -0
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +5 -5
- sglang/srt/layers/attention/triton_ops/extend_attention.py +18 -7
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +7 -3
- sglang/srt/layers/dp_attention.py +1 -1
- sglang/srt/layers/layernorm.py +20 -5
- sglang/srt/layers/linear.py +17 -3
- sglang/srt/layers/moe/ep_moe/layer.py +17 -29
- sglang/srt/layers/moe/fused_moe_native.py +4 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +14 -19
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
- sglang/srt/layers/moe/topk.py +27 -30
- sglang/srt/layers/parameter.py +0 -2
- sglang/srt/layers/quantization/__init__.py +1 -0
- sglang/srt/layers/quantization/blockwise_int8.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +9 -2
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +16 -44
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +153 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
- sglang/srt/layers/quantization/deep_gemm.py +378 -0
- sglang/srt/layers/quantization/fp8.py +115 -132
- sglang/srt/layers/quantization/fp8_kernel.py +213 -88
- sglang/srt/layers/quantization/fp8_utils.py +189 -264
- sglang/srt/layers/quantization/gptq.py +13 -7
- sglang/srt/layers/quantization/modelopt_quant.py +2 -2
- sglang/srt/layers/quantization/moe_wna16.py +2 -0
- sglang/srt/layers/quantization/utils.py +5 -11
- sglang/srt/layers/quantization/w8a8_fp8.py +2 -0
- sglang/srt/layers/quantization/w8a8_int8.py +7 -7
- sglang/srt/layers/radix_attention.py +15 -0
- sglang/srt/layers/rotary_embedding.py +9 -8
- sglang/srt/layers/sampler.py +7 -12
- sglang/srt/lora/backend/base_backend.py +18 -2
- sglang/srt/lora/backend/flashinfer_backend.py +1 -1
- sglang/srt/lora/backend/triton_backend.py +1 -1
- sglang/srt/lora/layers.py +1 -1
- sglang/srt/lora/lora.py +1 -1
- sglang/srt/lora/lora_manager.py +1 -1
- sglang/srt/managers/data_parallel_controller.py +7 -1
- sglang/srt/managers/detokenizer_manager.py +0 -1
- sglang/srt/managers/io_struct.py +15 -3
- sglang/srt/managers/mm_utils.py +4 -3
- sglang/srt/managers/multimodal_processor.py +0 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +3 -2
- sglang/srt/managers/schedule_batch.py +15 -4
- sglang/srt/managers/scheduler.py +28 -77
- sglang/srt/managers/tokenizer_manager.py +116 -29
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/mem_cache/hiradix_cache.py +41 -29
- sglang/srt/mem_cache/memory_pool.py +38 -15
- sglang/srt/model_executor/cuda_graph_runner.py +15 -10
- sglang/srt/model_executor/model_runner.py +39 -31
- sglang/srt/models/bert.py +398 -0
- sglang/srt/models/deepseek.py +1 -1
- sglang/srt/models/deepseek_nextn.py +74 -70
- sglang/srt/models/deepseek_v2.py +292 -348
- sglang/srt/models/llama.py +5 -5
- sglang/srt/models/minicpm3.py +31 -203
- sglang/srt/models/minicpmo.py +17 -6
- sglang/srt/models/qwen2.py +4 -1
- sglang/srt/models/qwen2_moe.py +14 -13
- sglang/srt/models/qwen3.py +335 -0
- sglang/srt/models/qwen3_moe.py +423 -0
- sglang/srt/openai_api/adapter.py +71 -4
- sglang/srt/openai_api/protocol.py +6 -1
- sglang/srt/reasoning_parser.py +0 -1
- sglang/srt/sampling/sampling_batch_info.py +2 -3
- sglang/srt/server_args.py +86 -72
- sglang/srt/speculative/build_eagle_tree.py +2 -2
- sglang/srt/speculative/eagle_utils.py +2 -2
- sglang/srt/speculative/eagle_worker.py +6 -14
- sglang/srt/utils.py +62 -6
- sglang/test/runners.py +5 -1
- sglang/test/test_block_fp8.py +167 -0
- sglang/test/test_custom_ops.py +1 -1
- sglang/test/test_utils.py +3 -1
- sglang/version.py +1 -1
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/METADATA +5 -5
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/RECORD +116 -110
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/WHEEL +1 -1
- sglang/lang/__init__.py +0 -0
- sglang/srt/lora/backend/__init__.py +0 -25
- sglang/srt/server.py +0 -18
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/top_level.txt +0 -0
sglang/srt/layers/moe/topk.py
CHANGED
@@ -13,7 +13,6 @@
|
|
13
13
|
# ==============================================================================
|
14
14
|
|
15
15
|
import math
|
16
|
-
import os
|
17
16
|
from typing import Callable, Optional
|
18
17
|
|
19
18
|
import torch
|
@@ -29,6 +28,10 @@ _is_hip = is_hip()
|
|
29
28
|
if _is_cuda:
|
30
29
|
from sgl_kernel import moe_fused_gate
|
31
30
|
|
31
|
+
if _is_cuda or _is_hip:
|
32
|
+
from sgl_kernel import topk_softmax
|
33
|
+
|
34
|
+
|
32
35
|
expert_distribution_recorder = ExpertDistributionRecorder()
|
33
36
|
|
34
37
|
|
@@ -59,11 +62,6 @@ def fused_topk(
|
|
59
62
|
topk: int,
|
60
63
|
renormalize: bool,
|
61
64
|
):
|
62
|
-
if _is_cuda or _is_hip:
|
63
|
-
from sgl_kernel import topk_softmax
|
64
|
-
else:
|
65
|
-
from vllm import _custom_ops as vllm_ops
|
66
|
-
|
67
65
|
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
68
66
|
|
69
67
|
M, _ = hidden_states.shape
|
@@ -76,20 +74,12 @@ def fused_topk(
|
|
76
74
|
M, topk, dtype=torch.int32, device=hidden_states.device
|
77
75
|
)
|
78
76
|
|
79
|
-
|
80
|
-
|
81
|
-
|
82
|
-
|
83
|
-
|
84
|
-
|
85
|
-
)
|
86
|
-
else:
|
87
|
-
vllm_ops.topk_softmax(
|
88
|
-
topk_weights,
|
89
|
-
topk_ids,
|
90
|
-
token_expert_indicies,
|
91
|
-
gating_output.float(),
|
92
|
-
)
|
77
|
+
topk_softmax(
|
78
|
+
topk_weights,
|
79
|
+
topk_ids,
|
80
|
+
token_expert_indicies,
|
81
|
+
gating_output.float(),
|
82
|
+
)
|
93
83
|
del token_expert_indicies
|
94
84
|
|
95
85
|
if renormalize:
|
@@ -108,6 +98,7 @@ def grouped_topk(
|
|
108
98
|
num_expert_group: int = 0,
|
109
99
|
topk_group: int = 0,
|
110
100
|
n_share_experts_fusion: int = 0,
|
101
|
+
routed_scaling_factor: Optional[float] = None,
|
111
102
|
):
|
112
103
|
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
113
104
|
|
@@ -137,9 +128,7 @@ def grouped_topk(
|
|
137
128
|
dtype=topk_ids.dtype,
|
138
129
|
device=topk_ids.device,
|
139
130
|
)
|
140
|
-
topk_weights[:, -1] = (
|
141
|
-
topk_weights[:, :-1].sum(dim=-1) / 2.5
|
142
|
-
) # 2.5 is the routed_scaling_factor.
|
131
|
+
topk_weights[:, -1] = topk_weights[:, :-1].sum(dim=-1) / routed_scaling_factor
|
143
132
|
|
144
133
|
if renormalize:
|
145
134
|
topk_weights_sum = (
|
@@ -161,6 +150,7 @@ def biased_grouped_topk_impl(
|
|
161
150
|
num_expert_group: int = 0,
|
162
151
|
topk_group: int = 0,
|
163
152
|
n_share_experts_fusion: int = 0,
|
153
|
+
routed_scaling_factor: Optional[float] = None,
|
164
154
|
):
|
165
155
|
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
166
156
|
|
@@ -197,9 +187,7 @@ def biased_grouped_topk_impl(
|
|
197
187
|
dtype=topk_ids.dtype,
|
198
188
|
device=topk_ids.device,
|
199
189
|
)
|
200
|
-
topk_weights[:, -1] = (
|
201
|
-
topk_weights[:, :-1].sum(dim=-1) / 2.5
|
202
|
-
) # 2.5 is the routed_scaling_factor.
|
190
|
+
topk_weights[:, -1] = topk_weights[:, :-1].sum(dim=-1) / routed_scaling_factor
|
203
191
|
|
204
192
|
if renormalize:
|
205
193
|
topk_weights_sum = (
|
@@ -226,11 +214,16 @@ def biased_grouped_topk(
|
|
226
214
|
topk_group: int = 0,
|
227
215
|
compiled: bool = True,
|
228
216
|
n_share_experts_fusion: int = 0,
|
217
|
+
routed_scaling_factor: Optional[float] = None,
|
229
218
|
):
|
219
|
+
assert (
|
220
|
+
routed_scaling_factor is not None
|
221
|
+
), "routed_scaling_factor is required for biased_grouped_topk"
|
230
222
|
# TODO: moe_fused_gate kernel is not supported for n_share_experts_fusion > 0 now.
|
231
223
|
if (
|
232
224
|
_is_cuda
|
233
|
-
and
|
225
|
+
and gating_output.shape[1] // num_expert_group
|
226
|
+
<= 32 # moe_fused_gate kernel ensure that num_experts/num_expert_group does not exceed MAX_VPT=32 now. And when kernel can handle MAX_VPT > 32, we can remove this assertion.
|
234
227
|
and is_power_of_two(correction_bias.shape[0])
|
235
228
|
):
|
236
229
|
return moe_fused_gate(
|
@@ -239,6 +232,8 @@ def biased_grouped_topk(
|
|
239
232
|
num_expert_group,
|
240
233
|
topk_group,
|
241
234
|
topk,
|
235
|
+
n_share_experts_fusion,
|
236
|
+
routed_scaling_factor,
|
242
237
|
)
|
243
238
|
else:
|
244
239
|
biased_grouped_topk_fn = (
|
@@ -257,6 +252,7 @@ def biased_grouped_topk(
|
|
257
252
|
num_expert_group,
|
258
253
|
topk_group,
|
259
254
|
n_share_experts_fusion=n_share_experts_fusion,
|
255
|
+
routed_scaling_factor=routed_scaling_factor,
|
260
256
|
)
|
261
257
|
|
262
258
|
|
@@ -271,10 +267,9 @@ def select_experts(
|
|
271
267
|
custom_routing_function: Optional[Callable] = None,
|
272
268
|
correction_bias: Optional[torch.Tensor] = None,
|
273
269
|
torch_native: bool = False,
|
270
|
+
routed_scaling_factor: Optional[float] = None,
|
274
271
|
):
|
275
|
-
n_share_experts_fusion =
|
276
|
-
if global_server_args_dict["n_share_experts_fusion"] is not None:
|
277
|
-
n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
|
272
|
+
n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
|
278
273
|
# DeekSeek V2/V3/R1 serices models uses grouped_top_k
|
279
274
|
if use_grouped_topk:
|
280
275
|
assert topk_group is not None
|
@@ -288,6 +283,7 @@ def select_experts(
|
|
288
283
|
num_expert_group=num_expert_group,
|
289
284
|
topk_group=topk_group,
|
290
285
|
n_share_experts_fusion=n_share_experts_fusion,
|
286
|
+
routed_scaling_factor=routed_scaling_factor,
|
291
287
|
)
|
292
288
|
else:
|
293
289
|
topk_weights, topk_ids = biased_grouped_topk(
|
@@ -299,6 +295,7 @@ def select_experts(
|
|
299
295
|
num_expert_group=num_expert_group,
|
300
296
|
topk_group=topk_group,
|
301
297
|
n_share_experts_fusion=n_share_experts_fusion,
|
298
|
+
routed_scaling_factor=routed_scaling_factor,
|
302
299
|
)
|
303
300
|
elif torch_native and custom_routing_function is None:
|
304
301
|
topk_weights, topk_ids = fused_topk_native(
|
sglang/srt/layers/parameter.py
CHANGED
@@ -290,6 +290,7 @@ def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
|
|
290
290
|
apply_router_weight_on_input: bool = False,
|
291
291
|
inplace: bool = True,
|
292
292
|
no_combine: bool = False,
|
293
|
+
routed_scaling_factor: Optional[float] = None,
|
293
294
|
):
|
294
295
|
assert activation == "silu"
|
295
296
|
assert inplace and not no_combine
|
@@ -373,6 +373,7 @@ class BlockInt8MoEMethod:
|
|
373
373
|
apply_router_weight_on_input: bool = False,
|
374
374
|
inplace: bool = True,
|
375
375
|
no_combine: bool = False,
|
376
|
+
routed_scaling_factor: Optional[float] = None,
|
376
377
|
) -> torch.Tensor:
|
377
378
|
from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
|
378
379
|
from sglang.srt.layers.moe.topk import select_experts
|
@@ -388,6 +389,7 @@ class BlockInt8MoEMethod:
|
|
388
389
|
num_expert_group=num_expert_group,
|
389
390
|
custom_routing_function=custom_routing_function,
|
390
391
|
correction_bias=correction_bias,
|
392
|
+
routed_scaling_factor=routed_scaling_factor,
|
391
393
|
)
|
392
394
|
|
393
395
|
# Expert fusion with INT8 quantization
|
@@ -1,4 +1,4 @@
|
|
1
|
-
# Adapted from https://github.com/vllm-project/vllm/tree/
|
1
|
+
# Adapted from https://github.com/vllm-project/vllm/tree/v0.8.2/vllm/model_executor/layers/quantization/compressed_tensors
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
3
3
|
|
4
4
|
import logging
|
@@ -33,13 +33,20 @@ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_moe im
|
|
33
33
|
from sglang.srt.layers.quantization.compressed_tensors.schemes import (
|
34
34
|
CompressedTensorsScheme,
|
35
35
|
CompressedTensorsW8A8Fp8,
|
36
|
+
CompressedTensorsW8A16Fp8,
|
36
37
|
)
|
37
38
|
from sglang.srt.layers.quantization.compressed_tensors.utils import (
|
38
39
|
find_matched_target,
|
39
40
|
is_activation_quantization_format,
|
40
41
|
should_ignore_layer,
|
41
42
|
)
|
42
|
-
|
43
|
+
|
44
|
+
try:
|
45
|
+
import vllm
|
46
|
+
|
47
|
+
VLLM_AVAILABLE = True
|
48
|
+
except ImportError:
|
49
|
+
VLLM_AVAILABLE = False
|
43
50
|
|
44
51
|
logger = logging.getLogger(__name__)
|
45
52
|
|
@@ -1,22 +1,16 @@
|
|
1
|
-
# Adapted from https://github.com/vllm-project/vllm/tree/
|
1
|
+
# Adapted from https://github.com/vllm-project/vllm/tree/v0.8.2/vllm/model_executor/layers/quantization/compressed_tensors
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
3
3
|
|
4
4
|
import enum
|
5
5
|
import logging
|
6
6
|
from enum import Enum
|
7
|
-
from typing import
|
7
|
+
from typing import Callable, List, Optional
|
8
8
|
|
9
9
|
import torch
|
10
10
|
from compressed_tensors import CompressionFormat
|
11
11
|
from compressed_tensors.quantization import QuantizationStrategy
|
12
12
|
|
13
|
-
|
14
|
-
from sglang.srt.layers.moe.fused_moe_triton import (
|
15
|
-
FusedMoE,
|
16
|
-
FusedMoEMethodBase,
|
17
|
-
FusedMoeWeightScaleSupported,
|
18
|
-
)
|
19
|
-
|
13
|
+
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
|
20
14
|
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
|
21
15
|
from sglang.srt.layers.quantization.utils import (
|
22
16
|
all_close_1d,
|
@@ -29,10 +23,9 @@ from sglang.srt.utils import set_weight_attrs
|
|
29
23
|
|
30
24
|
_is_cuda = is_cuda()
|
31
25
|
|
32
|
-
if _is_cuda:
|
33
|
-
from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
|
34
|
-
else:
|
26
|
+
if not _is_cuda:
|
35
27
|
from vllm import _custom_ops as vllm_ops
|
28
|
+
from vllm._custom_ops import scaled_fp8_quant
|
36
29
|
|
37
30
|
try:
|
38
31
|
import vllm
|
@@ -58,8 +51,6 @@ __all__ = [
|
|
58
51
|
|
59
52
|
class CompressedTensorsMoEMethod:
|
60
53
|
def __new__(cls, *args, **kwargs):
|
61
|
-
from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
|
62
|
-
|
63
54
|
if cls is CompressedTensorsMoEMethod:
|
64
55
|
return super().__new__(cls)
|
65
56
|
return super().__new__(cls)
|
@@ -76,7 +67,7 @@ class CompressedTensorsMoEMethod:
|
|
76
67
|
if quant_config._is_wNa16_group_channel(weight_quant, input_quant):
|
77
68
|
if not VLLM_AVAILABLE:
|
78
69
|
raise ImportError(
|
79
|
-
"vllm is not installed, to use CompressedTensorsWNA16MoEMethod, please install vllm"
|
70
|
+
"vllm is not installed, to use CompressedTensorsWNA16MoEMethod, please install vllm."
|
80
71
|
)
|
81
72
|
return CompressedTensorsWNA16MoEMethod(quant_config)
|
82
73
|
elif quant_config._is_fp8_w8a8(weight_quant, input_quant):
|
@@ -92,11 +83,6 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|
92
83
|
def __init__(
|
93
84
|
self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
|
94
85
|
):
|
95
|
-
from sglang.srt.layers.moe.fused_moe_triton import (
|
96
|
-
FusedMoEMethodBase,
|
97
|
-
FusedMoeWeightScaleSupported,
|
98
|
-
)
|
99
|
-
|
100
86
|
self.quant_config = quant_config
|
101
87
|
self.weight_quant = self.quant_config.target_scheme_map["Linear"].get("weights")
|
102
88
|
self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
|
@@ -267,19 +253,11 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|
267
253
|
layer.w13_weight[expert_id][start : start + shard_size, :],
|
268
254
|
layer.w13_weight_scale[expert_id][shard_id],
|
269
255
|
)
|
256
|
+
(
|
257
|
+
layer.w13_weight[expert_id][start : start + shard_size, :],
|
258
|
+
_,
|
259
|
+
) = scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
|
270
260
|
|
271
|
-
if _is_cuda:
|
272
|
-
(
|
273
|
-
layer.w13_weight[expert_id][start : start + shard_size, :],
|
274
|
-
_,
|
275
|
-
) = sgl_scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
|
276
|
-
else:
|
277
|
-
(
|
278
|
-
layer.w13_weight[expert_id][start : start + shard_size, :],
|
279
|
-
_,
|
280
|
-
) = vllm_ops.scaled_fp8_quant(
|
281
|
-
dq_weight, max_w13_scales[expert_id]
|
282
|
-
)
|
283
261
|
start += shard_size
|
284
262
|
|
285
263
|
layer.w13_weight_scale = torch.nn.Parameter(
|
@@ -305,6 +283,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|
305
283
|
inplace: bool = True,
|
306
284
|
no_combine: bool = False,
|
307
285
|
apply_router_weight_on_input: bool = False,
|
286
|
+
routed_scaling_factor: Optional[float] = None,
|
308
287
|
) -> torch.Tensor:
|
309
288
|
from sglang.srt.layers.moe.fused_moe_triton import fused_experts
|
310
289
|
from sglang.srt.layers.moe.topk import select_experts
|
@@ -319,6 +298,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|
319
298
|
num_expert_group=num_expert_group,
|
320
299
|
custom_routing_function=custom_routing_function,
|
321
300
|
correction_bias=correction_bias,
|
301
|
+
routed_scaling_factor=routed_scaling_factor,
|
322
302
|
)
|
323
303
|
|
324
304
|
return fused_experts(
|
@@ -345,11 +325,6 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
|
345
325
|
def __init__(
|
346
326
|
self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
|
347
327
|
):
|
348
|
-
from sglang.srt.layers.moe.fused_moe_triton import (
|
349
|
-
FusedMoEMethodBase,
|
350
|
-
FusedMoeWeightScaleSupported,
|
351
|
-
)
|
352
|
-
|
353
328
|
self.quant_config = quant_config
|
354
329
|
# TODO: @dsikka: refactor this to use schemes as other kernels
|
355
330
|
# are supported + check if the layer is being ignored.
|
@@ -609,7 +584,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
|
609
584
|
requires_grad=False,
|
610
585
|
)
|
611
586
|
|
612
|
-
marlin_w13_qweight =
|
587
|
+
marlin_w13_qweight = vllm_ops.gptq_marlin_moe_repack(
|
613
588
|
layer.w13_weight_packed,
|
614
589
|
layer.w13_g_idx_sort_indices,
|
615
590
|
layer.w13_weight_packed.shape[1] * self.packed_factor,
|
@@ -617,7 +592,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
|
617
592
|
self.num_bits,
|
618
593
|
)
|
619
594
|
replace_tensor("w13_weight_packed", marlin_w13_qweight)
|
620
|
-
marlin_w2_qweight =
|
595
|
+
marlin_w2_qweight = vllm_ops.gptq_marlin_moe_repack(
|
621
596
|
layer.w2_weight_packed,
|
622
597
|
layer.w2_g_idx_sort_indices,
|
623
598
|
layer.w2_weight_packed.shape[1] * self.packed_factor,
|
@@ -660,15 +635,11 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
|
660
635
|
scoring_func: str = "softmax",
|
661
636
|
correction_bias: Optional[torch.Tensor] = None,
|
662
637
|
activation: str = "silu",
|
638
|
+
routed_scaling_factor: Optional[float] = None,
|
663
639
|
) -> torch.Tensor:
|
664
|
-
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
665
640
|
from sglang.srt.layers.moe.topk import select_experts
|
666
641
|
|
667
642
|
assert activation == "silu", "Only SiLU activation is supported."
|
668
|
-
if not VLLM_AVAILABLE:
|
669
|
-
raise ImportError(
|
670
|
-
"vllm is not installed, to use fused_marlin_moe, please install vllm"
|
671
|
-
)
|
672
643
|
if expert_map is not None:
|
673
644
|
raise NotImplementedError(
|
674
645
|
"Expert Parallelism is not supported for " "fused Marlin MoE method."
|
@@ -685,6 +656,7 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
|
685
656
|
custom_routing_function=custom_routing_function,
|
686
657
|
scoring_func=scoring_func,
|
687
658
|
correction_bias=correction_bias,
|
659
|
+
routed_scaling_factor=routed_scaling_factor,
|
688
660
|
)
|
689
661
|
|
690
662
|
return torch.ops.vllm.fused_marlin_moe(
|
@@ -2,8 +2,10 @@
|
|
2
2
|
|
3
3
|
from .compressed_tensors_scheme import CompressedTensorsScheme
|
4
4
|
from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8
|
5
|
+
from .compressed_tensors_w8a16_fp8 import CompressedTensorsW8A16Fp8
|
5
6
|
|
6
7
|
__all__ = [
|
7
8
|
"CompressedTensorsScheme",
|
8
9
|
"CompressedTensorsW8A8Fp8",
|
10
|
+
"CompressedTensorsW8A16Fp8",
|
9
11
|
]
|
@@ -0,0 +1,153 @@
|
|
1
|
+
# Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors
|
2
|
+
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
|
4
|
+
from typing import Callable, List, Optional
|
5
|
+
|
6
|
+
import torch
|
7
|
+
from compressed_tensors.quantization import QuantizationStrategy
|
8
|
+
|
9
|
+
from sglang.srt.layers.parameter import (
|
10
|
+
ChannelQuantScaleParameter,
|
11
|
+
ModelWeightParameter,
|
12
|
+
PerTensorScaleParameter,
|
13
|
+
)
|
14
|
+
from sglang.srt.layers.quantization.compressed_tensors.schemes import (
|
15
|
+
CompressedTensorsScheme,
|
16
|
+
)
|
17
|
+
from sglang.srt.layers.quantization.utils import convert_to_channelwise
|
18
|
+
|
19
|
+
try:
|
20
|
+
from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
|
21
|
+
apply_fp8_marlin_linear,
|
22
|
+
prepare_fp8_layer_for_marlin,
|
23
|
+
)
|
24
|
+
|
25
|
+
MARLIN_FP8_AVAILABLE = True
|
26
|
+
except ImportError:
|
27
|
+
MARLIN_FP8_AVAILABLE = False
|
28
|
+
|
29
|
+
def apply_fp8_marlin_linear(*args, **kwargs):
|
30
|
+
raise ImportError("vllm is not installed")
|
31
|
+
|
32
|
+
def prepare_fp8_layer_for_marlin(*args, **kwargs):
|
33
|
+
raise ImportError("vllm is not installed")
|
34
|
+
|
35
|
+
|
36
|
+
__all__ = ["CompressedTensorsW8A16Fp8"]
|
37
|
+
|
38
|
+
SUPPORTED_STRATEGIES = [QuantizationStrategy.CHANNEL, QuantizationStrategy.TENSOR]
|
39
|
+
|
40
|
+
|
41
|
+
class CompressedTensorsW8A16Fp8(CompressedTensorsScheme):
|
42
|
+
def __init__(self, strategy: str, is_static_input_scheme: bool):
|
43
|
+
self.strategy = strategy
|
44
|
+
self.is_static_input_scheme = is_static_input_scheme
|
45
|
+
|
46
|
+
if not MARLIN_FP8_AVAILABLE:
|
47
|
+
raise ImportError(
|
48
|
+
"vllm is not installed. To use CompressedTensorsW8A16Fp8, please install vllm"
|
49
|
+
)
|
50
|
+
|
51
|
+
@classmethod
|
52
|
+
def get_min_capability(cls) -> int:
|
53
|
+
# ampere and up
|
54
|
+
return 80
|
55
|
+
|
56
|
+
# W8A8-Fp8 kernels support only per-tensor and per-channel cases.
|
57
|
+
# So if we have a fused module (QKV, MLP) with per tensor scales,
|
58
|
+
# we expand each scale to its shard's channels.
|
59
|
+
def process_weights_after_loading(self, layer) -> None:
|
60
|
+
if self.strategy == QuantizationStrategy.TENSOR:
|
61
|
+
ws_channelwise = convert_to_channelwise(
|
62
|
+
layer.weight_scale, layer.logical_widths
|
63
|
+
)
|
64
|
+
layer.weight_scale = torch.nn.Parameter(ws_channelwise, requires_grad=False)
|
65
|
+
else:
|
66
|
+
# required by torch.compile to be torch.nn.Parameter
|
67
|
+
layer.weight_scale = torch.nn.Parameter(
|
68
|
+
layer.weight_scale.data, requires_grad=False
|
69
|
+
)
|
70
|
+
|
71
|
+
# Weights must be transposed for marlin
|
72
|
+
layer.weight = torch.nn.Parameter(layer.weight.t(), requires_grad=False)
|
73
|
+
|
74
|
+
if self.is_static_input_scheme:
|
75
|
+
# required by torch.compile to be torch.nn.Parameter
|
76
|
+
layer.input_scale = torch.nn.Parameter(
|
77
|
+
layer.input_scale.data, requires_grad=False
|
78
|
+
)
|
79
|
+
prepare_fp8_layer_for_marlin(layer, strategy="channel")
|
80
|
+
|
81
|
+
def create_weights(
|
82
|
+
self,
|
83
|
+
layer: torch.nn.Module,
|
84
|
+
input_size: int,
|
85
|
+
output_partition_sizes: List[int],
|
86
|
+
input_size_per_partition: int,
|
87
|
+
params_dtype: torch.dtype,
|
88
|
+
weight_loader: Callable,
|
89
|
+
**kwargs,
|
90
|
+
):
|
91
|
+
output_size_per_partition = sum(output_partition_sizes)
|
92
|
+
layer.logical_widths = output_partition_sizes
|
93
|
+
layer.input_size_per_partition = input_size_per_partition
|
94
|
+
layer.output_size_per_partition = output_size_per_partition
|
95
|
+
layer.orig_dtype = params_dtype
|
96
|
+
|
97
|
+
# WEIGHT
|
98
|
+
weight = ModelWeightParameter(
|
99
|
+
data=torch.empty(
|
100
|
+
output_size_per_partition,
|
101
|
+
input_size_per_partition,
|
102
|
+
dtype=torch.float8_e4m3fn,
|
103
|
+
),
|
104
|
+
input_dim=1,
|
105
|
+
output_dim=0,
|
106
|
+
weight_loader=weight_loader,
|
107
|
+
)
|
108
|
+
layer.register_parameter("weight", weight)
|
109
|
+
|
110
|
+
# WEIGHT SCALE
|
111
|
+
if self.strategy == QuantizationStrategy.CHANNEL:
|
112
|
+
weight_scale = ChannelQuantScaleParameter(
|
113
|
+
data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32),
|
114
|
+
output_dim=0,
|
115
|
+
weight_loader=weight_loader,
|
116
|
+
)
|
117
|
+
elif self.strategy == QuantizationStrategy.TENSOR:
|
118
|
+
weight_scale = PerTensorScaleParameter(
|
119
|
+
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
120
|
+
weight_loader=weight_loader,
|
121
|
+
)
|
122
|
+
else:
|
123
|
+
raise ValueError(
|
124
|
+
f"Unsupported weight strategy={self.strategy}, "
|
125
|
+
f"supported strategies are {SUPPORTED_STRATEGIES}"
|
126
|
+
)
|
127
|
+
|
128
|
+
weight_scale[:] = torch.finfo(torch.float32).min
|
129
|
+
layer.register_parameter("weight_scale", weight_scale)
|
130
|
+
|
131
|
+
# INPUT SCALE (to deal with converted checkpoints)
|
132
|
+
if self.is_static_input_scheme:
|
133
|
+
input_scale = PerTensorScaleParameter(
|
134
|
+
data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
|
135
|
+
weight_loader=weight_loader,
|
136
|
+
)
|
137
|
+
layer.register_parameter("input_scale", input_scale)
|
138
|
+
|
139
|
+
def apply_weights(
|
140
|
+
self,
|
141
|
+
layer: torch.nn.Module,
|
142
|
+
x: torch.Tensor,
|
143
|
+
bias: Optional[torch.Tensor] = None,
|
144
|
+
) -> torch.Tensor:
|
145
|
+
return apply_fp8_marlin_linear(
|
146
|
+
input=x,
|
147
|
+
weight=layer.weight,
|
148
|
+
weight_scale=layer.weight_scale,
|
149
|
+
workspace=layer.workspace,
|
150
|
+
size_n=layer.output_size_per_partition,
|
151
|
+
size_k=layer.input_size_per_partition,
|
152
|
+
bias=bias,
|
153
|
+
)
|
@@ -16,8 +16,7 @@ from sglang.srt.layers.quantization.compressed_tensors.schemes import (
|
|
16
16
|
CompressedTensorsScheme,
|
17
17
|
)
|
18
18
|
from sglang.srt.layers.quantization.fp8_utils import (
|
19
|
-
|
20
|
-
maybe_create_device_identity,
|
19
|
+
apply_fp8_linear,
|
21
20
|
normalize_e4m3fn_to_e4m3fnuz,
|
22
21
|
)
|
23
22
|
from sglang.srt.layers.quantization.utils import is_fp8_fnuz, requantize_with_max_scale
|
@@ -30,7 +29,6 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
|
30
29
|
def __init__(self, strategy: str, is_static_input_scheme: bool):
|
31
30
|
self.strategy = strategy
|
32
31
|
self.is_static_input_scheme = is_static_input_scheme
|
33
|
-
self.fp8_linear = Fp8LinearOp(use_per_token_if_dynamic=True)
|
34
32
|
|
35
33
|
@classmethod
|
36
34
|
def get_min_capability(cls) -> int:
|
@@ -99,8 +97,6 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
|
99
97
|
weight_loader: Callable,
|
100
98
|
**kwargs,
|
101
99
|
):
|
102
|
-
maybe_create_device_identity()
|
103
|
-
|
104
100
|
output_size_per_partition = sum(output_partition_sizes)
|
105
101
|
layer.logical_widths = output_partition_sizes
|
106
102
|
|
@@ -152,11 +148,12 @@ class CompressedTensorsW8A8Fp8(CompressedTensorsScheme):
|
|
152
148
|
x: torch.Tensor,
|
153
149
|
bias: Optional[torch.Tensor] = None,
|
154
150
|
) -> torch.Tensor:
|
155
|
-
|
156
|
-
return self.fp8_linear.apply(
|
151
|
+
return apply_fp8_linear(
|
157
152
|
input=x,
|
158
153
|
weight=layer.weight,
|
159
154
|
weight_scale=layer.weight_scale,
|
160
155
|
input_scale=layer.input_scale,
|
161
156
|
bias=bias,
|
157
|
+
use_per_token_if_dynamic=True,
|
158
|
+
compressed_tensor_quant=True,
|
162
159
|
)
|