sglang 0.5.0rc1__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +0 -7
- sglang/bench_one_batch_server.py +7 -2
- sglang/bench_serving.py +3 -3
- sglang/eval/llama3_eval.py +0 -1
- sglang/srt/configs/model_config.py +25 -9
- sglang/srt/configs/update_config.py +40 -5
- sglang/srt/constrained/xgrammar_backend.py +23 -11
- sglang/srt/conversation.py +2 -15
- sglang/srt/disaggregation/ascend/conn.py +1 -3
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +1 -2
- sglang/srt/disaggregation/launch_lb.py +7 -1
- sglang/srt/disaggregation/mini_lb.py +11 -5
- sglang/srt/disaggregation/mooncake/conn.py +141 -47
- sglang/srt/disaggregation/prefill.py +261 -5
- sglang/srt/disaggregation/utils.py +2 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
- sglang/srt/distributed/device_communicators/pynccl.py +68 -18
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
- sglang/srt/distributed/naive_distributed.py +112 -0
- sglang/srt/distributed/parallel_state.py +90 -4
- sglang/srt/entrypoints/context.py +20 -1
- sglang/srt/entrypoints/engine.py +29 -4
- sglang/srt/entrypoints/http_server.py +76 -0
- sglang/srt/entrypoints/openai/protocol.py +4 -2
- sglang/srt/entrypoints/openai/serving_chat.py +23 -6
- sglang/srt/entrypoints/openai/serving_completions.py +10 -1
- sglang/srt/entrypoints/openai/serving_responses.py +2 -2
- sglang/srt/eplb/expert_distribution.py +2 -3
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +24 -0
- sglang/srt/host_shared_memory.py +83 -0
- sglang/srt/layers/attention/ascend_backend.py +132 -22
- sglang/srt/layers/attention/flashattention_backend.py +24 -17
- sglang/srt/layers/attention/flashinfer_backend.py +14 -3
- sglang/srt/layers/attention/flashinfer_mla_backend.py +227 -76
- sglang/srt/layers/attention/triton_backend.py +109 -73
- sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
- sglang/srt/layers/attention/trtllm_mha_backend.py +398 -36
- sglang/srt/layers/attention/trtllm_mla_backend.py +49 -19
- sglang/srt/layers/attention/utils.py +94 -15
- sglang/srt/layers/attention/vision.py +40 -13
- sglang/srt/layers/attention/vision_utils.py +65 -0
- sglang/srt/layers/communicator.py +58 -10
- sglang/srt/layers/dp_attention.py +137 -27
- sglang/srt/layers/elementwise.py +94 -0
- sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
- sglang/srt/layers/layernorm.py +8 -1
- sglang/srt/layers/linear.py +24 -0
- sglang/srt/layers/logits_processor.py +16 -18
- sglang/srt/layers/moe/__init__.py +31 -0
- sglang/srt/layers/moe/ep_moe/layer.py +37 -33
- sglang/srt/layers/moe/fused_moe_native.py +14 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
- sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
- sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
- sglang/srt/layers/moe/moe_runner/base.py +13 -0
- sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
- sglang/srt/layers/moe/router.py +15 -9
- sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
- sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
- sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
- sglang/srt/layers/moe/topk.py +167 -83
- sglang/srt/layers/moe/utils.py +159 -18
- sglang/srt/layers/multimodal.py +156 -40
- sglang/srt/layers/quantization/__init__.py +18 -46
- sglang/srt/layers/quantization/awq.py +22 -23
- sglang/srt/layers/quantization/base_config.py +2 -6
- sglang/srt/layers/quantization/blockwise_int8.py +4 -12
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -29
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -1
- sglang/srt/layers/quantization/fp8.py +127 -119
- sglang/srt/layers/quantization/fp8_kernel.py +195 -24
- sglang/srt/layers/quantization/fp8_utils.py +34 -9
- sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
- sglang/srt/layers/quantization/gptq.py +17 -21
- sglang/srt/layers/quantization/marlin_utils.py +26 -8
- sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
- sglang/srt/layers/quantization/modelopt_quant.py +217 -98
- sglang/srt/layers/quantization/moe_wna16.py +10 -15
- sglang/srt/layers/quantization/mxfp4.py +222 -39
- sglang/srt/layers/quantization/quark/quark.py +390 -0
- sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
- sglang/srt/layers/quantization/unquant.py +34 -70
- sglang/srt/layers/quantization/utils.py +77 -2
- sglang/srt/layers/quantization/w4afp8.py +7 -8
- sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
- sglang/srt/layers/quantization/w8a8_int8.py +5 -13
- sglang/srt/layers/radix_attention.py +6 -0
- sglang/srt/layers/rotary_embedding.py +1 -0
- sglang/srt/layers/sampler.py +5 -2
- sglang/srt/lora/layers.py +6 -2
- sglang/srt/lora/lora_manager.py +21 -22
- sglang/srt/lora/lora_registry.py +3 -3
- sglang/srt/lora/mem_pool.py +26 -24
- sglang/srt/lora/utils.py +10 -12
- sglang/srt/managers/cache_controller.py +80 -19
- sglang/srt/managers/detokenizer_manager.py +10 -2
- sglang/srt/managers/io_struct.py +23 -0
- sglang/srt/managers/mm_utils.py +1 -1
- sglang/srt/managers/schedule_batch.py +22 -48
- sglang/srt/managers/scheduler.py +28 -20
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/template_manager.py +7 -5
- sglang/srt/managers/tokenizer_manager.py +88 -39
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/managers/utils.py +59 -1
- sglang/srt/mem_cache/allocator.py +10 -157
- sglang/srt/mem_cache/allocator_ascend.py +147 -0
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +14 -4
- sglang/srt/mem_cache/memory_pool.py +3 -3
- sglang/srt/mem_cache/memory_pool_host.py +35 -2
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
- sglang/srt/model_executor/cuda_graph_runner.py +33 -33
- sglang/srt/model_executor/forward_batch_info.py +11 -10
- sglang/srt/model_executor/model_runner.py +93 -78
- sglang/srt/model_executor/npu_graph_runner.py +94 -0
- sglang/srt/model_loader/loader.py +24 -6
- sglang/srt/models/dbrx.py +12 -6
- sglang/srt/models/deepseek.py +2 -1
- sglang/srt/models/deepseek_nextn.py +5 -2
- sglang/srt/models/deepseek_v2.py +226 -223
- sglang/srt/models/ernie4.py +2 -2
- sglang/srt/models/glm4_moe.py +27 -65
- sglang/srt/models/glm4_moe_nextn.py +2 -1
- sglang/srt/models/glm4v.py +52 -1
- sglang/srt/models/glm4v_moe.py +8 -11
- sglang/srt/models/gpt_oss.py +41 -76
- sglang/srt/models/granitemoe.py +0 -1
- sglang/srt/models/grok.py +376 -48
- sglang/srt/models/interns1.py +12 -47
- sglang/srt/models/internvl.py +6 -51
- sglang/srt/models/llama.py +10 -2
- sglang/srt/models/llama4.py +18 -7
- sglang/srt/models/minicpm3.py +0 -1
- sglang/srt/models/mixtral.py +0 -2
- sglang/srt/models/nemotron_nas.py +435 -0
- sglang/srt/models/olmoe.py +0 -1
- sglang/srt/models/phi4mm.py +3 -21
- sglang/srt/models/qwen2.py +2 -2
- sglang/srt/models/qwen2_5_vl.py +2 -0
- sglang/srt/models/qwen2_moe.py +23 -23
- sglang/srt/models/qwen3.py +2 -2
- sglang/srt/models/qwen3_classification.py +84 -0
- sglang/srt/models/qwen3_moe.py +27 -43
- sglang/srt/models/step3_vl.py +8 -3
- sglang/srt/models/xverse_moe.py +11 -5
- sglang/srt/multimodal/processors/base_processor.py +3 -3
- sglang/srt/multimodal/processors/internvl.py +7 -2
- sglang/srt/multimodal/processors/llava.py +11 -7
- sglang/srt/offloader.py +433 -0
- sglang/srt/operations.py +22 -2
- sglang/srt/reasoning_parser.py +4 -3
- sglang/srt/sampling/sampling_batch_info.py +7 -4
- sglang/srt/server_args.py +264 -105
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +8 -21
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
- sglang/srt/speculative/eagle_utils.py +36 -13
- sglang/srt/speculative/eagle_worker.py +56 -3
- sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
- sglang/srt/two_batch_overlap.py +20 -19
- sglang/srt/utils.py +68 -70
- sglang/test/runners.py +8 -5
- sglang/test/test_block_fp8.py +5 -6
- sglang/test/test_block_fp8_ep.py +13 -19
- sglang/test/test_cutlass_moe.py +4 -6
- sglang/test/test_cutlass_w4a8_moe.py +4 -3
- sglang/test/test_fp4_moe.py +4 -3
- sglang/test/test_marlin_moe.py +1 -1
- sglang/test/test_marlin_utils.py +1 -1
- sglang/test/test_utils.py +7 -0
- sglang/utils.py +0 -1
- sglang/version.py +1 -1
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/METADATA +11 -11
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/RECORD +201 -171
- sglang/srt/layers/quantization/fp4.py +0 -557
- sglang/srt/layers/quantization/scalar_type.py +0 -352
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/WHEEL +0 -0
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/top_level.txt +0 -0
sglang/srt/layers/moe/utils.py
CHANGED
@@ -1,55 +1,85 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
import importlib.util
|
2
4
|
from enum import Enum
|
3
5
|
from functools import lru_cache
|
6
|
+
from typing import TYPE_CHECKING, Optional
|
4
7
|
|
5
8
|
from packaging import version as pkg_version
|
6
9
|
|
7
|
-
from sglang.srt.
|
8
|
-
|
10
|
+
from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
|
11
|
+
from sglang.srt.layers.dp_attention import (
|
12
|
+
get_attention_dp_size,
|
13
|
+
is_dp_attention_enabled,
|
14
|
+
)
|
15
|
+
from sglang.srt.utils import logger
|
9
16
|
|
10
|
-
|
11
|
-
|
12
|
-
result = global_server_args_dict["enable_flashinfer_trtllm_moe"] and (
|
13
|
-
not importlib.util.find_spec("flashinfer")
|
14
|
-
or pkg_version.parse(__import__("flashinfer").__version__)
|
15
|
-
>= pkg_version.parse("0.2.9rc1")
|
16
|
-
)
|
17
|
-
return result
|
17
|
+
if TYPE_CHECKING:
|
18
|
+
from sglang.srt.server_args import ServerArgs
|
18
19
|
|
19
20
|
|
20
21
|
class MoeA2ABackend(Enum):
|
21
22
|
|
22
|
-
|
23
|
+
NONE = "none"
|
23
24
|
DEEPEP = "deepep"
|
24
25
|
|
25
26
|
@classmethod
|
26
27
|
def _missing_(cls, value):
|
27
28
|
if value is None:
|
28
|
-
return cls.
|
29
|
+
return cls.NONE
|
29
30
|
for member in cls:
|
30
|
-
if value
|
31
|
+
if value == member.value:
|
31
32
|
return member
|
32
33
|
raise ValueError(f"No {cls.__name__} member for value {value}")
|
33
34
|
|
35
|
+
def is_none(self):
|
36
|
+
return self == MoeA2ABackend.NONE
|
37
|
+
|
34
38
|
def is_deepep(self):
|
35
39
|
return self == MoeA2ABackend.DEEPEP
|
36
40
|
|
37
|
-
|
38
|
-
|
41
|
+
|
42
|
+
class MoeRunnerBackend(Enum):
|
43
|
+
|
44
|
+
AUTO = "auto"
|
45
|
+
TRITON = "triton"
|
46
|
+
TRITON_KERNEL = "triton_kernel"
|
47
|
+
FLASHINFER = "flashinfer_trtllm"
|
48
|
+
FLASHINFER_CUTLASS = "flashinfer_cutlass"
|
49
|
+
FLASHINFER_MXFP4 = "flashinfer_mxfp4"
|
50
|
+
|
51
|
+
def is_auto(self):
|
52
|
+
return self == MoeRunnerBackend.AUTO
|
53
|
+
|
54
|
+
def is_triton(self):
|
55
|
+
return self == MoeRunnerBackend.TRITON
|
56
|
+
|
57
|
+
def is_triton_kernel(self):
|
58
|
+
return self == MoeRunnerBackend.TRITON_KERNEL
|
59
|
+
|
60
|
+
def is_flashinfer_trtllm(self):
|
61
|
+
return self == MoeRunnerBackend.FLASHINFER
|
62
|
+
|
63
|
+
def is_flashinfer_cutlass(self):
|
64
|
+
return self == MoeRunnerBackend.FLASHINFER_CUTLASS
|
65
|
+
|
66
|
+
def is_flashinfer_mxfp4(self):
|
67
|
+
return self == MoeRunnerBackend.FLASHINFER_MXFP4
|
39
68
|
|
40
69
|
|
41
70
|
class DeepEPMode(Enum):
|
71
|
+
|
42
72
|
NORMAL = "normal"
|
43
73
|
LOW_LATENCY = "low_latency"
|
44
74
|
AUTO = "auto"
|
45
75
|
|
46
|
-
def enable_normal(self):
|
76
|
+
def enable_normal(self) -> bool:
|
47
77
|
return self in [DeepEPMode.NORMAL, DeepEPMode.AUTO]
|
48
78
|
|
49
|
-
def enable_low_latency(self):
|
79
|
+
def enable_low_latency(self) -> bool:
|
50
80
|
return self in [DeepEPMode.LOW_LATENCY, DeepEPMode.AUTO]
|
51
81
|
|
52
|
-
def resolve(self, is_extend_in_batch: bool):
|
82
|
+
def resolve(self, is_extend_in_batch: bool) -> DeepEPMode:
|
53
83
|
if self != DeepEPMode.AUTO:
|
54
84
|
return self
|
55
85
|
|
@@ -57,3 +87,114 @@ class DeepEPMode(Enum):
|
|
57
87
|
return DeepEPMode.NORMAL
|
58
88
|
else:
|
59
89
|
return DeepEPMode.LOW_LATENCY
|
90
|
+
|
91
|
+
def is_normal(self) -> bool:
|
92
|
+
return self == DeepEPMode.NORMAL
|
93
|
+
|
94
|
+
def is_low_latency(self) -> bool:
|
95
|
+
return self == DeepEPMode.LOW_LATENCY
|
96
|
+
|
97
|
+
def is_auto(self) -> bool:
|
98
|
+
return self == DeepEPMode.AUTO
|
99
|
+
|
100
|
+
|
101
|
+
MOE_A2A_BACKEND: Optional[MoeA2ABackend] = None
|
102
|
+
MOE_RUNNER_BACKEND: Optional[MoeRunnerBackend] = None
|
103
|
+
DEEPEP_MODE: Optional[DeepEPMode] = None
|
104
|
+
IS_TBO_ENABLED: Optional[bool] = None
|
105
|
+
TBO_TOKEN_DISTRIBUTION_THRESHOLD: Optional[float] = None
|
106
|
+
DEEPEP_CONFIG: Optional[str] = None
|
107
|
+
DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER: Optional[bool] = None
|
108
|
+
|
109
|
+
|
110
|
+
def initialize_moe_config(server_args: ServerArgs):
|
111
|
+
global MOE_A2A_BACKEND
|
112
|
+
global MOE_RUNNER_BACKEND
|
113
|
+
global DEEPEP_MODE
|
114
|
+
global DEEPEP_CONFIG
|
115
|
+
global IS_TBO_ENABLED
|
116
|
+
global TBO_TOKEN_DISTRIBUTION_THRESHOLD
|
117
|
+
global DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER
|
118
|
+
|
119
|
+
MOE_A2A_BACKEND = MoeA2ABackend(server_args.moe_a2a_backend)
|
120
|
+
MOE_RUNNER_BACKEND = MoeRunnerBackend(server_args.moe_runner_backend)
|
121
|
+
DEEPEP_MODE = DeepEPMode(server_args.deepep_mode)
|
122
|
+
DEEPEP_CONFIG = server_args.deepep_config or ""
|
123
|
+
IS_TBO_ENABLED = server_args.enable_two_batch_overlap
|
124
|
+
TBO_TOKEN_DISTRIBUTION_THRESHOLD = server_args.tbo_token_distribution_threshold
|
125
|
+
DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER = (
|
126
|
+
server_args.disable_flashinfer_cutlass_moe_fp4_allgather
|
127
|
+
)
|
128
|
+
|
129
|
+
|
130
|
+
def get_moe_a2a_backend() -> MoeA2ABackend:
|
131
|
+
global MOE_A2A_BACKEND
|
132
|
+
if MOE_A2A_BACKEND is None:
|
133
|
+
logger.warning("MOE_A2A_BACKEND is not initialized, using default backend")
|
134
|
+
MOE_A2A_BACKEND = MoeA2ABackend(None)
|
135
|
+
return MOE_A2A_BACKEND
|
136
|
+
|
137
|
+
|
138
|
+
def get_moe_runner_backend() -> MoeRunnerBackend:
|
139
|
+
global MOE_RUNNER_BACKEND
|
140
|
+
if MOE_RUNNER_BACKEND is None:
|
141
|
+
logger.warning("MOE_RUNNER_BACKEND is not initialized, using triton backend")
|
142
|
+
MOE_RUNNER_BACKEND = MoeRunnerBackend("triton")
|
143
|
+
return MOE_RUNNER_BACKEND
|
144
|
+
|
145
|
+
|
146
|
+
def get_deepep_mode() -> DeepEPMode:
|
147
|
+
global DEEPEP_MODE
|
148
|
+
if DEEPEP_MODE is None:
|
149
|
+
logger.warning("DEEPEP_MODE is not initialized, using auto mode")
|
150
|
+
DEEPEP_MODE = DeepEPMode("auto")
|
151
|
+
return DEEPEP_MODE
|
152
|
+
|
153
|
+
|
154
|
+
def get_deepep_config() -> str:
|
155
|
+
global DEEPEP_CONFIG
|
156
|
+
if DEEPEP_CONFIG is None:
|
157
|
+
logger.warning("DEEPEP_CONFIG is not initialized, using default config")
|
158
|
+
DEEPEP_CONFIG = ""
|
159
|
+
return DEEPEP_CONFIG
|
160
|
+
|
161
|
+
|
162
|
+
def is_tbo_enabled() -> bool:
|
163
|
+
global IS_TBO_ENABLED
|
164
|
+
if IS_TBO_ENABLED is None:
|
165
|
+
logger.warning("IS_TBO_ENABLED is not initialized, using False")
|
166
|
+
IS_TBO_ENABLED = False
|
167
|
+
return IS_TBO_ENABLED
|
168
|
+
|
169
|
+
|
170
|
+
def get_tbo_token_distribution_threshold() -> float:
|
171
|
+
global TBO_TOKEN_DISTRIBUTION_THRESHOLD
|
172
|
+
if TBO_TOKEN_DISTRIBUTION_THRESHOLD is None:
|
173
|
+
logger.warning(
|
174
|
+
"TBO_TOKEN_DISTRIBUTION_THRESHOLD is not initialized, using 0.48"
|
175
|
+
)
|
176
|
+
TBO_TOKEN_DISTRIBUTION_THRESHOLD = 0.48
|
177
|
+
return TBO_TOKEN_DISTRIBUTION_THRESHOLD
|
178
|
+
|
179
|
+
|
180
|
+
@lru_cache(maxsize=1)
|
181
|
+
def should_use_flashinfer_trtllm_moe():
|
182
|
+
result = get_moe_runner_backend().is_flashinfer_trtllm() and (
|
183
|
+
not importlib.util.find_spec("flashinfer")
|
184
|
+
or pkg_version.parse(__import__("flashinfer").__version__)
|
185
|
+
>= pkg_version.parse("0.2.9rc1")
|
186
|
+
)
|
187
|
+
return result
|
188
|
+
|
189
|
+
|
190
|
+
@lru_cache(maxsize=1)
|
191
|
+
def should_use_flashinfer_cutlass_moe_fp4_allgather():
|
192
|
+
"""
|
193
|
+
Perform FP4 quantize before all-gather for flashinfer cutlass moe to reduce communication cost for high-throughput serving.
|
194
|
+
"""
|
195
|
+
return (
|
196
|
+
not DISABLE_FLASHINFER_CUTLASS_MOE_FP4_ALLGATHER
|
197
|
+
and get_moe_runner_backend().is_flashinfer_cutlass()
|
198
|
+
and is_dp_attention_enabled()
|
199
|
+
and get_moe_expert_parallel_world_size() == get_attention_dp_size()
|
200
|
+
)
|
sglang/srt/layers/multimodal.py
CHANGED
@@ -17,57 +17,173 @@ import torch
|
|
17
17
|
import triton
|
18
18
|
import triton.language as tl
|
19
19
|
|
20
|
+
FMIX32_C1 = 0x85EBCA6B
|
21
|
+
FMIX32_C2 = 0xC2B2AE35
|
22
|
+
POS_C1 = 0x27D4EB2D
|
23
|
+
POS_C2 = 0x165667B1
|
24
|
+
|
25
|
+
|
26
|
+
@triton.jit
|
27
|
+
def _rotl32(x, r: tl.constexpr):
|
28
|
+
return (x << r) | (x >> (32 - r))
|
29
|
+
|
30
|
+
|
31
|
+
@triton.jit
|
32
|
+
def _fmix32(x, C1: tl.constexpr, C2: tl.constexpr):
|
33
|
+
c1 = tl.full((), C1, tl.uint32)
|
34
|
+
c2 = tl.full((), C2, tl.uint32)
|
35
|
+
x ^= x >> 16
|
36
|
+
x = x * c1
|
37
|
+
x ^= x >> 13
|
38
|
+
x = x * c2
|
39
|
+
x ^= x >> 16
|
40
|
+
return x
|
41
|
+
|
20
42
|
|
21
43
|
@triton.jit
|
22
|
-
def
|
23
|
-
|
24
|
-
|
25
|
-
|
26
|
-
|
27
|
-
|
28
|
-
|
44
|
+
def hash_tiles32_kernel_blocked(
|
45
|
+
in_ptr,
|
46
|
+
out_ptr,
|
47
|
+
n_u32,
|
48
|
+
seed1,
|
49
|
+
seed2,
|
50
|
+
FM_C1: tl.constexpr,
|
51
|
+
FM_C2: tl.constexpr,
|
52
|
+
POS_A: tl.constexpr,
|
53
|
+
POS_B: tl.constexpr,
|
54
|
+
TILE: tl.constexpr,
|
55
|
+
BLOCK: tl.constexpr,
|
56
|
+
USE_CG: tl.constexpr,
|
29
57
|
):
|
30
58
|
pid = tl.program_id(axis=0)
|
31
|
-
|
32
|
-
|
33
|
-
|
59
|
+
base = pid * TILE
|
60
|
+
|
61
|
+
s1 = tl.full((), seed1, tl.uint32)
|
62
|
+
s2 = tl.full((), seed2, tl.uint32)
|
63
|
+
posA = tl.full((), POS_A, tl.uint32)
|
64
|
+
posB = tl.full((), POS_B, tl.uint32)
|
65
|
+
|
66
|
+
h1 = tl.zeros((), dtype=tl.uint32)
|
67
|
+
h2 = tl.zeros((), dtype=tl.uint32)
|
68
|
+
|
69
|
+
for off in tl.static_range(0, TILE, BLOCK):
|
70
|
+
idx = base + off + tl.arange(0, BLOCK)
|
71
|
+
m = idx < n_u32
|
34
72
|
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
73
|
+
if USE_CG:
|
74
|
+
v = tl.load(in_ptr + idx, mask=m, other=0, cache_modifier=".cg")
|
75
|
+
else:
|
76
|
+
v = tl.load(in_ptr + idx, mask=m, other=0)
|
77
|
+
v = v.to(tl.uint32)
|
78
|
+
|
79
|
+
iu = idx.to(tl.uint32)
|
80
|
+
p1 = (iu * posA + s1) ^ _rotl32(iu, 15)
|
81
|
+
p2 = (iu * posB + s2) ^ _rotl32(iu, 13)
|
82
|
+
|
83
|
+
k1 = _fmix32(v ^ p1, C1=FM_C1, C2=FM_C2)
|
84
|
+
k2 = _fmix32(v ^ p2, C1=FM_C1, C2=FM_C2)
|
85
|
+
|
86
|
+
zero32 = tl.zeros_like(k1)
|
87
|
+
k1 = tl.where(m, k1, zero32)
|
88
|
+
k2 = tl.where(m, k2, zero32)
|
89
|
+
|
90
|
+
h1 += tl.sum(k1, axis=0).to(tl.uint32)
|
91
|
+
h2 += tl.sum(k2, axis=0).to(tl.uint32)
|
92
|
+
|
93
|
+
nbytes = tl.full((), n_u32 * 4, tl.uint32)
|
94
|
+
h1 ^= nbytes
|
95
|
+
h2 ^= nbytes
|
96
|
+
h1 = _fmix32(h1, C1=FM_C1, C2=FM_C2)
|
97
|
+
h2 = (
|
98
|
+
_fmix32(h2, C1=FMIX32_C1, C2=FMIX32_C2)
|
99
|
+
if False
|
100
|
+
else _fmix32(h2, C1=FM_C1, C2=FM_C2)
|
101
|
+
)
|
102
|
+
|
103
|
+
out = (h1.to(tl.uint64) << 32) | h2.to(tl.uint64)
|
104
|
+
tl.store(out_ptr + pid, out)
|
105
|
+
|
106
|
+
|
107
|
+
@triton.jit
|
108
|
+
def add_tree_reduce_u64_kernel(in_ptr, out_ptr, n_elems, CHUNK: tl.constexpr):
|
109
|
+
pid = tl.program_id(axis=0)
|
110
|
+
start = pid * CHUNK
|
111
|
+
h = tl.zeros((), dtype=tl.uint64)
|
112
|
+
for i in tl.static_range(0, CHUNK):
|
113
|
+
idx = start + i
|
114
|
+
m = idx < n_elems
|
115
|
+
v = tl.load(in_ptr + idx, mask=m, other=0).to(tl.uint64)
|
116
|
+
h += v
|
117
|
+
tl.store(out_ptr + pid, h)
|
41
118
|
|
42
|
-
tl.store(output_ptr + offsets, hash_val, mask=mask)
|
43
119
|
|
120
|
+
def _as_uint32_words(t: torch.Tensor) -> torch.Tensor:
|
121
|
+
assert t.is_cuda, "Use .cuda() first"
|
122
|
+
tb = t.contiguous().view(torch.uint8)
|
123
|
+
nbytes = tb.numel()
|
124
|
+
pad = (4 - (nbytes & 3)) & 3
|
125
|
+
if pad:
|
126
|
+
tb_p = torch.empty(nbytes + pad, dtype=torch.uint8, device=tb.device)
|
127
|
+
tb_p[:nbytes].copy_(tb)
|
128
|
+
tb_p[nbytes:].zero_()
|
129
|
+
tb = tb_p
|
130
|
+
return tb.view(torch.uint32)
|
44
131
|
|
45
|
-
PRIME_1 = -(11400714785074694791 ^ 0xFFFFFFFFFFFFFFFF) - 1
|
46
|
-
PRIME_2 = -(14029467366897019727 ^ 0xFFFFFFFFFFFFFFFF) - 1
|
47
132
|
|
133
|
+
def _final_splitmix64(x: int) -> int:
|
134
|
+
mask = (1 << 64) - 1
|
135
|
+
x &= mask
|
136
|
+
x ^= x >> 30
|
137
|
+
x = (x * 0xBF58476D1CE4E5B9) & mask
|
138
|
+
x ^= x >> 27
|
139
|
+
x = (x * 0x94D049BB133111EB) & mask
|
140
|
+
x ^= x >> 31
|
141
|
+
return x
|
48
142
|
|
49
|
-
def gpu_tensor_hash(tensor: torch.Tensor) -> int:
|
50
|
-
assert tensor.is_cuda
|
51
|
-
tensor = tensor.contiguous().view(torch.int32)
|
52
|
-
n = tensor.numel()
|
53
|
-
BLOCK_SIZE = 1024
|
54
|
-
grid = (triton.cdiv(n, BLOCK_SIZE),)
|
55
143
|
|
56
|
-
|
144
|
+
@torch.inference_mode()
|
145
|
+
def gpu_tensor_hash(
|
146
|
+
tensor: torch.Tensor,
|
147
|
+
*,
|
148
|
+
seed: int = 0x243F6A88,
|
149
|
+
tile_words: int = 8192,
|
150
|
+
block_words: int = 256,
|
151
|
+
reduce_chunk: int = 1024,
|
152
|
+
num_warps: int = 4,
|
153
|
+
num_stages: int = 4,
|
154
|
+
use_cg: bool = True,
|
155
|
+
) -> int:
|
156
|
+
assert tensor.is_cuda, "Use .cuda() first"
|
157
|
+
u32 = _as_uint32_words(tensor)
|
158
|
+
n = u32.numel()
|
159
|
+
if n == 0:
|
160
|
+
return 0
|
57
161
|
|
58
|
-
|
59
|
-
|
60
|
-
|
61
|
-
|
62
|
-
|
63
|
-
|
64
|
-
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
|
162
|
+
grid1 = (triton.cdiv(n, tile_words),)
|
163
|
+
partials = torch.empty(grid1[0], dtype=torch.uint64, device=u32.device)
|
164
|
+
hash_tiles32_kernel_blocked[grid1](
|
165
|
+
u32,
|
166
|
+
partials,
|
167
|
+
n,
|
168
|
+
seed1=seed & 0xFFFFFFFF,
|
169
|
+
seed2=((seed * 0x9E3779B1) ^ 0xDEADBEEF) & 0xFFFFFFFF,
|
170
|
+
FM_C1=FMIX32_C1,
|
171
|
+
FM_C2=FMIX32_C2,
|
172
|
+
POS_A=POS_C1,
|
173
|
+
POS_B=POS_C2,
|
174
|
+
TILE=tile_words,
|
175
|
+
BLOCK=block_words,
|
176
|
+
USE_CG=use_cg,
|
177
|
+
num_warps=num_warps,
|
178
|
+
num_stages=num_stages,
|
179
|
+
)
|
69
180
|
|
70
|
-
|
71
|
-
|
181
|
+
cur = partials
|
182
|
+
while cur.numel() > 1:
|
183
|
+
n_elems = cur.numel()
|
184
|
+
grid2 = (triton.cdiv(n_elems, reduce_chunk),)
|
185
|
+
nxt = torch.empty(grid2[0], dtype=torch.uint64, device=cur.device)
|
186
|
+
add_tree_reduce_u64_kernel[grid2](cur, nxt, n_elems, CHUNK=reduce_chunk)
|
187
|
+
cur = nxt
|
72
188
|
|
73
|
-
return
|
189
|
+
return _final_splitmix64(int(cur.item()))
|
@@ -16,7 +16,6 @@ try:
|
|
16
16
|
)
|
17
17
|
from vllm.model_executor.layers.quantization.deepspeedfp import DeepSpeedFPConfig
|
18
18
|
from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
|
19
|
-
from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config
|
20
19
|
from vllm.model_executor.layers.quantization.gguf import GGUFConfig
|
21
20
|
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
|
22
21
|
GPTQMarlin24Config,
|
@@ -37,9 +36,9 @@ except ImportError as e:
|
|
37
36
|
|
38
37
|
AQLMConfig = BitsAndBytesConfig = CompressedTensorsConfig = DeepSpeedFPConfig = (
|
39
38
|
ExpertsInt8Config
|
40
|
-
) =
|
41
|
-
|
42
|
-
)
|
39
|
+
) = GGUFConfig = GPTQMarlin24Config = MarlinConfig = QQQConfig = Int8TpuConfig = (
|
40
|
+
DummyConfig
|
41
|
+
)
|
43
42
|
|
44
43
|
|
45
44
|
from sglang.srt.layers.quantization.awq import AWQConfig, AWQMarlinConfig
|
@@ -48,20 +47,9 @@ from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config
|
|
48
47
|
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import (
|
49
48
|
CompressedTensorsConfig,
|
50
49
|
)
|
51
|
-
from sglang.srt.utils import is_cuda, is_hip, mxfp_supported
|
52
|
-
|
53
|
-
is_mxfp_supported = mxfp_supported()
|
54
|
-
if is_mxfp_supported:
|
55
|
-
from sglang.srt.layers.quantization.fp4 import MxFp4Config
|
56
|
-
|
57
50
|
from sglang.srt.layers.quantization.fp8 import Fp8Config
|
58
|
-
from sglang.srt.layers.quantization.
|
59
|
-
|
60
|
-
GPTQLinearMethod,
|
61
|
-
GPTQMarlinConfig,
|
62
|
-
GPTQMarlinLinearMethod,
|
63
|
-
GPTQMarlinMoEMethod,
|
64
|
-
)
|
51
|
+
from sglang.srt.layers.quantization.fpgemm_fp8 import FBGEMMFp8Config
|
52
|
+
from sglang.srt.layers.quantization.gptq import GPTQConfig, GPTQMarlinConfig
|
65
53
|
from sglang.srt.layers.quantization.modelopt_quant import (
|
66
54
|
ModelOptFp4Config,
|
67
55
|
ModelOptFp8Config,
|
@@ -70,10 +58,12 @@ from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config
|
|
70
58
|
from sglang.srt.layers.quantization.mxfp4 import Mxfp4Config
|
71
59
|
from sglang.srt.layers.quantization.petit import PetitNvFp4Config
|
72
60
|
from sglang.srt.layers.quantization.qoq import QoQConfig
|
73
|
-
from sglang.srt.layers.quantization.utils import get_linear_quant_method
|
74
61
|
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config
|
75
62
|
from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
|
76
63
|
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
|
64
|
+
from sglang.srt.utils import is_cuda, is_hip, mxfp_supported
|
65
|
+
|
66
|
+
_is_mxfp_supported = mxfp_supported()
|
77
67
|
|
78
68
|
if TYPE_CHECKING:
|
79
69
|
from sglang.srt.layers.moe.topk import TopKOutput
|
@@ -86,11 +76,16 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
|
86
76
|
"modelopt_fp4": ModelOptFp4Config,
|
87
77
|
"w8a8_int8": W8A8Int8Config,
|
88
78
|
"w8a8_fp8": W8A8Fp8Config,
|
79
|
+
"awq": AWQConfig,
|
80
|
+
"awq_marlin": AWQMarlinConfig,
|
81
|
+
"gptq": GPTQConfig,
|
82
|
+
"gptq_marlin": GPTQMarlinConfig,
|
89
83
|
"moe_wna16": MoeWNA16Config,
|
90
84
|
"compressed-tensors": CompressedTensorsConfig,
|
91
85
|
"qoq": QoQConfig,
|
92
86
|
"w4afp8": W4AFp8Config,
|
93
87
|
"petit_nvfp4": PetitNvFp4Config,
|
88
|
+
"fbgemm_fp8": FBGEMMFp8Config,
|
94
89
|
}
|
95
90
|
|
96
91
|
|
@@ -101,29 +96,26 @@ if is_cuda():
|
|
101
96
|
"mxfp4": Mxfp4Config,
|
102
97
|
}
|
103
98
|
)
|
104
|
-
elif
|
99
|
+
elif _is_mxfp_supported and is_hip():
|
100
|
+
from sglang.srt.layers.quantization.quark.quark import QuarkConfig
|
101
|
+
|
105
102
|
BASE_QUANTIZATION_METHODS.update(
|
106
103
|
{
|
107
|
-
"quark":
|
108
|
-
"mxfp4":
|
104
|
+
"quark": QuarkConfig,
|
105
|
+
"mxfp4": Mxfp4Config,
|
109
106
|
}
|
110
107
|
)
|
111
108
|
# VLLM-dependent quantization methods
|
112
109
|
VLLM_QUANTIZATION_METHODS = {
|
113
110
|
"aqlm": AQLMConfig,
|
114
|
-
"awq": AWQConfig,
|
115
111
|
"deepspeedfp": DeepSpeedFPConfig,
|
116
112
|
"tpu_int8": Int8TpuConfig,
|
117
|
-
"fbgemm_fp8": FBGEMMFp8Config,
|
118
113
|
"marlin": MarlinConfig,
|
119
114
|
"gguf": GGUFConfig,
|
120
115
|
"gptq_marlin_24": GPTQMarlin24Config,
|
121
|
-
"awq_marlin": AWQMarlinConfig,
|
122
116
|
"bitsandbytes": BitsAndBytesConfig,
|
123
117
|
"qqq": QQQConfig,
|
124
118
|
"experts_int8": ExpertsInt8Config,
|
125
|
-
"gptq_marlin": GPTQMarlinConfig,
|
126
|
-
"gptq": GPTQConfig,
|
127
119
|
}
|
128
120
|
|
129
121
|
QUANTIZATION_METHODS = {**BASE_QUANTIZATION_METHODS, **VLLM_QUANTIZATION_METHODS}
|
@@ -145,23 +137,6 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
|
|
145
137
|
return QUANTIZATION_METHODS[quantization]
|
146
138
|
|
147
139
|
|
148
|
-
def gptq_get_quant_method(self, layer, prefix):
|
149
|
-
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
150
|
-
|
151
|
-
if isinstance(layer, FusedMoE):
|
152
|
-
return GPTQMarlinMoEMethod(self)
|
153
|
-
|
154
|
-
if isinstance(self, GPTQConfig):
|
155
|
-
return get_linear_quant_method(
|
156
|
-
self, layer, prefix=prefix, linear_method_cls=GPTQLinearMethod
|
157
|
-
)
|
158
|
-
elif isinstance(self, GPTQMarlinConfig):
|
159
|
-
return get_linear_quant_method(
|
160
|
-
self, layer, prefix=prefix, linear_method_cls=GPTQMarlinLinearMethod
|
161
|
-
)
|
162
|
-
return None
|
163
|
-
|
164
|
-
|
165
140
|
original_isinstance = builtins.isinstance
|
166
141
|
|
167
142
|
|
@@ -239,10 +214,7 @@ def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
|
|
239
214
|
|
240
215
|
def monkey_patch_quant_configs():
|
241
216
|
"""Apply all monkey patches in one place."""
|
242
|
-
setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
|
243
|
-
setattr(GPTQConfig, "get_quant_method", gptq_get_quant_method)
|
244
217
|
|
245
|
-
monkey_patch_moe_apply(GPTQMarlinMoEMethod)
|
246
218
|
monkey_patch_moe_apply(CompressedTensorsW8A8Fp8MoEMethod)
|
247
219
|
monkey_patch_moe_apply(CompressedTensorsWNA16MoEMethod)
|
248
220
|
|
@@ -29,29 +29,26 @@ from sglang.srt.layers.quantization.marlin_utils import (
|
|
29
29
|
verify_marlin_supported,
|
30
30
|
verify_marlin_supports_shape,
|
31
31
|
)
|
32
|
-
from sglang.srt.layers.quantization.scalar_type import scalar_types
|
33
32
|
from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
|
34
|
-
from sglang.srt.layers.quantization.utils import replace_parameter
|
33
|
+
from sglang.srt.layers.quantization.utils import get_scalar_types, replace_parameter
|
35
34
|
|
36
35
|
if TYPE_CHECKING:
|
37
|
-
from sglang.srt.layers.moe.
|
38
|
-
|
39
|
-
try:
|
40
|
-
from vllm import _custom_ops as ops
|
41
|
-
|
42
|
-
warnings.warn(
|
43
|
-
f"Using kernels directly from vllm. This might lead to performance degradation or "
|
44
|
-
f"missing functionalities as certain kernels may not be optimized. "
|
45
|
-
)
|
46
|
-
except ImportError:
|
47
|
-
ops = None
|
36
|
+
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
37
|
+
from sglang.srt.layers.moe.topk import StandardTopKOutput
|
48
38
|
|
49
39
|
from sglang.srt.utils import is_cuda, is_hip
|
50
40
|
|
51
41
|
_is_cuda = is_cuda()
|
52
42
|
_is_hip = is_hip()
|
53
43
|
if _is_cuda:
|
54
|
-
from sgl_kernel import
|
44
|
+
from sgl_kernel import (
|
45
|
+
awq_dequantize,
|
46
|
+
awq_marlin_moe_repack,
|
47
|
+
awq_marlin_repack,
|
48
|
+
fused_marlin_moe,
|
49
|
+
)
|
50
|
+
|
51
|
+
|
55
52
|
elif _is_hip:
|
56
53
|
from sglang.srt.layers.quantization.awq_triton import (
|
57
54
|
awq_dequantize_triton as awq_dequantize,
|
@@ -64,6 +61,9 @@ else:
|
|
64
61
|
logger = logging.getLogger(__name__)
|
65
62
|
|
66
63
|
|
64
|
+
ScalarType, scalar_types = get_scalar_types()
|
65
|
+
|
66
|
+
|
67
67
|
def is_layer_skipped_awq(prefix: str, modules_to_not_convert: List[str]):
|
68
68
|
return any(module_name in prefix for module_name in modules_to_not_convert)
|
69
69
|
|
@@ -516,7 +516,7 @@ class AWQMarlinLinearMethod(LinearMethodBase):
|
|
516
516
|
layer.workspace = marlin_make_workspace(device)
|
517
517
|
|
518
518
|
# Repack weights from AWQ format to marlin format.
|
519
|
-
marlin_qweight =
|
519
|
+
marlin_qweight = awq_marlin_repack(
|
520
520
|
layer.qweight,
|
521
521
|
size_k=layer.input_size_per_partition,
|
522
522
|
size_n=layer.output_size_per_partition,
|
@@ -684,7 +684,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
|
684
684
|
requires_grad=False,
|
685
685
|
)
|
686
686
|
|
687
|
-
marlin_w13_qweight =
|
687
|
+
marlin_w13_qweight = awq_marlin_moe_repack(
|
688
688
|
layer.w13_qweight,
|
689
689
|
layer.w13_g_idx_sort_indices,
|
690
690
|
size_k=layer.w13_qweight.shape[1],
|
@@ -693,7 +693,7 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
|
693
693
|
)
|
694
694
|
replace_parameter(layer, "w13_qweight", marlin_w13_qweight)
|
695
695
|
|
696
|
-
marlin_w2_qweight =
|
696
|
+
marlin_w2_qweight = awq_marlin_moe_repack(
|
697
697
|
layer.w2_qweight,
|
698
698
|
layer.w2_g_idx_sort_indices,
|
699
699
|
size_k=layer.w2_qweight.shape[1],
|
@@ -740,13 +740,12 @@ class AWQMoEMethod(FusedMoEMethodBase):
|
|
740
740
|
self,
|
741
741
|
layer: torch.nn.Module,
|
742
742
|
x: torch.Tensor,
|
743
|
-
topk_output:
|
744
|
-
|
745
|
-
activation: str = "silu",
|
746
|
-
**kwargs,
|
743
|
+
topk_output: StandardTopKOutput,
|
744
|
+
moe_runner_config: MoeRunnerConfig,
|
747
745
|
) -> torch.Tensor:
|
748
|
-
|
749
|
-
|
746
|
+
assert (
|
747
|
+
moe_runner_config.activation == "silu"
|
748
|
+
), "Only SiLU activation is supported."
|
750
749
|
|
751
750
|
# The input must currently be float16
|
752
751
|
orig_dtype = x.dtype
|
@@ -9,6 +9,7 @@ import torch
|
|
9
9
|
from torch import nn
|
10
10
|
|
11
11
|
if TYPE_CHECKING:
|
12
|
+
from sglang.srt.layers.moe.moe_runner import MoeRunnerConfig
|
12
13
|
from sglang.srt.layers.moe.topk import TopKOutput
|
13
14
|
|
14
15
|
|
@@ -100,12 +101,7 @@ class FusedMoEMethodBase(QuantizeMethodBase):
|
|
100
101
|
layer: torch.nn.Module,
|
101
102
|
x: torch.Tensor,
|
102
103
|
topk_output: TopKOutput,
|
103
|
-
|
104
|
-
activation: str = "silu",
|
105
|
-
apply_router_weight_on_input: bool = False,
|
106
|
-
inplace: bool = True,
|
107
|
-
no_combine: bool = False,
|
108
|
-
routed_scaling_factor: Optional[float] = None,
|
104
|
+
moe_runner_config: MoeRunnerConfig,
|
109
105
|
) -> torch.Tensor:
|
110
106
|
raise NotImplementedError
|
111
107
|
|