sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +2 -1
- sglang/eval/loogle_eval.py +7 -0
- sglang/srt/_custom_ops.py +29 -1
- sglang/srt/configs/deepseekvl2.py +11 -2
- sglang/srt/configs/internvl.py +3 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/model_config.py +10 -8
- sglang/srt/configs/update_config.py +3 -1
- sglang/srt/conversation.py +2 -1
- sglang/srt/custom_op.py +5 -2
- sglang/srt/disaggregation/common/conn.py +34 -6
- sglang/srt/disaggregation/decode.py +9 -1
- sglang/srt/disaggregation/mini_lb.py +3 -2
- sglang/srt/disaggregation/mooncake/conn.py +93 -76
- sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
- sglang/srt/disaggregation/nixl/conn.py +17 -13
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
- sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
- sglang/srt/distributed/parallel_state.py +103 -15
- sglang/srt/entrypoints/engine.py +31 -33
- sglang/srt/entrypoints/http_server.py +20 -32
- sglang/srt/entrypoints/openai/protocol.py +3 -3
- sglang/srt/entrypoints/openai/serving_chat.py +48 -6
- sglang/srt/eplb/expert_location_dispatch.py +1 -1
- sglang/srt/function_call/base_format_detector.py +74 -12
- sglang/srt/function_call/deepseekv3_detector.py +26 -11
- sglang/srt/function_call/ebnf_composer.py +95 -63
- sglang/srt/function_call/function_call_parser.py +4 -2
- sglang/srt/function_call/kimik2_detector.py +41 -16
- sglang/srt/function_call/llama32_detector.py +6 -3
- sglang/srt/function_call/mistral_detector.py +11 -3
- sglang/srt/function_call/pythonic_detector.py +16 -14
- sglang/srt/function_call/qwen25_detector.py +12 -3
- sglang/srt/function_call/qwen3_coder_detector.py +151 -0
- sglang/srt/hf_transformers_utils.py +0 -1
- sglang/srt/layers/activation.py +24 -3
- sglang/srt/layers/attention/base_attn_backend.py +3 -1
- sglang/srt/layers/attention/flashattention_backend.py +3 -3
- sglang/srt/layers/attention/flashinfer_backend.py +40 -1
- sglang/srt/layers/communicator.py +12 -12
- sglang/srt/layers/dp_attention.py +72 -24
- sglang/srt/layers/linear.py +13 -102
- sglang/srt/layers/logits_processor.py +34 -24
- sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
- sglang/srt/layers/moe/ep_moe/layer.py +23 -402
- sglang/srt/layers/moe/fused_moe_native.py +7 -47
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +54 -263
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
- sglang/srt/layers/moe/topk.py +190 -23
- sglang/srt/layers/quantization/__init__.py +20 -134
- sglang/srt/layers/quantization/awq.py +578 -11
- sglang/srt/layers/quantization/awq_triton.py +339 -0
- sglang/srt/layers/quantization/base_config.py +85 -10
- sglang/srt/layers/quantization/blockwise_int8.py +17 -55
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +23 -79
- sglang/srt/layers/quantization/fp8.py +273 -62
- sglang/srt/layers/quantization/fp8_kernel.py +210 -46
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gptq.py +501 -143
- sglang/srt/layers/quantization/marlin_utils.py +790 -0
- sglang/srt/layers/quantization/modelopt_quant.py +34 -112
- sglang/srt/layers/quantization/moe_wna16.py +45 -49
- sglang/srt/layers/quantization/petit.py +252 -0
- sglang/srt/layers/quantization/petit_utils.py +104 -0
- sglang/srt/layers/quantization/qoq.py +7 -6
- sglang/srt/layers/quantization/scalar_type.py +352 -0
- sglang/srt/layers/quantization/unquant.py +422 -0
- sglang/srt/layers/quantization/utils.py +340 -9
- sglang/srt/layers/quantization/w4afp8.py +8 -4
- sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
- sglang/srt/layers/quantization/w8a8_int8.py +51 -115
- sglang/srt/layers/radix_attention.py +5 -3
- sglang/srt/layers/vocab_parallel_embedding.py +1 -41
- sglang/srt/lora/lora.py +0 -4
- sglang/srt/lora/lora_manager.py +162 -164
- sglang/srt/lora/lora_registry.py +124 -0
- sglang/srt/lora/mem_pool.py +83 -35
- sglang/srt/lora/utils.py +12 -5
- sglang/srt/managers/cache_controller.py +288 -0
- sglang/srt/managers/io_struct.py +60 -30
- sglang/srt/managers/mm_utils.py +7 -8
- sglang/srt/managers/schedule_batch.py +163 -113
- sglang/srt/managers/schedule_policy.py +68 -27
- sglang/srt/managers/scheduler.py +256 -86
- sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
- sglang/srt/managers/tokenizer_manager.py +38 -27
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/allocator.py +74 -23
- sglang/srt/mem_cache/base_prefix_cache.py +14 -2
- sglang/srt/mem_cache/chunk_cache.py +5 -2
- sglang/srt/mem_cache/hicache_storage.py +168 -0
- sglang/srt/mem_cache/hiradix_cache.py +194 -5
- sglang/srt/mem_cache/memory_pool.py +16 -1
- sglang/srt/mem_cache/memory_pool_host.py +44 -2
- sglang/srt/mem_cache/radix_cache.py +26 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +66 -31
- sglang/srt/model_executor/forward_batch_info.py +210 -25
- sglang/srt/model_executor/model_runner.py +147 -42
- sglang/srt/model_loader/loader.py +7 -1
- sglang/srt/model_loader/utils.py +4 -4
- sglang/srt/models/clip.py +1 -1
- sglang/srt/models/deepseek.py +9 -6
- sglang/srt/models/deepseek_janus_pro.py +1 -1
- sglang/srt/models/deepseek_v2.py +192 -173
- sglang/srt/models/deepseek_vl2.py +5 -5
- sglang/srt/models/gemma.py +48 -0
- sglang/srt/models/gemma2.py +52 -0
- sglang/srt/models/gemma3_causal.py +63 -0
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/gemma3n_mm.py +2 -4
- sglang/srt/models/granitemoe.py +385 -0
- sglang/srt/models/grok.py +9 -3
- sglang/srt/models/hunyuan.py +63 -16
- sglang/srt/models/internvl.py +1 -1
- sglang/srt/models/kimi_vl.py +1 -1
- sglang/srt/models/llama.py +41 -0
- sglang/srt/models/llama4.py +11 -11
- sglang/srt/models/llava.py +2 -2
- sglang/srt/models/llavavid.py +1 -1
- sglang/srt/models/minicpm.py +0 -2
- sglang/srt/models/minicpmo.py +3 -7
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mixtral.py +9 -2
- sglang/srt/models/mllama.py +3 -5
- sglang/srt/models/mllama4.py +13 -6
- sglang/srt/models/olmoe.py +8 -5
- sglang/srt/models/persimmon.py +330 -0
- sglang/srt/models/phi.py +321 -0
- sglang/srt/models/phi4mm.py +44 -4
- sglang/srt/models/phi4mm_audio.py +1260 -0
- sglang/srt/models/phi4mm_utils.py +1917 -0
- sglang/srt/models/phimoe.py +9 -3
- sglang/srt/models/qwen.py +37 -0
- sglang/srt/models/qwen2.py +41 -0
- sglang/srt/models/qwen2_5_vl.py +4 -4
- sglang/srt/models/qwen2_audio.py +1 -1
- sglang/srt/models/qwen2_moe.py +53 -9
- sglang/srt/models/qwen2_vl.py +4 -4
- sglang/srt/models/qwen3.py +65 -1
- sglang/srt/models/qwen3_moe.py +57 -24
- sglang/srt/models/vila.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +91 -97
- sglang/srt/multimodal/processors/clip.py +21 -19
- sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
- sglang/srt/multimodal/processors/gemma3.py +13 -17
- sglang/srt/multimodal/processors/gemma3n.py +19 -23
- sglang/srt/multimodal/processors/internvl.py +9 -10
- sglang/srt/multimodal/processors/janus_pro.py +12 -27
- sglang/srt/multimodal/processors/kimi_vl.py +12 -14
- sglang/srt/multimodal/processors/llava.py +4 -2
- sglang/srt/multimodal/processors/minicpm.py +35 -44
- sglang/srt/multimodal/processors/mlama.py +21 -18
- sglang/srt/multimodal/processors/mllama4.py +4 -5
- sglang/srt/multimodal/processors/phi4mm.py +63 -39
- sglang/srt/multimodal/processors/pixtral.py +14 -35
- sglang/srt/multimodal/processors/qwen_audio.py +65 -0
- sglang/srt/multimodal/processors/qwen_vl.py +16 -21
- sglang/srt/multimodal/processors/vila.py +14 -14
- sglang/srt/reasoning_parser.py +46 -4
- sglang/srt/sampling/sampling_batch_info.py +6 -5
- sglang/srt/sampling/sampling_params.py +8 -1
- sglang/srt/server_args.py +454 -270
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +46 -37
- sglang/srt/speculative/eagle_utils.py +51 -23
- sglang/srt/speculative/eagle_worker.py +59 -44
- sglang/srt/two_batch_overlap.py +10 -5
- sglang/srt/utils.py +44 -69
- sglang/test/runners.py +14 -3
- sglang/test/test_activation.py +50 -1
- sglang/test/test_block_fp8.py +8 -3
- sglang/test/test_block_fp8_ep.py +1 -1
- sglang/test/test_custom_ops.py +12 -7
- sglang/test/test_cutlass_w4a8_moe.py +1 -3
- sglang/test/test_fp4_moe.py +1 -3
- sglang/test/test_marlin_moe.py +286 -0
- sglang/test/test_marlin_utils.py +171 -0
- sglang/test/test_utils.py +35 -0
- sglang/version.py +1 -1
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/METADATA +10 -10
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/RECORD +198 -175
- sglang/srt/layers/quantization/quant_utils.py +0 -166
- sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/top_level.txt +0 -0
sglang/srt/layers/moe/topk.py
CHANGED
@@ -12,22 +12,21 @@
|
|
12
12
|
# limitations under the License.
|
13
13
|
# ==============================================================================
|
14
14
|
|
15
|
+
from __future__ import annotations
|
16
|
+
|
15
17
|
import math
|
16
|
-
from typing import Callable, Optional
|
18
|
+
from typing import Callable, NamedTuple, Optional
|
17
19
|
|
18
20
|
import torch
|
19
21
|
import torch.nn.functional as F
|
20
22
|
|
23
|
+
from sglang.srt.custom_op import CustomOp
|
21
24
|
from sglang.srt.eplb import expert_location_dispatch
|
22
|
-
from sglang.srt.eplb.expert_distribution import
|
23
|
-
ExpertDistributionRecorder,
|
24
|
-
get_global_expert_distribution_recorder,
|
25
|
-
)
|
25
|
+
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
26
26
|
from sglang.srt.eplb.expert_location_dispatch import (
|
27
27
|
ExpertLocationDispatchInfo,
|
28
28
|
topk_ids_logical_to_physical,
|
29
29
|
)
|
30
|
-
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
31
30
|
from sglang.srt.utils import (
|
32
31
|
cpu_has_amx_support,
|
33
32
|
get_bool_env_var,
|
@@ -40,10 +39,10 @@ from sglang.srt.utils import (
|
|
40
39
|
|
41
40
|
_is_cuda = is_cuda()
|
42
41
|
_is_hip = is_hip()
|
43
|
-
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
44
|
-
_is_cpu_amx_available = cpu_has_amx_support()
|
45
42
|
_is_cpu = is_cpu()
|
43
|
+
_is_cpu_amx_available = cpu_has_amx_support()
|
46
44
|
_is_npu = is_npu()
|
45
|
+
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
47
46
|
|
48
47
|
if _is_cuda:
|
49
48
|
from sgl_kernel import moe_fused_gate
|
@@ -55,6 +54,167 @@ if _use_aiter:
|
|
55
54
|
from aiter import biased_grouped_topk as aiter_biased_grouped_topk
|
56
55
|
except ImportError:
|
57
56
|
raise ImportError("aiter is required when SGLANG_USE_AITER is set to True")
|
57
|
+
if _is_npu:
|
58
|
+
import torch_npu
|
59
|
+
|
60
|
+
|
61
|
+
class TopKOutput(NamedTuple):
|
62
|
+
topk_weights: torch.Tensor
|
63
|
+
topk_ids: torch.Tensor
|
64
|
+
router_logits: torch.Tensor
|
65
|
+
|
66
|
+
|
67
|
+
class TopK(CustomOp):
|
68
|
+
|
69
|
+
# TODO(ch-wan): support triton_kernels
|
70
|
+
|
71
|
+
def __init__(
|
72
|
+
self,
|
73
|
+
top_k: int,
|
74
|
+
*,
|
75
|
+
use_grouped_topk: bool = False,
|
76
|
+
topk_group: Optional[int] = None,
|
77
|
+
num_expert_group: Optional[int] = None,
|
78
|
+
renormalize: bool = True,
|
79
|
+
num_fused_shared_experts: int = 0,
|
80
|
+
custom_routing_function: Optional[Callable] = None,
|
81
|
+
scoring_func: str = "softmax",
|
82
|
+
correction_bias: Optional[torch.Tensor] = None,
|
83
|
+
routed_scaling_factor: Optional[float] = None,
|
84
|
+
):
|
85
|
+
# NOTE: scoring_func is not used for now, but we keep it for future use
|
86
|
+
# see https://github.com/sgl-project/sglang/pull/4505 for more details
|
87
|
+
super().__init__()
|
88
|
+
if use_grouped_topk:
|
89
|
+
assert num_expert_group is not None and topk_group is not None
|
90
|
+
self.top_k = top_k
|
91
|
+
self.use_grouped_topk = use_grouped_topk
|
92
|
+
self.renormalize = renormalize
|
93
|
+
self.topk_group = topk_group
|
94
|
+
self.num_expert_group = num_expert_group
|
95
|
+
self.num_fused_shared_experts = num_fused_shared_experts
|
96
|
+
self.custom_routing_function = custom_routing_function
|
97
|
+
self.correction_bias = correction_bias
|
98
|
+
self.routed_scaling_factor = routed_scaling_factor
|
99
|
+
|
100
|
+
def forward_native(
|
101
|
+
self,
|
102
|
+
hidden_states: torch.Tensor,
|
103
|
+
router_logits: torch.Tensor,
|
104
|
+
*,
|
105
|
+
num_token_non_padded: Optional[torch.Tensor] = None,
|
106
|
+
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
107
|
+
) -> TopKOutput:
|
108
|
+
torch_native = True
|
109
|
+
return select_experts(
|
110
|
+
hidden_states=hidden_states,
|
111
|
+
router_logits=router_logits,
|
112
|
+
top_k=self.top_k,
|
113
|
+
use_grouped_topk=self.use_grouped_topk,
|
114
|
+
renormalize=self.renormalize,
|
115
|
+
topk_group=self.topk_group,
|
116
|
+
num_expert_group=self.num_expert_group,
|
117
|
+
num_fused_shared_experts=self.num_fused_shared_experts,
|
118
|
+
custom_routing_function=self.custom_routing_function,
|
119
|
+
correction_bias=self.correction_bias,
|
120
|
+
torch_native=torch_native,
|
121
|
+
routed_scaling_factor=self.routed_scaling_factor,
|
122
|
+
num_token_non_padded=num_token_non_padded,
|
123
|
+
expert_location_dispatch_info=expert_location_dispatch_info,
|
124
|
+
)
|
125
|
+
|
126
|
+
def forward_cuda(
|
127
|
+
self,
|
128
|
+
hidden_states: torch.Tensor,
|
129
|
+
router_logits: torch.Tensor,
|
130
|
+
*,
|
131
|
+
num_token_non_padded: Optional[torch.Tensor] = None,
|
132
|
+
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
133
|
+
) -> TopKOutput:
|
134
|
+
torch_native = False
|
135
|
+
return select_experts(
|
136
|
+
hidden_states=hidden_states,
|
137
|
+
router_logits=router_logits,
|
138
|
+
top_k=self.top_k,
|
139
|
+
use_grouped_topk=self.use_grouped_topk,
|
140
|
+
renormalize=self.renormalize,
|
141
|
+
topk_group=self.topk_group,
|
142
|
+
num_expert_group=self.num_expert_group,
|
143
|
+
num_fused_shared_experts=self.num_fused_shared_experts,
|
144
|
+
custom_routing_function=self.custom_routing_function,
|
145
|
+
correction_bias=self.correction_bias,
|
146
|
+
torch_native=torch_native,
|
147
|
+
routed_scaling_factor=self.routed_scaling_factor,
|
148
|
+
num_token_non_padded=num_token_non_padded,
|
149
|
+
expert_location_dispatch_info=expert_location_dispatch_info,
|
150
|
+
)
|
151
|
+
|
152
|
+
def forward_cpu(
|
153
|
+
self,
|
154
|
+
hidden_states: torch.Tensor,
|
155
|
+
router_logits: torch.Tensor,
|
156
|
+
*,
|
157
|
+
num_token_non_padded: Optional[torch.Tensor] = None,
|
158
|
+
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
159
|
+
) -> TopKOutput:
|
160
|
+
return select_experts(
|
161
|
+
hidden_states=hidden_states,
|
162
|
+
router_logits=router_logits,
|
163
|
+
top_k=self.top_k,
|
164
|
+
use_grouped_topk=self.use_grouped_topk,
|
165
|
+
renormalize=self.renormalize,
|
166
|
+
topk_group=self.topk_group,
|
167
|
+
num_expert_group=self.num_expert_group,
|
168
|
+
num_fused_shared_experts=self.num_fused_shared_experts,
|
169
|
+
custom_routing_function=self.custom_routing_function,
|
170
|
+
correction_bias=self.correction_bias,
|
171
|
+
routed_scaling_factor=self.routed_scaling_factor,
|
172
|
+
num_token_non_padded=num_token_non_padded,
|
173
|
+
expert_location_dispatch_info=expert_location_dispatch_info,
|
174
|
+
)
|
175
|
+
|
176
|
+
def forward_npu(
|
177
|
+
self,
|
178
|
+
hidden_states: torch.Tensor,
|
179
|
+
router_logits: torch.Tensor,
|
180
|
+
*,
|
181
|
+
num_token_non_padded: Optional[torch.Tensor] = None,
|
182
|
+
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
183
|
+
) -> TopKOutput:
|
184
|
+
global_num_experts = router_logits.shape[-1]
|
185
|
+
|
186
|
+
# NOTE: now npu_moe_gating_top_k can only support `group_count=256` pattern
|
187
|
+
if global_num_experts == 256:
|
188
|
+
return torch_npu.npu_moe_gating_top_k(
|
189
|
+
router_logits,
|
190
|
+
k=self.top_k,
|
191
|
+
bias=self.correction_bias,
|
192
|
+
k_group=self.topk_group,
|
193
|
+
group_count=self.num_expert_group,
|
194
|
+
group_select_mode=1,
|
195
|
+
renorm=0,
|
196
|
+
norm_type=1,
|
197
|
+
routed_scaling_factor=1,
|
198
|
+
eps=float(1e-20),
|
199
|
+
)
|
200
|
+
else:
|
201
|
+
torch_native = True
|
202
|
+
return select_experts(
|
203
|
+
hidden_states=hidden_states,
|
204
|
+
router_logits=router_logits,
|
205
|
+
top_k=self.top_k,
|
206
|
+
use_grouped_topk=self.use_grouped_topk,
|
207
|
+
renormalize=self.renormalize,
|
208
|
+
topk_group=self.topk_group,
|
209
|
+
num_expert_group=self.num_expert_group,
|
210
|
+
num_fused_shared_experts=self.num_fused_shared_experts,
|
211
|
+
custom_routing_function=self.custom_routing_function,
|
212
|
+
correction_bias=self.correction_bias,
|
213
|
+
torch_native=torch_native,
|
214
|
+
routed_scaling_factor=self.routed_scaling_factor,
|
215
|
+
num_token_non_padded=num_token_non_padded,
|
216
|
+
expert_location_dispatch_info=expert_location_dispatch_info,
|
217
|
+
)
|
58
218
|
|
59
219
|
|
60
220
|
def fused_topk_torch_native(
|
@@ -97,6 +257,19 @@ def fused_topk_cpu(
|
|
97
257
|
return topk_weights, topk_ids
|
98
258
|
|
99
259
|
|
260
|
+
def apply_topk_weights_cpu(need_apply, topk_weights, inputs):
|
261
|
+
if not need_apply:
|
262
|
+
return inputs, topk_weights
|
263
|
+
|
264
|
+
# TODO: fuse below processing in fused_experts_cpu kernel
|
265
|
+
inputs = inputs * topk_weights.to(inputs.dtype)
|
266
|
+
topk_weights = torch.ones_like(
|
267
|
+
topk_weights, dtype=torch.float32
|
268
|
+
) # clear topk_weights as already applied
|
269
|
+
|
270
|
+
return inputs, topk_weights
|
271
|
+
|
272
|
+
|
100
273
|
def fused_topk(
|
101
274
|
hidden_states: torch.Tensor,
|
102
275
|
gating_output: torch.Tensor,
|
@@ -213,6 +386,7 @@ def grouped_topk_cpu(
|
|
213
386
|
)
|
214
387
|
|
215
388
|
|
389
|
+
@torch.compile(dynamic=True, backend=get_compiler_backend(), disable=_is_npu)
|
216
390
|
def biased_grouped_topk_impl(
|
217
391
|
hidden_states: torch.Tensor,
|
218
392
|
gating_output: torch.Tensor,
|
@@ -308,7 +482,6 @@ def biased_grouped_topk_gpu(
|
|
308
482
|
renormalize: bool,
|
309
483
|
num_expert_group: int = 0,
|
310
484
|
topk_group: int = 0,
|
311
|
-
compiled: bool = not _is_npu,
|
312
485
|
num_fused_shared_experts: int = 0,
|
313
486
|
routed_scaling_factor: Optional[float] = None,
|
314
487
|
num_token_non_padded: Optional[torch.Tensor] = None,
|
@@ -325,7 +498,7 @@ def biased_grouped_topk_gpu(
|
|
325
498
|
and is_power_of_two(correction_bias.shape[0])
|
326
499
|
):
|
327
500
|
topk_weights, topk_ids = moe_fused_gate(
|
328
|
-
gating_output,
|
501
|
+
gating_output.to(dtype=torch.float32),
|
329
502
|
correction_bias,
|
330
503
|
num_expert_group,
|
331
504
|
topk_group,
|
@@ -350,7 +523,7 @@ def biased_grouped_topk_gpu(
|
|
350
523
|
topk_weights = torch.empty((token, topk), dtype=torch.float32, device=device)
|
351
524
|
topk_ids = torch.empty((token, topk), dtype=torch.int32, device=device)
|
352
525
|
aiter_biased_grouped_topk(
|
353
|
-
gating_output,
|
526
|
+
gating_output.to(dtype=torch.float32),
|
354
527
|
correction_bias,
|
355
528
|
topk_weights,
|
356
529
|
topk_ids,
|
@@ -361,14 +534,7 @@ def biased_grouped_topk_gpu(
|
|
361
534
|
)
|
362
535
|
return topk_weights, topk_ids
|
363
536
|
else:
|
364
|
-
|
365
|
-
torch.compile(
|
366
|
-
biased_grouped_topk_impl, dynamic=True, backend=get_compiler_backend()
|
367
|
-
)
|
368
|
-
if compiled
|
369
|
-
else biased_grouped_topk_impl
|
370
|
-
)
|
371
|
-
return biased_grouped_topk_fn(
|
537
|
+
return biased_grouped_topk_impl(
|
372
538
|
hidden_states,
|
373
539
|
gating_output,
|
374
540
|
correction_bias,
|
@@ -427,8 +593,9 @@ def select_experts(
|
|
427
593
|
hidden_states: torch.Tensor,
|
428
594
|
router_logits: torch.Tensor,
|
429
595
|
top_k: int,
|
430
|
-
|
431
|
-
|
596
|
+
*,
|
597
|
+
use_grouped_topk: bool = False,
|
598
|
+
renormalize: bool = False,
|
432
599
|
topk_group: Optional[int] = None,
|
433
600
|
num_expert_group: Optional[int] = None,
|
434
601
|
num_fused_shared_experts: int = 0,
|
@@ -438,7 +605,7 @@ def select_experts(
|
|
438
605
|
routed_scaling_factor: Optional[float] = None,
|
439
606
|
num_token_non_padded: Optional[torch.Tensor] = None,
|
440
607
|
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
441
|
-
):
|
608
|
+
) -> TopKOutput:
|
442
609
|
router_logits, correction_bias = (
|
443
610
|
expert_location_dispatch.transform_select_experts_inputs(
|
444
611
|
router_logits=router_logits,
|
@@ -513,4 +680,4 @@ def select_experts(
|
|
513
680
|
|
514
681
|
get_global_expert_distribution_recorder().on_select_experts(topk_ids=topk_ids)
|
515
682
|
|
516
|
-
return topk_weights, topk_ids
|
683
|
+
return TopKOutput(topk_weights, topk_ids, router_logits)
|
@@ -1,18 +1,14 @@
|
|
1
1
|
# Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/__init__.py
|
2
|
+
from __future__ import annotations
|
3
|
+
|
2
4
|
import builtins
|
3
5
|
import inspect
|
4
|
-
import
|
5
|
-
from copy import deepcopy
|
6
|
-
from typing import Callable, Dict, Optional, Type, Union
|
6
|
+
from typing import TYPE_CHECKING, Callable, Dict, Optional, Type, Union
|
7
7
|
|
8
8
|
import torch
|
9
9
|
|
10
10
|
try:
|
11
11
|
from vllm.model_executor.layers.quantization.aqlm import AQLMConfig
|
12
|
-
from vllm.model_executor.layers.quantization.awq_marlin import (
|
13
|
-
AWQMarlinConfig,
|
14
|
-
AWQMoEMethod,
|
15
|
-
)
|
16
12
|
from vllm.model_executor.layers.quantization.bitsandbytes import BitsAndBytesConfig
|
17
13
|
from vllm.model_executor.layers.quantization.compressed_tensors.compressed_tensors_moe import (
|
18
14
|
CompressedTensorsW8A8Fp8MoEMethod,
|
@@ -22,10 +18,6 @@ try:
|
|
22
18
|
from vllm.model_executor.layers.quantization.experts_int8 import ExpertsInt8Config
|
23
19
|
from vllm.model_executor.layers.quantization.fbgemm_fp8 import FBGEMMFp8Config
|
24
20
|
from vllm.model_executor.layers.quantization.gguf import GGUFConfig
|
25
|
-
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
|
26
|
-
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
27
|
-
GPTQMarlinLinearMethod,
|
28
|
-
)
|
29
21
|
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
|
30
22
|
GPTQMarlin24Config,
|
31
23
|
)
|
@@ -42,15 +34,14 @@ except ImportError:
|
|
42
34
|
def override_quantization_method(self, *args, **kwargs):
|
43
35
|
return None
|
44
36
|
|
45
|
-
AQLMConfig =
|
46
|
-
|
47
|
-
) =
|
48
|
-
|
49
|
-
) =
|
37
|
+
AQLMConfig = BitsAndBytesConfig = CompressedTensorsConfig = DeepSpeedFPConfig = (
|
38
|
+
ExpertsInt8Config
|
39
|
+
) = FBGEMMFp8Config = GGUFConfig = GPTQMarlin24Config = MarlinConfig = QQQConfig = (
|
40
|
+
Int8TpuConfig
|
41
|
+
) = DummyConfig
|
50
42
|
|
51
43
|
|
52
|
-
from sglang.srt.layers.
|
53
|
-
from sglang.srt.layers.quantization.awq import AWQConfig
|
44
|
+
from sglang.srt.layers.quantization.awq import AWQConfig, AWQMarlinConfig
|
54
45
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
55
46
|
from sglang.srt.layers.quantization.blockwise_int8 import BlockInt8Config
|
56
47
|
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import (
|
@@ -59,7 +50,9 @@ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import
|
|
59
50
|
from sglang.srt.layers.quantization.fp8 import Fp8Config
|
60
51
|
from sglang.srt.layers.quantization.gptq import (
|
61
52
|
GPTQConfig,
|
53
|
+
GPTQLinearMethod,
|
62
54
|
GPTQMarlinConfig,
|
55
|
+
GPTQMarlinLinearMethod,
|
63
56
|
GPTQMarlinMoEMethod,
|
64
57
|
)
|
65
58
|
from sglang.srt.layers.quantization.modelopt_quant import (
|
@@ -67,11 +60,16 @@ from sglang.srt.layers.quantization.modelopt_quant import (
|
|
67
60
|
ModelOptFp8Config,
|
68
61
|
)
|
69
62
|
from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config
|
63
|
+
from sglang.srt.layers.quantization.petit import PetitNvFp4Config
|
70
64
|
from sglang.srt.layers.quantization.qoq import QoQConfig
|
65
|
+
from sglang.srt.layers.quantization.utils import get_linear_quant_method
|
71
66
|
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config
|
72
67
|
from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
|
73
68
|
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
|
74
69
|
|
70
|
+
if TYPE_CHECKING:
|
71
|
+
from sglang.srt.layers.moe.topk import TopKOutput
|
72
|
+
|
75
73
|
# Base quantization methods that don't depend on vllm
|
76
74
|
BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
77
75
|
"fp8": Fp8Config,
|
@@ -84,6 +82,7 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
|
84
82
|
"compressed-tensors": CompressedTensorsConfig,
|
85
83
|
"qoq": QoQConfig,
|
86
84
|
"w4afp8": W4AFp8Config,
|
85
|
+
"petit_nvfp4": PetitNvFp4Config,
|
87
86
|
}
|
88
87
|
|
89
88
|
# VLLM-dependent quantization methods
|
@@ -122,99 +121,6 @@ def get_quantization_config(quantization: str) -> Type[QuantizationConfig]:
|
|
122
121
|
return QUANTIZATION_METHODS[quantization]
|
123
122
|
|
124
123
|
|
125
|
-
# Match dynamic rules with module name (prefix) and override quantize
|
126
|
-
# config if module (prefix) matches a rule
|
127
|
-
def override_config(config: QuantizationConfig, prefix: str):
|
128
|
-
weight_bits = get_dynamic_override(config, prefix, "bits", config.weight_bits)
|
129
|
-
if isinstance(weight_bits, int):
|
130
|
-
config.weight_bits = weight_bits
|
131
|
-
group_size = get_dynamic_override(config, prefix, "group_size", config.group_size)
|
132
|
-
if isinstance(group_size, int):
|
133
|
-
config.group_size = group_size
|
134
|
-
desc_act = get_dynamic_override(config, prefix, "desc_act", config.desc_act)
|
135
|
-
if isinstance(desc_act, bool):
|
136
|
-
config.desc_act = desc_act
|
137
|
-
|
138
|
-
config.pack_factor = 32 // config.weight_bits # packed into int32
|
139
|
-
if config.get_name() == "gptq_marlin":
|
140
|
-
is_sym = get_dynamic_override(config, prefix, "sym", config.is_sym)
|
141
|
-
if isinstance(is_sym, bool):
|
142
|
-
config.is_sym = is_sym
|
143
|
-
|
144
|
-
if (config.weight_bits, config.is_sym) not in config.TYPE_MAP:
|
145
|
-
raise ValueError(
|
146
|
-
"Unsupported quantization config: "
|
147
|
-
f"bits={config.weight_bits}, sym={config.is_sym}"
|
148
|
-
)
|
149
|
-
|
150
|
-
config.quant_type = config.TYPE_MAP[(config.weight_bits, config.is_sym)]
|
151
|
-
elif config.get_name() == "gptq":
|
152
|
-
if config.weight_bits not in [2, 3, 4, 8]:
|
153
|
-
raise ValueError(
|
154
|
-
"Currently, only 2/3/4/8-bit weight quantization is "
|
155
|
-
f"supported for GPTQ, but got {config.weight_bits} bits."
|
156
|
-
)
|
157
|
-
|
158
|
-
|
159
|
-
def get_dynamic_override(
|
160
|
-
config: QuantizationConfig,
|
161
|
-
layer_name: str,
|
162
|
-
key: Optional[str] = None,
|
163
|
-
default_value: Union[int, bool, None] = None,
|
164
|
-
) -> Union[Dict, int, bool, None]:
|
165
|
-
for pattern, pattern_dict in config.dynamic.items():
|
166
|
-
# Negative match: matched modules are excluded from quantized init
|
167
|
-
if pattern.startswith("-:"):
|
168
|
-
if re.match(pattern.removeprefix("-:"), layer_name):
|
169
|
-
return False
|
170
|
-
# Positive match: matched modules have quant properties overrides
|
171
|
-
# base quant config
|
172
|
-
elif re.match(pattern.removeprefix("+:"), layer_name):
|
173
|
-
if key is None:
|
174
|
-
return pattern_dict
|
175
|
-
else:
|
176
|
-
return pattern_dict.get(key, default_value)
|
177
|
-
return default_value
|
178
|
-
|
179
|
-
|
180
|
-
def get_linear_quant_method(
|
181
|
-
config: QuantizationConfig,
|
182
|
-
layer: torch.nn.Module,
|
183
|
-
prefix: str,
|
184
|
-
linear_method_cls: type,
|
185
|
-
):
|
186
|
-
# Move import here to avoid circular import. This is only used in monkey patching
|
187
|
-
# of vllm's QuantizationConfig.
|
188
|
-
from sglang.srt.layers.vocab_parallel_embedding import (
|
189
|
-
ParallelLMHead,
|
190
|
-
UnquantizedEmbeddingMethod,
|
191
|
-
)
|
192
|
-
|
193
|
-
cloned_config = deepcopy(config)
|
194
|
-
parallel_lm_head_quantized = (
|
195
|
-
isinstance(layer, ParallelLMHead) and cloned_config.lm_head_quantized
|
196
|
-
)
|
197
|
-
|
198
|
-
if isinstance(layer, LinearBase) or parallel_lm_head_quantized:
|
199
|
-
# False = skip module, None = no override, else = Positive match
|
200
|
-
if (
|
201
|
-
get_dynamic_override( # noqa: E712
|
202
|
-
cloned_config, layer_name=prefix # noqa: E712
|
203
|
-
)
|
204
|
-
== False
|
205
|
-
): # noqa: E712
|
206
|
-
if parallel_lm_head_quantized:
|
207
|
-
return UnquantizedEmbeddingMethod()
|
208
|
-
return UnquantizedLinearMethod()
|
209
|
-
|
210
|
-
if prefix:
|
211
|
-
# Dynamic per module/layer rules may override base config
|
212
|
-
override_config(cloned_config, prefix=prefix)
|
213
|
-
|
214
|
-
return linear_method_cls(cloned_config)
|
215
|
-
return None
|
216
|
-
|
217
|
-
|
218
124
|
def gptq_get_quant_method(self, layer, prefix):
|
219
125
|
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
220
126
|
|
@@ -285,15 +191,8 @@ def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
|
|
285
191
|
self,
|
286
192
|
layer: torch.nn.Module,
|
287
193
|
x: torch.Tensor,
|
288
|
-
|
289
|
-
|
290
|
-
renormalize: bool,
|
291
|
-
use_grouped_topk: bool,
|
292
|
-
topk_group: Optional[int] = None,
|
293
|
-
num_expert_group: Optional[int] = None,
|
294
|
-
num_fused_shared_experts: int = 0,
|
295
|
-
custom_routing_function: Optional[Callable] = None,
|
296
|
-
correction_bias: Optional[torch.Tensor] = None,
|
194
|
+
topk_output: TopKOutput,
|
195
|
+
*,
|
297
196
|
activation: str = "silu",
|
298
197
|
apply_router_weight_on_input: bool = False,
|
299
198
|
inplace: bool = True,
|
@@ -307,20 +206,8 @@ def monkey_patch_moe_apply(class_obj: "FusedMoEMethodBase"):
|
|
307
206
|
"self": self,
|
308
207
|
"layer": layer,
|
309
208
|
"x": x,
|
310
|
-
"
|
311
|
-
"top_k": top_k,
|
312
|
-
"renormalize": renormalize,
|
313
|
-
"use_grouped_topk": use_grouped_topk,
|
314
|
-
"topk_group": topk_group,
|
315
|
-
"num_expert_group": num_expert_group,
|
316
|
-
"custom_routing_function": custom_routing_function,
|
209
|
+
"topk_output": topk_output,
|
317
210
|
}
|
318
|
-
if correction_bias is not None:
|
319
|
-
if not has_correction_bias:
|
320
|
-
raise ValueError(
|
321
|
-
"Please increase the version of your vllm. Try `pip install vllm==0.9.0.1`"
|
322
|
-
)
|
323
|
-
kwargs["e_score_correction_bias"] = correction_bias
|
324
211
|
return original_apply(**kwargs)
|
325
212
|
|
326
213
|
setattr(class_obj, "apply", new_apply)
|
@@ -331,7 +218,6 @@ def monkey_patch_quant_configs():
|
|
331
218
|
setattr(GPTQMarlinConfig, "get_quant_method", gptq_get_quant_method)
|
332
219
|
setattr(GPTQConfig, "get_quant_method", gptq_get_quant_method)
|
333
220
|
|
334
|
-
monkey_patch_moe_apply(AWQMoEMethod)
|
335
221
|
monkey_patch_moe_apply(GPTQMarlinMoEMethod)
|
336
222
|
monkey_patch_moe_apply(CompressedTensorsW8A8Fp8MoEMethod)
|
337
223
|
monkey_patch_moe_apply(CompressedTensorsWNA16MoEMethod)
|