sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +2 -1
- sglang/eval/loogle_eval.py +7 -0
- sglang/srt/configs/deepseekvl2.py +11 -2
- sglang/srt/configs/internvl.py +3 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/model_config.py +9 -7
- sglang/srt/configs/update_config.py +3 -1
- sglang/srt/conversation.py +1 -0
- sglang/srt/custom_op.py +5 -2
- sglang/srt/disaggregation/decode.py +9 -1
- sglang/srt/disaggregation/mooncake/conn.py +44 -56
- sglang/srt/distributed/parallel_state.py +33 -0
- sglang/srt/entrypoints/engine.py +30 -26
- sglang/srt/entrypoints/openai/serving_chat.py +21 -2
- sglang/srt/eplb/expert_location_dispatch.py +1 -1
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/qwen3_detector.py +150 -0
- sglang/srt/hf_transformers_utils.py +0 -1
- sglang/srt/layers/activation.py +13 -0
- sglang/srt/layers/attention/flashattention_backend.py +3 -3
- sglang/srt/layers/attention/flashinfer_backend.py +40 -1
- sglang/srt/layers/linear.py +13 -102
- sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
- sglang/srt/layers/moe/ep_moe/layer.py +23 -402
- sglang/srt/layers/moe/fused_moe_native.py +7 -47
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +35 -45
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
- sglang/srt/layers/moe/topk.py +187 -12
- sglang/srt/layers/quantization/__init__.py +20 -134
- sglang/srt/layers/quantization/awq.py +578 -11
- sglang/srt/layers/quantization/awq_triton.py +339 -0
- sglang/srt/layers/quantization/base_config.py +85 -10
- sglang/srt/layers/quantization/blockwise_int8.py +17 -55
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +24 -73
- sglang/srt/layers/quantization/fp8.py +273 -62
- sglang/srt/layers/quantization/fp8_kernel.py +210 -46
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gptq.py +501 -143
- sglang/srt/layers/quantization/marlin_utils.py +790 -0
- sglang/srt/layers/quantization/modelopt_quant.py +26 -108
- sglang/srt/layers/quantization/moe_wna16.py +45 -49
- sglang/srt/layers/quantization/petit.py +252 -0
- sglang/srt/layers/quantization/petit_utils.py +104 -0
- sglang/srt/layers/quantization/qoq.py +7 -6
- sglang/srt/layers/quantization/scalar_type.py +352 -0
- sglang/srt/layers/quantization/unquant.py +422 -0
- sglang/srt/layers/quantization/utils.py +343 -3
- sglang/srt/layers/quantization/w4afp8.py +8 -4
- sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
- sglang/srt/layers/quantization/w8a8_int8.py +51 -115
- sglang/srt/layers/vocab_parallel_embedding.py +1 -41
- sglang/srt/lora/lora.py +0 -4
- sglang/srt/lora/lora_manager.py +87 -53
- sglang/srt/lora/mem_pool.py +81 -33
- sglang/srt/lora/utils.py +12 -5
- sglang/srt/managers/cache_controller.py +241 -0
- sglang/srt/managers/io_struct.py +41 -29
- sglang/srt/managers/mm_utils.py +7 -8
- sglang/srt/managers/schedule_batch.py +150 -110
- sglang/srt/managers/schedule_policy.py +68 -27
- sglang/srt/managers/scheduler.py +243 -61
- sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
- sglang/srt/managers/tokenizer_manager.py +11 -3
- sglang/srt/managers/tp_worker.py +14 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/allocator.py +7 -16
- sglang/srt/mem_cache/base_prefix_cache.py +14 -2
- sglang/srt/mem_cache/chunk_cache.py +5 -2
- sglang/srt/mem_cache/hicache_storage.py +152 -0
- sglang/srt/mem_cache/hiradix_cache.py +179 -4
- sglang/srt/mem_cache/memory_pool.py +16 -1
- sglang/srt/mem_cache/memory_pool_host.py +41 -2
- sglang/srt/mem_cache/radix_cache.py +26 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +5 -6
- sglang/srt/model_executor/forward_batch_info.py +14 -1
- sglang/srt/model_executor/model_runner.py +109 -22
- sglang/srt/model_loader/loader.py +7 -1
- sglang/srt/model_loader/utils.py +4 -4
- sglang/srt/models/clip.py +1 -1
- sglang/srt/models/deepseek.py +9 -6
- sglang/srt/models/deepseek_janus_pro.py +1 -1
- sglang/srt/models/deepseek_v2.py +191 -171
- sglang/srt/models/deepseek_vl2.py +5 -5
- sglang/srt/models/gemma.py +48 -0
- sglang/srt/models/gemma2.py +52 -0
- sglang/srt/models/gemma3_causal.py +63 -0
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -4
- sglang/srt/models/granitemoe.py +385 -0
- sglang/srt/models/grok.py +9 -3
- sglang/srt/models/hunyuan.py +63 -16
- sglang/srt/models/internvl.py +1 -1
- sglang/srt/models/kimi_vl.py +1 -1
- sglang/srt/models/llama.py +41 -0
- sglang/srt/models/llama4.py +11 -11
- sglang/srt/models/llava.py +2 -2
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/minicpm.py +0 -2
- sglang/srt/models/minicpmo.py +3 -7
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mixtral.py +9 -2
- sglang/srt/models/mllama.py +3 -5
- sglang/srt/models/mllama4.py +3 -3
- sglang/srt/models/olmoe.py +8 -5
- sglang/srt/models/persimmon.py +330 -0
- sglang/srt/models/phi.py +321 -0
- sglang/srt/models/phi4mm.py +44 -4
- sglang/srt/models/phi4mm_audio.py +1260 -0
- sglang/srt/models/phi4mm_utils.py +1917 -0
- sglang/srt/models/phimoe.py +9 -3
- sglang/srt/models/qwen.py +37 -0
- sglang/srt/models/qwen2.py +41 -0
- sglang/srt/models/qwen2_5_vl.py +4 -4
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +53 -5
- sglang/srt/models/qwen2_vl.py +4 -4
- sglang/srt/models/qwen3.py +65 -1
- sglang/srt/models/qwen3_moe.py +56 -18
- sglang/srt/models/vila.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +91 -97
- sglang/srt/multimodal/processors/clip.py +21 -19
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
- sglang/srt/multimodal/processors/gemma3.py +13 -17
- sglang/srt/multimodal/processors/gemma3n.py +19 -23
- sglang/srt/multimodal/processors/internvl.py +9 -10
- sglang/srt/multimodal/processors/janus_pro.py +12 -27
- sglang/srt/multimodal/processors/kimi_vl.py +12 -14
- sglang/srt/multimodal/processors/llava.py +4 -2
- sglang/srt/multimodal/processors/minicpm.py +35 -44
- sglang/srt/multimodal/processors/mlama.py +21 -18
- sglang/srt/multimodal/processors/mllama4.py +4 -5
- sglang/srt/multimodal/processors/phi4mm.py +63 -39
- sglang/srt/multimodal/processors/pixtral.py +14 -35
- sglang/srt/multimodal/processors/qwen_audio.py +65 -0
- sglang/srt/multimodal/processors/qwen_vl.py +16 -21
- sglang/srt/multimodal/processors/vila.py +14 -14
- sglang/srt/sampling/sampling_params.py +8 -1
- sglang/srt/server_args.py +393 -230
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +9 -1
- sglang/srt/two_batch_overlap.py +1 -0
- sglang/srt/utils.py +27 -1
- sglang/test/runners.py +14 -3
- sglang/test/test_block_fp8.py +8 -3
- sglang/test/test_block_fp8_ep.py +1 -1
- sglang/test/test_custom_ops.py +12 -7
- sglang/test/test_cutlass_w4a8_moe.py +1 -3
- sglang/test/test_fp4_moe.py +1 -3
- sglang/test/test_marlin_moe.py +286 -0
- sglang/test/test_marlin_utils.py +171 -0
- sglang/test/test_utils.py +35 -0
- sglang/version.py +1 -1
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/METADATA +8 -8
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/RECORD +166 -146
- sglang/srt/layers/quantization/quant_utils.py +0 -166
- sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post3.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,6 @@
|
|
1
1
|
# Adapted from https://github.com/vllm-project/vllm/tree/v0.8.2/vllm/model_executor/layers/quantization/compressed_tensors
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
from __future__ import annotations
|
3
4
|
|
4
5
|
import logging
|
5
6
|
from contextlib import suppress
|
@@ -18,12 +19,8 @@ from compressed_tensors.quantization import (
|
|
18
19
|
)
|
19
20
|
from pydantic import BaseModel
|
20
21
|
|
21
|
-
from sglang.srt.layers.linear import (
|
22
|
-
LinearBase,
|
23
|
-
LinearMethodBase,
|
24
|
-
UnquantizedLinearMethod,
|
25
|
-
)
|
26
22
|
from sglang.srt.layers.quantization.base_config import (
|
23
|
+
LinearMethodBase,
|
27
24
|
QuantizationConfig,
|
28
25
|
QuantizeMethodBase,
|
29
26
|
)
|
@@ -40,9 +37,13 @@ from sglang.srt.layers.quantization.compressed_tensors.utils import (
|
|
40
37
|
is_activation_quantization_format,
|
41
38
|
should_ignore_layer,
|
42
39
|
)
|
40
|
+
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
43
41
|
|
44
42
|
try:
|
45
|
-
import
|
43
|
+
from vllm.model_executor.layers.quantization.compressed_tensors.schemes.compressed_tensors_wNa16 import (
|
44
|
+
WNA16_SUPPORTED_BITS,
|
45
|
+
CompressedTensorsWNA16,
|
46
|
+
)
|
46
47
|
|
47
48
|
VLLM_AVAILABLE = True
|
48
49
|
except ImportError:
|
@@ -97,7 +98,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|
97
98
|
self.config = config
|
98
99
|
self.packed_modules_mapping = packed_modules_mapping
|
99
100
|
|
100
|
-
def get_linear_method(self) ->
|
101
|
+
def get_linear_method(self) -> CompressedTensorsLinearMethod:
|
101
102
|
return CompressedTensorsLinearMethod(self)
|
102
103
|
|
103
104
|
def get_supported_act_dtypes(cls) -> List[torch.dtype]:
|
@@ -117,7 +118,8 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|
117
118
|
self,
|
118
119
|
layer: torch.nn.Module,
|
119
120
|
prefix: str,
|
120
|
-
) -> Optional[
|
121
|
+
) -> Optional[QuantizeMethodBase]:
|
122
|
+
from sglang.srt.layers.linear import LinearBase
|
121
123
|
|
122
124
|
# Check if the layer is skipped for quantization.
|
123
125
|
# TODO (@robertgshaw2): support module names
|
@@ -138,7 +140,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|
138
140
|
return None
|
139
141
|
|
140
142
|
@classmethod
|
141
|
-
def from_config(cls, config: Dict[str, Any]) ->
|
143
|
+
def from_config(cls, config: Dict[str, Any]) -> CompressedTensorsConfig:
|
142
144
|
ignore: List[str] = cast(List[str], config.get("ignore", []))
|
143
145
|
quant_format = cast(str, config.get("format"))
|
144
146
|
target_scheme_map = cls._quantization_scheme_map_from_config(config=config)
|
@@ -357,7 +359,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|
357
359
|
|
358
360
|
def _get_scheme_from_parts(
|
359
361
|
self, weight_quant: BaseModel, input_quant: BaseModel
|
360
|
-
) ->
|
362
|
+
) -> CompressedTensorsScheme:
|
361
363
|
|
362
364
|
# Detect If Mixed Precision
|
363
365
|
if self._is_wNa16_group_channel(weight_quant, input_quant):
|
@@ -435,7 +437,7 @@ class CompressedTensorsConfig(QuantizationConfig):
|
|
435
437
|
|
436
438
|
def get_scheme(
|
437
439
|
self, layer: torch.nn.Module, layer_name: Optional[str] = None
|
438
|
-
) -> Optional[
|
440
|
+
) -> Optional[CompressedTensorsScheme]:
|
439
441
|
"""
|
440
442
|
compressed-tensors supports non uniform in the following way:
|
441
443
|
|
@@ -1,15 +1,17 @@
|
|
1
1
|
# Adapted from https://github.com/vllm-project/vllm/tree/v0.8.2/vllm/model_executor/layers/quantization/compressed_tensors
|
2
2
|
# SPDX-License-Identifier: Apache-2.0
|
3
|
+
from __future__ import annotations
|
3
4
|
|
4
5
|
import enum
|
5
6
|
import logging
|
6
7
|
from enum import Enum
|
7
|
-
from typing import
|
8
|
+
from typing import TYPE_CHECKING, List, Optional
|
8
9
|
|
9
10
|
import torch
|
10
11
|
from compressed_tensors import CompressionFormat
|
11
12
|
from compressed_tensors.quantization import QuantizationStrategy
|
12
13
|
|
14
|
+
from sglang.srt.layers.quantization.base_config import FusedMoEMethodBase
|
13
15
|
from sglang.srt.layers.quantization.fp8_kernel import is_fp8_fnuz, scaled_fp8_quant
|
14
16
|
from sglang.srt.layers.quantization.fp8_utils import normalize_e4m3fn_to_e4m3fnuz
|
15
17
|
from sglang.srt.layers.quantization.utils import (
|
@@ -18,14 +20,21 @@ from sglang.srt.layers.quantization.utils import (
|
|
18
20
|
per_tensor_dequantize,
|
19
21
|
replace_parameter,
|
20
22
|
)
|
21
|
-
from sglang.srt.utils import is_cpu, is_cuda, is_npu, set_weight_attrs
|
23
|
+
from sglang.srt.utils import is_cpu, is_cuda, is_hip, is_npu, set_weight_attrs
|
24
|
+
|
25
|
+
if TYPE_CHECKING:
|
26
|
+
from sglang.srt.layers.moe.topk import TopKOutput
|
27
|
+
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import (
|
28
|
+
CompressedTensorsConfig,
|
29
|
+
)
|
22
30
|
|
23
31
|
_is_cuda = is_cuda()
|
24
32
|
_is_npu = is_npu()
|
25
33
|
_is_cpu_amx_available = cpu_has_amx_support()
|
26
34
|
_is_cpu = is_cpu()
|
35
|
+
_is_hip = is_hip()
|
27
36
|
|
28
|
-
if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available)):
|
37
|
+
if not (_is_cuda or _is_npu or (_is_cpu and _is_cpu_amx_available) or _is_hip):
|
29
38
|
from vllm import _custom_ops as vllm_ops
|
30
39
|
from vllm._custom_ops import scaled_fp8_quant
|
31
40
|
|
@@ -51,7 +60,7 @@ __all__ = [
|
|
51
60
|
]
|
52
61
|
|
53
62
|
|
54
|
-
class CompressedTensorsMoEMethod:
|
63
|
+
class CompressedTensorsMoEMethod(FusedMoEMethodBase):
|
55
64
|
def __new__(cls, *args, **kwargs):
|
56
65
|
if cls is CompressedTensorsMoEMethod:
|
57
66
|
return super().__new__(cls)
|
@@ -59,7 +68,7 @@ class CompressedTensorsMoEMethod:
|
|
59
68
|
|
60
69
|
@staticmethod
|
61
70
|
def get_moe_method(
|
62
|
-
quant_config:
|
71
|
+
quant_config: CompressedTensorsConfig,
|
63
72
|
) -> "CompressedTensorsMoEMethod":
|
64
73
|
# TODO: @dsikka: refactor this to use schemes as other kernels
|
65
74
|
# are supported + check if the layer is being ignored.
|
@@ -82,9 +91,7 @@ class CompressedTensorsMoEMethod:
|
|
82
91
|
|
83
92
|
class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
84
93
|
|
85
|
-
def __init__(
|
86
|
-
self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
|
87
|
-
):
|
94
|
+
def __init__(self, quant_config: CompressedTensorsConfig):
|
88
95
|
self.quant_config = quant_config
|
89
96
|
self.weight_quant = self.quant_config.target_scheme_map["Linear"].get("weights")
|
90
97
|
self.input_quant = self.quant_config.target_scheme_map["Linear"].get(
|
@@ -270,47 +277,21 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|
270
277
|
self,
|
271
278
|
layer: torch.nn.Module,
|
272
279
|
x: torch.Tensor,
|
273
|
-
|
274
|
-
|
275
|
-
renormalize: bool,
|
276
|
-
use_grouped_topk: bool = False,
|
277
|
-
topk_group: Optional[int] = None,
|
278
|
-
num_expert_group: Optional[int] = None,
|
279
|
-
num_fused_shared_experts: int = 0,
|
280
|
-
global_num_experts: int = -1,
|
281
|
-
expert_map: Optional[torch.Tensor] = None,
|
282
|
-
custom_routing_function: Optional[Callable] = None,
|
283
|
-
scoring_func: str = "softmax",
|
284
|
-
correction_bias: Optional[torch.Tensor] = None,
|
280
|
+
topk_output: TopKOutput,
|
281
|
+
*,
|
285
282
|
activation: str = "silu",
|
283
|
+
apply_router_weight_on_input: bool = False,
|
286
284
|
inplace: bool = True,
|
287
285
|
no_combine: bool = False,
|
288
|
-
apply_router_weight_on_input: bool = False,
|
289
286
|
routed_scaling_factor: Optional[float] = None,
|
290
287
|
) -> torch.Tensor:
|
291
288
|
from sglang.srt.layers.moe.fused_moe_triton import fused_experts
|
292
|
-
from sglang.srt.layers.moe.topk import select_experts
|
293
|
-
|
294
|
-
topk_weights, topk_ids = select_experts(
|
295
|
-
hidden_states=x,
|
296
|
-
router_logits=router_logits,
|
297
|
-
use_grouped_topk=use_grouped_topk,
|
298
|
-
top_k=top_k,
|
299
|
-
renormalize=renormalize,
|
300
|
-
topk_group=topk_group,
|
301
|
-
num_expert_group=num_expert_group,
|
302
|
-
num_fused_shared_experts=num_fused_shared_experts,
|
303
|
-
custom_routing_function=custom_routing_function,
|
304
|
-
correction_bias=correction_bias,
|
305
|
-
routed_scaling_factor=routed_scaling_factor,
|
306
|
-
)
|
307
289
|
|
308
290
|
return fused_experts(
|
309
291
|
x,
|
310
292
|
layer.w13_weight,
|
311
293
|
layer.w2_weight,
|
312
|
-
|
313
|
-
topk_ids=topk_ids,
|
294
|
+
topk_output=topk_output,
|
314
295
|
inplace=inplace,
|
315
296
|
activation=activation,
|
316
297
|
use_fp8_w8a8=True,
|
@@ -327,9 +308,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|
327
308
|
|
328
309
|
class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
329
310
|
|
330
|
-
def __init__(
|
331
|
-
self, quant_config: "CompressedTensorsConfig" # type: ignore # noqa E501
|
332
|
-
):
|
311
|
+
def __init__(self, quant_config: CompressedTensorsConfig):
|
333
312
|
self.quant_config = quant_config
|
334
313
|
# TODO: @dsikka: refactor this to use schemes as other kernels
|
335
314
|
# are supported + check if the layer is being ignored.
|
@@ -628,43 +607,15 @@ class CompressedTensorsWNA16MoEMethod(CompressedTensorsMoEMethod):
|
|
628
607
|
self,
|
629
608
|
layer: torch.nn.Module,
|
630
609
|
x: torch.Tensor,
|
631
|
-
|
632
|
-
|
633
|
-
renormalize: bool,
|
634
|
-
use_grouped_topk: bool = False,
|
635
|
-
topk_group: Optional[int] = None,
|
636
|
-
num_expert_group: Optional[int] = None,
|
637
|
-
num_fused_shared_experts: int = 0,
|
638
|
-
global_num_experts: int = -1,
|
639
|
-
expert_map: Optional[torch.Tensor] = None,
|
640
|
-
custom_routing_function: Optional[Callable] = None,
|
641
|
-
scoring_func: str = "softmax",
|
642
|
-
correction_bias: Optional[torch.Tensor] = None,
|
610
|
+
topk_output: TopKOutput,
|
611
|
+
*,
|
643
612
|
activation: str = "silu",
|
644
|
-
|
613
|
+
**kwargs,
|
645
614
|
) -> torch.Tensor:
|
646
|
-
from sglang.srt.layers.moe.topk import select_experts
|
647
615
|
|
648
616
|
assert activation == "silu", "Only SiLU activation is supported."
|
649
|
-
if expert_map is not None:
|
650
|
-
raise NotImplementedError(
|
651
|
-
"Expert Parallelism is not supported for " "fused Marlin MoE method."
|
652
|
-
)
|
653
617
|
|
654
|
-
topk_weights, topk_ids =
|
655
|
-
hidden_states=x,
|
656
|
-
router_logits=router_logits,
|
657
|
-
use_grouped_topk=use_grouped_topk,
|
658
|
-
top_k=top_k,
|
659
|
-
renormalize=renormalize,
|
660
|
-
topk_group=topk_group,
|
661
|
-
num_expert_group=num_expert_group,
|
662
|
-
num_fused_shared_experts=num_fused_shared_experts,
|
663
|
-
custom_routing_function=custom_routing_function,
|
664
|
-
scoring_func=scoring_func,
|
665
|
-
correction_bias=correction_bias,
|
666
|
-
routed_scaling_factor=routed_scaling_factor,
|
667
|
-
)
|
618
|
+
topk_weights, topk_ids, router_logits = topk_output
|
668
619
|
|
669
620
|
return torch.ops.vllm.fused_marlin_moe(
|
670
621
|
x,
|