sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +2 -1
- sglang/eval/loogle_eval.py +7 -0
- sglang/srt/_custom_ops.py +29 -1
- sglang/srt/configs/deepseekvl2.py +11 -2
- sglang/srt/configs/internvl.py +3 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/model_config.py +10 -8
- sglang/srt/configs/update_config.py +3 -1
- sglang/srt/conversation.py +2 -1
- sglang/srt/custom_op.py +5 -2
- sglang/srt/disaggregation/common/conn.py +34 -6
- sglang/srt/disaggregation/decode.py +9 -1
- sglang/srt/disaggregation/mini_lb.py +3 -2
- sglang/srt/disaggregation/mooncake/conn.py +93 -76
- sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
- sglang/srt/disaggregation/nixl/conn.py +17 -13
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
- sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
- sglang/srt/distributed/parallel_state.py +103 -15
- sglang/srt/entrypoints/engine.py +31 -33
- sglang/srt/entrypoints/http_server.py +20 -32
- sglang/srt/entrypoints/openai/protocol.py +3 -3
- sglang/srt/entrypoints/openai/serving_chat.py +48 -6
- sglang/srt/eplb/expert_location_dispatch.py +1 -1
- sglang/srt/function_call/base_format_detector.py +74 -12
- sglang/srt/function_call/deepseekv3_detector.py +26 -11
- sglang/srt/function_call/ebnf_composer.py +95 -63
- sglang/srt/function_call/function_call_parser.py +4 -2
- sglang/srt/function_call/kimik2_detector.py +41 -16
- sglang/srt/function_call/llama32_detector.py +6 -3
- sglang/srt/function_call/mistral_detector.py +11 -3
- sglang/srt/function_call/pythonic_detector.py +16 -14
- sglang/srt/function_call/qwen25_detector.py +12 -3
- sglang/srt/function_call/qwen3_coder_detector.py +151 -0
- sglang/srt/hf_transformers_utils.py +0 -1
- sglang/srt/layers/activation.py +24 -3
- sglang/srt/layers/attention/base_attn_backend.py +3 -1
- sglang/srt/layers/attention/flashattention_backend.py +3 -3
- sglang/srt/layers/attention/flashinfer_backend.py +40 -1
- sglang/srt/layers/communicator.py +12 -12
- sglang/srt/layers/dp_attention.py +72 -24
- sglang/srt/layers/linear.py +13 -102
- sglang/srt/layers/logits_processor.py +34 -24
- sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
- sglang/srt/layers/moe/ep_moe/layer.py +23 -402
- sglang/srt/layers/moe/fused_moe_native.py +7 -47
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +54 -263
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
- sglang/srt/layers/moe/topk.py +190 -23
- sglang/srt/layers/quantization/__init__.py +20 -134
- sglang/srt/layers/quantization/awq.py +578 -11
- sglang/srt/layers/quantization/awq_triton.py +339 -0
- sglang/srt/layers/quantization/base_config.py +85 -10
- sglang/srt/layers/quantization/blockwise_int8.py +17 -55
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +23 -79
- sglang/srt/layers/quantization/fp8.py +273 -62
- sglang/srt/layers/quantization/fp8_kernel.py +210 -46
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gptq.py +501 -143
- sglang/srt/layers/quantization/marlin_utils.py +790 -0
- sglang/srt/layers/quantization/modelopt_quant.py +34 -112
- sglang/srt/layers/quantization/moe_wna16.py +45 -49
- sglang/srt/layers/quantization/petit.py +252 -0
- sglang/srt/layers/quantization/petit_utils.py +104 -0
- sglang/srt/layers/quantization/qoq.py +7 -6
- sglang/srt/layers/quantization/scalar_type.py +352 -0
- sglang/srt/layers/quantization/unquant.py +422 -0
- sglang/srt/layers/quantization/utils.py +340 -9
- sglang/srt/layers/quantization/w4afp8.py +8 -4
- sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
- sglang/srt/layers/quantization/w8a8_int8.py +51 -115
- sglang/srt/layers/radix_attention.py +5 -3
- sglang/srt/layers/vocab_parallel_embedding.py +1 -41
- sglang/srt/lora/lora.py +0 -4
- sglang/srt/lora/lora_manager.py +162 -164
- sglang/srt/lora/lora_registry.py +124 -0
- sglang/srt/lora/mem_pool.py +83 -35
- sglang/srt/lora/utils.py +12 -5
- sglang/srt/managers/cache_controller.py +288 -0
- sglang/srt/managers/io_struct.py +60 -30
- sglang/srt/managers/mm_utils.py +7 -8
- sglang/srt/managers/schedule_batch.py +163 -113
- sglang/srt/managers/schedule_policy.py +68 -27
- sglang/srt/managers/scheduler.py +256 -86
- sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
- sglang/srt/managers/tokenizer_manager.py +38 -27
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/allocator.py +74 -23
- sglang/srt/mem_cache/base_prefix_cache.py +14 -2
- sglang/srt/mem_cache/chunk_cache.py +5 -2
- sglang/srt/mem_cache/hicache_storage.py +168 -0
- sglang/srt/mem_cache/hiradix_cache.py +194 -5
- sglang/srt/mem_cache/memory_pool.py +16 -1
- sglang/srt/mem_cache/memory_pool_host.py +44 -2
- sglang/srt/mem_cache/radix_cache.py +26 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +66 -31
- sglang/srt/model_executor/forward_batch_info.py +210 -25
- sglang/srt/model_executor/model_runner.py +147 -42
- sglang/srt/model_loader/loader.py +7 -1
- sglang/srt/model_loader/utils.py +4 -4
- sglang/srt/models/clip.py +1 -1
- sglang/srt/models/deepseek.py +9 -6
- sglang/srt/models/deepseek_janus_pro.py +1 -1
- sglang/srt/models/deepseek_v2.py +192 -173
- sglang/srt/models/deepseek_vl2.py +5 -5
- sglang/srt/models/gemma.py +48 -0
- sglang/srt/models/gemma2.py +52 -0
- sglang/srt/models/gemma3_causal.py +63 -0
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -4
- sglang/srt/models/granitemoe.py +385 -0
- sglang/srt/models/grok.py +9 -3
- sglang/srt/models/hunyuan.py +63 -16
- sglang/srt/models/internvl.py +1 -1
- sglang/srt/models/kimi_vl.py +1 -1
- sglang/srt/models/llama.py +41 -0
- sglang/srt/models/llama4.py +11 -11
- sglang/srt/models/llava.py +2 -2
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/minicpm.py +0 -2
- sglang/srt/models/minicpmo.py +3 -7
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mixtral.py +9 -2
- sglang/srt/models/mllama.py +3 -5
- sglang/srt/models/mllama4.py +13 -6
- sglang/srt/models/olmoe.py +8 -5
- sglang/srt/models/persimmon.py +330 -0
- sglang/srt/models/phi.py +321 -0
- sglang/srt/models/phi4mm.py +44 -4
- sglang/srt/models/phi4mm_audio.py +1260 -0
- sglang/srt/models/phi4mm_utils.py +1917 -0
- sglang/srt/models/phimoe.py +9 -3
- sglang/srt/models/qwen.py +37 -0
- sglang/srt/models/qwen2.py +41 -0
- sglang/srt/models/qwen2_5_vl.py +4 -4
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +53 -9
- sglang/srt/models/qwen2_vl.py +4 -4
- sglang/srt/models/qwen3.py +65 -1
- sglang/srt/models/qwen3_moe.py +57 -24
- sglang/srt/models/vila.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +91 -97
- sglang/srt/multimodal/processors/clip.py +21 -19
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
- sglang/srt/multimodal/processors/gemma3.py +13 -17
- sglang/srt/multimodal/processors/gemma3n.py +19 -23
- sglang/srt/multimodal/processors/internvl.py +9 -10
- sglang/srt/multimodal/processors/janus_pro.py +12 -27
- sglang/srt/multimodal/processors/kimi_vl.py +12 -14
- sglang/srt/multimodal/processors/llava.py +4 -2
- sglang/srt/multimodal/processors/minicpm.py +35 -44
- sglang/srt/multimodal/processors/mlama.py +21 -18
- sglang/srt/multimodal/processors/mllama4.py +4 -5
- sglang/srt/multimodal/processors/phi4mm.py +63 -39
- sglang/srt/multimodal/processors/pixtral.py +14 -35
- sglang/srt/multimodal/processors/qwen_audio.py +65 -0
- sglang/srt/multimodal/processors/qwen_vl.py +16 -21
- sglang/srt/multimodal/processors/vila.py +14 -14
- sglang/srt/reasoning_parser.py +46 -4
- sglang/srt/sampling/sampling_batch_info.py +6 -5
- sglang/srt/sampling/sampling_params.py +8 -1
- sglang/srt/server_args.py +454 -270
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +46 -37
- sglang/srt/speculative/eagle_utils.py +51 -23
- sglang/srt/speculative/eagle_worker.py +59 -44
- sglang/srt/two_batch_overlap.py +10 -5
- sglang/srt/utils.py +44 -69
- sglang/test/runners.py +14 -3
- sglang/test/test_activation.py +50 -1
- sglang/test/test_block_fp8.py +8 -3
- sglang/test/test_block_fp8_ep.py +1 -1
- sglang/test/test_custom_ops.py +12 -7
- sglang/test/test_cutlass_w4a8_moe.py +1 -3
- sglang/test/test_fp4_moe.py +1 -3
- sglang/test/test_marlin_moe.py +286 -0
- sglang/test/test_marlin_utils.py +171 -0
- sglang/test/test_utils.py +35 -0
- sglang/version.py +1 -1
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/METADATA +10 -10
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/RECORD +198 -175
- sglang/srt/layers/quantization/quant_utils.py +0 -166
- sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,6 @@
|
|
1
1
|
# Adapted from https://github.com/vllm-project/vllm/tree/v0.8.2/vllm/model_executor/layers/quantization/compressed_tensors
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
from __future__ import annotations
|
3
4
|
|
4
5
|
import logging
|
5
6
|
from contextlib import suppress
|
@@ -18,12 +19,8 @@ from compressed_tensors.quantization import (
|
|
18
19
|
)
|
19
20
|
from pydantic import BaseModel
|
20
21
|
|
21
|
-
from sglang.srt.layers.linear import (
|
22
|
-
LinearBase,
|
23
|
-
LinearMethodBase,
|
24
|
-
UnquantizedLinearMethod,
|
25
|
-
)
|
26
22
|
from sglang.srt.layers.quantization.base_config import (
|
23
|
+
LinearMethodBase,
|
27
24
|
QuantizationConfig,
|
28
25
|
QuantizeMethodBase,
|
29
26
|
)
|
@@ -40,9 +37,13 @@ from sglang.srt.layers.quantization.compressed_tensors.utils import (
|
|
40
37
|
is_activation_quantization_format,
|
41
38
|
should_ignore_layer,
|
42
39
|
)
|
40
|
+
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
43
41
|
|
44
42
|
try:
|
45
|
-
import
|
43
|
+
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import (
|
44
|
+
WNA16_SUPPORTED_BITS,
|
45
|
+
CompressedTensorsWNA16,
|
46
|
+
)
|
46
47
|
|
47
48
|
VLLM_AVAILABLE = True
|
48
49
|
except ImportError:
|
@@ -97,7 +98,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|
97
98
|
self.config = config
|
98
99
|
self.packed_modules_mapping = packed_modules_mapping
|
99
100
|
|
100
|
-
def get_linear_method(self) ->
|
101
|
+
def get_linear_method(self) -> CompressedTensorsLinearMethod:
|
101
102
|
return CompressedTensorsLinearMethod(self)
|
102
103
|
|
103
104
|
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
@@ -117,7 +118,8 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|
117
118
|
self,
|
118
119
|
layer: torch.nn.Module,
|
119
120
|
prefix: str,
|
120
|
-
) -> Optional[
|
121
|
+
) -> Optional[QuantizeMethodBase]:
|
122
|
+
from sglang.srt.layers.linear import LinearBase
|
121
123
|
|
122
124
|
# Check if the layer is skipped for quantization.
|
123
125
|
# TODO (@robertgshaw2): support module names
|
@@ -138,7 +140,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|
138
140
|
return None
|
139
141
|
|
140
142
|
@classmethod
|
141
|
-
def from_config(cls, config: Dict[str, Any]) ->
|
143
|
+
def from_config(cls, config: Dict[str, Any]) -> CompressedTensorsConfig:
|
142
144
|
ignore: List[str] = cast(List[str], config.get("ignore", []))
|
143
145
|
quant_format = cast(str, config.get("format"))
|
144
146
|
target_scheme_map = cls._quantization_scheme_map_from_config(config=config)
|
@@ -357,7 +359,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|
357
359
|
|
358
360
|
def _get_scheme_from_parts(
|
359
361
|
self, weight_quant: BaseModel, input_quant: BaseModel
|
360
|
-
) ->
|
362
|
+
) -> CompressedTensorsScheme:
|
361
363
|
|
362
364
|
# Detect If Mixed Precision
|
363
365
|
if self._is_wNa16_group_channel(weight_quant, input_quant):
|
@@ -435,7 +437,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|
435
437
|
|
436
438
|
def get_scheme(
|
437
439
|
self, layer: torch.nn.Module, layer_name: Optional[str] = None
|
438
|
-
) -> Optional[
|
440
|
+
) -> Optional[CompressedTensorsScheme]:
|
439
441
|
"""
|
440
442
|
compressed-tensors supports non uniform in the following way:
|
441
443
|
|
@@ -1,15 +1,17 @@
|
|
1
1
|
# Adapted from https://github.com/vllm-project/vllm/tree/v0.8.2/vllm/model_executor/layers/quantization/compressed_tensors
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
from __future__ import annotations
|
3
4
|
|
4
5
|
import enum
|
5
6
|
import logging
|
6
7
|
from enum import Enum
|
7
|
-
from typing import
|
8
|
+
from typing import TYPE_CHECKING, List, Optional
|
8
9
|
|
9
10
|
import torch
|
10
11
|
from compressed_tensors import CompressionFormat
|
11
12
|
from compressed_tensors.quantization import QuantizationStrategy
|
12
13
|
|
14
|
+
from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase
|
13
15
|
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_quant
|
14
16
|
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
|
15
17
|
from sglang.srt.layers.quantization.utils import (
|
@@ -18,16 +20,14 @@ from sglang.srt.layers.quantization.utils import (
|
|
18
20
|
per_tensor_dequantize,
|
19
21
|
replace_parameter,
|
20
22
|
)
|
21
|
-
from sglang.srt.utils import is_cpu, is_cuda, is_npu, set_weight_attrs
|
23
|
+
from sglang.srt.utils import is_cpu, is_cuda, is_hip, is_npu, set_weight_attrs
|
22
24
|
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
25
|
+
if TYPE_CHECKING:
|
26
|
+
from sglang.srt.layers.moe.topk import TopKOutput
|
27
|
+
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import (
|
28
|
+
CompressedTensorsConfig,
|
29
|
+
)
|
27
30
|
|
28
|
-
if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available)):
|
29
|
-
from vllm import _custom_ops as vllm_ops
|
30
|
-
from vllm._custom_ops import scaled_fp8_quant
|
31
31
|
|
32
32
|
try:
|
33
33
|
import vllm
|
@@ -51,7 +51,7 @@ __all__ = [
|
|
51
51
|
]
|
52
52
|
|
53
53
|
|
54
|
-
class CompressedTensorsMoEMethod:
|
54
|
+
class CompressedTensorsMoEMethod(FusedMoEMethodBase):
|
55
55
|
def __new__(cls, *args, **kwargs):
|
56
56
|
if cls is CompressedTensorsMoEMethod:
|
57
57
|
return super().__new__(cls)
|
@@ -59,7 +59,7 @@ class CompressedTensorsMoEMethod:
|
|
59
59
|
|
60
60
|
@staticmethod
|
61
61
|
def get_moe_method(
|
62
|
-
quant_config:
|
62
|
+
quant_config: CompressedTensorsConfig,
|
63
63
|
) -> "CompressedTensorsMoEMethod":
|
64
64
|
# TODO: @dsikka: refactor this to use schemes as other kernels
|
65
65
|
# are supported + check if the layer is being ignored.
|
@@ -82,9 +82,7 @@ class CompressedTensorsMoEMethod:
|
|
82
82
|
|
83
83
|
class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
84
84
|
|
85
|
-
def __init__(
|
86
|
-
self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
|
87
|
-
):
|
85
|
+
def __init__(self, quant_config: CompressedTensorsConfig):
|
88
86
|
self.quant_config = quant_config
|
89
87
|
self.weight_quant = self.quant_config.target_scheme_map["Linear"].get("weights")
|
90
88
|
self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
|
@@ -270,47 +268,21 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|
270
268
|
self,
|
271
269
|
layer: torch.nn.Module,
|
272
270
|
x: torch.Tensor,
|
273
|
-
|
274
|
-
|
275
|
-
renormalize: bool,
|
276
|
-
use_grouped_topk: bool = False,
|
277
|
-
topk_group: Optional[int] = None,
|
278
|
-
num_expert_group: Optional[int] = None,
|
279
|
-
num_fused_shared_experts: int = 0,
|
280
|
-
global_num_experts: int = -1,
|
281
|
-
expert_map: Optional[torch.Tensor] = None,
|
282
|
-
custom_routing_function: Optional[Callable] = None,
|
283
|
-
scoring_func: str = "softmax",
|
284
|
-
correction_bias: Optional[torch.Tensor] = None,
|
271
|
+
topk_output: TopKOutput,
|
272
|
+
*,
|
285
273
|
activation: str = "silu",
|
274
|
+
apply_router_weight_on_input: bool = False,
|
286
275
|
inplace: bool = True,
|
287
276
|
no_combine: bool = False,
|
288
|
-
apply_router_weight_on_input: bool = False,
|
289
277
|
routed_scaling_factor: Optional[float] = None,
|
290
278
|
) -> torch.Tensor:
|
291
279
|
from sglang.srt.layers.moe.fused_moe_triton import fused_experts
|
292
|
-
from sglang.srt.layers.moe.topk import select_experts
|
293
|
-
|
294
|
-
topk_weights, topk_ids = select_experts(
|
295
|
-
hidden_states=x,
|
296
|
-
router_logits=router_logits,
|
297
|
-
use_grouped_topk=use_grouped_topk,
|
298
|
-
top_k=top_k,
|
299
|
-
renormalize=renormalize,
|
300
|
-
topk_group=topk_group,
|
301
|
-
num_expert_group=num_expert_group,
|
302
|
-
num_fused_shared_experts=num_fused_shared_experts,
|
303
|
-
custom_routing_function=custom_routing_function,
|
304
|
-
correction_bias=correction_bias,
|
305
|
-
routed_scaling_factor=routed_scaling_factor,
|
306
|
-
)
|
307
280
|
|
308
281
|
return fused_experts(
|
309
282
|
x,
|
310
283
|
layer.w13_weight,
|
311
284
|
layer.w2_weight,
|
312
|
-
|
313
|
-
topk_ids=topk_ids,
|
285
|
+
topk_output=topk_output,
|
314
286
|
inplace=inplace,
|
315
287
|
activation=activation,
|
316
288
|
use_fp8_w8a8=True,
|
@@ -327,9 +299,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|
327
299
|
|
328
300
|
class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
329
301
|
|
330
|
-
def __init__(
|
331
|
-
self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
|
332
|
-
):
|
302
|
+
def __init__(self, quant_config: CompressedTensorsConfig):
|
333
303
|
self.quant_config = quant_config
|
334
304
|
# TODO: @dsikka: refactor this to use schemes as other kernels
|
335
305
|
# are supported + check if the layer is being ignored.
|
@@ -589,6 +559,8 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
|
589
559
|
requires_grad=False,
|
590
560
|
)
|
591
561
|
|
562
|
+
from vllm import _custom_ops as vllm_ops
|
563
|
+
|
592
564
|
marlin_w13_qweight = vllm_ops.gptq_marlin_moe_repack(
|
593
565
|
layer.w13_weight_packed,
|
594
566
|
layer.w13_g_idx_sort_indices,
|
@@ -628,43 +600,15 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
|
628
600
|
self,
|
629
601
|
layer: torch.nn.Module,
|
630
602
|
x: torch.Tensor,
|
631
|
-
|
632
|
-
|
633
|
-
renormalize: bool,
|
634
|
-
use_grouped_topk: bool = False,
|
635
|
-
topk_group: Optional[int] = None,
|
636
|
-
num_expert_group: Optional[int] = None,
|
637
|
-
num_fused_shared_experts: int = 0,
|
638
|
-
global_num_experts: int = -1,
|
639
|
-
expert_map: Optional[torch.Tensor] = None,
|
640
|
-
custom_routing_function: Optional[Callable] = None,
|
641
|
-
scoring_func: str = "softmax",
|
642
|
-
correction_bias: Optional[torch.Tensor] = None,
|
603
|
+
topk_output: TopKOutput,
|
604
|
+
*,
|
643
605
|
activation: str = "silu",
|
644
|
-
|
606
|
+
**kwargs,
|
645
607
|
) -> torch.Tensor:
|
646
|
-
from sglang.srt.layers.moe.topk import select_experts
|
647
608
|
|
648
609
|
assert activation == "silu", "Only SiLU activation is supported."
|
649
|
-
if expert_map is not None:
|
650
|
-
raise NotImplementedError(
|
651
|
-
"Expert Parallelism is not supported for " "fused Marlin MoE method."
|
652
|
-
)
|
653
610
|
|
654
|
-
topk_weights, topk_ids =
|
655
|
-
hidden_states=x,
|
656
|
-
router_logits=router_logits,
|
657
|
-
use_grouped_topk=use_grouped_topk,
|
658
|
-
top_k=top_k,
|
659
|
-
renormalize=renormalize,
|
660
|
-
topk_group=topk_group,
|
661
|
-
num_expert_group=num_expert_group,
|
662
|
-
num_fused_shared_experts=num_fused_shared_experts,
|
663
|
-
custom_routing_function=custom_routing_function,
|
664
|
-
scoring_func=scoring_func,
|
665
|
-
correction_bias=correction_bias,
|
666
|
-
routed_scaling_factor=routed_scaling_factor,
|
667
|
-
)
|
611
|
+
topk_weights, topk_ids, router_logits = topk_output
|
668
612
|
|
669
613
|
return torch.ops.vllm.fused_marlin_moe(
|
670
614
|
x,
|