sglang 0.4.6.post4__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +6 -6
- sglang/bench_one_batch.py +5 -4
- sglang/bench_one_batch_server.py +23 -15
- sglang/bench_serving.py +133 -57
- sglang/compile_deep_gemm.py +4 -4
- sglang/srt/configs/model_config.py +39 -28
- sglang/srt/conversation.py +1 -1
- sglang/srt/disaggregation/decode.py +122 -133
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
- sglang/srt/disaggregation/fake/conn.py +3 -13
- sglang/srt/disaggregation/kv_events.py +357 -0
- sglang/srt/disaggregation/mini_lb.py +57 -24
- sglang/srt/disaggregation/mooncake/conn.py +11 -2
- sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
- sglang/srt/disaggregation/nixl/conn.py +9 -19
- sglang/srt/disaggregation/prefill.py +126 -44
- sglang/srt/disaggregation/utils.py +116 -5
- sglang/srt/distributed/utils.py +3 -3
- sglang/srt/entrypoints/EngineBase.py +5 -0
- sglang/srt/entrypoints/engine.py +28 -8
- sglang/srt/entrypoints/http_server.py +6 -4
- sglang/srt/entrypoints/http_server_engine.py +5 -2
- sglang/srt/function_call/base_format_detector.py +250 -0
- sglang/srt/function_call/core_types.py +34 -0
- sglang/srt/function_call/deepseekv3_detector.py +157 -0
- sglang/srt/function_call/ebnf_composer.py +234 -0
- sglang/srt/function_call/function_call_parser.py +175 -0
- sglang/srt/function_call/llama32_detector.py +74 -0
- sglang/srt/function_call/mistral_detector.py +84 -0
- sglang/srt/function_call/pythonic_detector.py +163 -0
- sglang/srt/function_call/qwen25_detector.py +67 -0
- sglang/srt/function_call/utils.py +35 -0
- sglang/srt/hf_transformers_utils.py +46 -7
- sglang/srt/layers/attention/aiter_backend.py +513 -0
- sglang/srt/layers/attention/flashattention_backend.py +63 -17
- sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
- sglang/srt/layers/attention/flashmla_backend.py +340 -78
- sglang/srt/layers/attention/triton_backend.py +3 -0
- sglang/srt/layers/attention/utils.py +2 -2
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/communicator.py +451 -0
- sglang/srt/layers/dp_attention.py +0 -10
- sglang/srt/layers/moe/cutlass_moe.py +207 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +33 -11
- sglang/srt/layers/moe/ep_moe/layer.py +104 -50
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
- sglang/srt/layers/moe/topk.py +66 -9
- sglang/srt/layers/multimodal.py +70 -0
- sglang/srt/layers/quantization/__init__.py +7 -2
- sglang/srt/layers/quantization/deep_gemm.py +5 -3
- sglang/srt/layers/quantization/fp8.py +90 -0
- sglang/srt/layers/quantization/fp8_utils.py +6 -0
- sglang/srt/layers/quantization/gptq.py +298 -6
- sglang/srt/layers/quantization/int8_kernel.py +18 -5
- sglang/srt/layers/quantization/qoq.py +244 -0
- sglang/srt/lora/lora_manager.py +1 -3
- sglang/srt/managers/deepseek_eplb.py +278 -0
- sglang/srt/managers/eplb_manager.py +55 -0
- sglang/srt/managers/expert_distribution.py +704 -56
- sglang/srt/managers/expert_location.py +394 -0
- sglang/srt/managers/expert_location_dispatch.py +91 -0
- sglang/srt/managers/io_struct.py +16 -3
- sglang/srt/managers/mm_utils.py +293 -139
- sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
- sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
- sglang/srt/managers/multimodal_processors/internvl.py +14 -5
- sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
- sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
- sglang/srt/managers/multimodal_processors/llava.py +3 -3
- sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
- sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
- sglang/srt/managers/multimodal_processors/pixtral.py +9 -9
- sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
- sglang/srt/managers/schedule_batch.py +49 -21
- sglang/srt/managers/schedule_policy.py +4 -5
- sglang/srt/managers/scheduler.py +92 -50
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +99 -24
- sglang/srt/mem_cache/base_prefix_cache.py +3 -0
- sglang/srt/mem_cache/chunk_cache.py +3 -1
- sglang/srt/mem_cache/hiradix_cache.py +4 -4
- sglang/srt/mem_cache/memory_pool.py +74 -52
- sglang/srt/mem_cache/multimodal_cache.py +45 -0
- sglang/srt/mem_cache/radix_cache.py +58 -5
- sglang/srt/metrics/collector.py +2 -2
- sglang/srt/mm_utils.py +10 -0
- sglang/srt/model_executor/cuda_graph_runner.py +20 -9
- sglang/srt/model_executor/expert_location_updater.py +422 -0
- sglang/srt/model_executor/forward_batch_info.py +4 -0
- sglang/srt/model_executor/model_runner.py +144 -54
- sglang/srt/model_loader/loader.py +10 -6
- sglang/srt/models/clip.py +5 -1
- sglang/srt/models/deepseek_v2.py +297 -343
- sglang/srt/models/exaone.py +8 -3
- sglang/srt/models/gemma3_mm.py +70 -33
- sglang/srt/models/llama4.py +10 -2
- sglang/srt/models/llava.py +26 -18
- sglang/srt/models/mimo_mtp.py +220 -0
- sglang/srt/models/minicpmo.py +5 -12
- sglang/srt/models/mistral.py +71 -1
- sglang/srt/models/mllama.py +3 -3
- sglang/srt/models/qwen2.py +95 -26
- sglang/srt/models/qwen2_5_vl.py +8 -0
- sglang/srt/models/qwen2_moe.py +330 -60
- sglang/srt/models/qwen2_vl.py +6 -0
- sglang/srt/models/qwen3.py +52 -10
- sglang/srt/models/qwen3_moe.py +411 -48
- sglang/srt/models/siglip.py +294 -0
- sglang/srt/openai_api/adapter.py +28 -16
- sglang/srt/openai_api/protocol.py +6 -0
- sglang/srt/operations.py +154 -0
- sglang/srt/operations_strategy.py +31 -0
- sglang/srt/server_args.py +134 -24
- sglang/srt/speculative/eagle_utils.py +131 -0
- sglang/srt/speculative/eagle_worker.py +47 -2
- sglang/srt/utils.py +68 -12
- sglang/test/test_cutlass_moe.py +278 -0
- sglang/test/test_utils.py +2 -36
- sglang/utils.py +2 -2
- sglang/version.py +1 -1
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +20 -11
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +128 -102
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
- sglang/srt/function_call_parser.py +0 -858
- sglang/srt/platforms/interface.py +0 -371
- /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,15 @@
|
|
1
|
+
import logging
|
2
|
+
from dataclasses import dataclass
|
3
|
+
|
1
4
|
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
|
2
|
-
from sglang.srt.
|
5
|
+
from sglang.srt.managers.expert_distribution import (
|
6
|
+
get_global_expert_distribution_recorder,
|
7
|
+
)
|
8
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
9
|
+
from sglang.srt.utils import DeepEPMode, load_json_config
|
3
10
|
|
4
11
|
try:
|
5
|
-
from deep_ep import Buffer
|
12
|
+
from deep_ep import Buffer, Config
|
6
13
|
|
7
14
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
8
15
|
sglang_per_token_group_quant_fp8,
|
@@ -12,7 +19,7 @@ try:
|
|
12
19
|
except ImportError:
|
13
20
|
use_deepep = False
|
14
21
|
|
15
|
-
from enum import IntEnum, auto
|
22
|
+
from enum import Enum, IntEnum, auto
|
16
23
|
from typing import Optional, Tuple, Union
|
17
24
|
|
18
25
|
import torch
|
@@ -25,6 +32,8 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
|
|
25
32
|
)
|
26
33
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
27
34
|
|
35
|
+
logger = logging.getLogger(__name__)
|
36
|
+
|
28
37
|
|
29
38
|
class DeepEPDispatchMode(IntEnum):
|
30
39
|
NORMAL = auto()
|
@@ -32,7 +41,6 @@ class DeepEPDispatchMode(IntEnum):
|
|
32
41
|
|
33
42
|
|
34
43
|
class DeepEPBuffer:
|
35
|
-
|
36
44
|
_buffer = None
|
37
45
|
_dispatch_mode: Optional[DeepEPDispatchMode] = None
|
38
46
|
_hidden_size: Optional[int] = None
|
@@ -60,8 +68,10 @@ class DeepEPBuffer:
|
|
60
68
|
if deepep_mode.enable_normal():
|
61
69
|
hidden_bytes = hidden_size * param_bytes
|
62
70
|
for config in (
|
63
|
-
|
64
|
-
Buffer.
|
71
|
+
DeepEPConfig.get_instance().normal_dispatch_config
|
72
|
+
or Buffer.get_dispatch_config(group.size()),
|
73
|
+
DeepEPConfig.get_instance().normal_combine_config
|
74
|
+
or Buffer.get_combine_config(group.size()),
|
65
75
|
):
|
66
76
|
num_nvl_bytes = max(
|
67
77
|
config.get_nvl_buffer_size_hint(hidden_bytes, group.size()),
|
@@ -88,7 +98,12 @@ class DeepEPBuffer:
|
|
88
98
|
num_nvl_bytes,
|
89
99
|
num_rdma_bytes,
|
90
100
|
low_latency_mode=deepep_mode.enable_low_latency(),
|
91
|
-
num_qps_per_rank=(
|
101
|
+
num_qps_per_rank=(
|
102
|
+
max(
|
103
|
+
num_experts // group.size(),
|
104
|
+
DeepEPConfig.get_instance().num_sms // 2,
|
105
|
+
)
|
106
|
+
),
|
92
107
|
)
|
93
108
|
return cls._buffer
|
94
109
|
|
@@ -113,6 +128,35 @@ class DeepEPBuffer:
|
|
113
128
|
cls._dispatch_mode = DeepEPDispatchMode.LOW_LATENCY
|
114
129
|
|
115
130
|
|
131
|
+
class DeepEPConfig:
|
132
|
+
_instance = None
|
133
|
+
|
134
|
+
def __init__(self):
|
135
|
+
config_str = global_server_args_dict["deepep_config"]
|
136
|
+
if config_str:
|
137
|
+
config_parsed = load_json_config(config_str)
|
138
|
+
if torch.distributed.get_rank() == 0:
|
139
|
+
logger.info(f"Use DeepEP Config: {config_parsed}")
|
140
|
+
config_dispatch = config_parsed["normal_dispatch"]
|
141
|
+
config_combine = config_parsed["normal_combine"]
|
142
|
+
|
143
|
+
self.normal_dispatch_config = Config(**config_dispatch)
|
144
|
+
self.normal_combine_config = Config(**config_combine)
|
145
|
+
|
146
|
+
assert config_dispatch["num_sms"] == config_combine["num_sms"]
|
147
|
+
self.num_sms = config_dispatch["num_sms"]
|
148
|
+
else:
|
149
|
+
self.normal_dispatch_config = None
|
150
|
+
self.normal_combine_config = None
|
151
|
+
self.num_sms = Buffer.num_sms
|
152
|
+
|
153
|
+
@classmethod
|
154
|
+
def get_instance(cls):
|
155
|
+
if cls._instance is None:
|
156
|
+
cls._instance = DeepEPConfig()
|
157
|
+
return cls._instance
|
158
|
+
|
159
|
+
|
116
160
|
class _DeepEPDispatcherImplBase:
|
117
161
|
def __init__(
|
118
162
|
self,
|
@@ -295,6 +339,14 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|
295
339
|
async_finish=self.async_finish,
|
296
340
|
allocate_on_comm_stream=(previous_event is not None) and self.async_finish,
|
297
341
|
expert_alignment=128 if _ENABLE_JIT_DEEPGEMM else 1,
|
342
|
+
config=DeepEPConfig.get_instance().normal_dispatch_config,
|
343
|
+
)
|
344
|
+
|
345
|
+
get_global_expert_distribution_recorder().on_deepep_dispatch_normal(
|
346
|
+
num_recv_tokens_per_expert_list,
|
347
|
+
num_tokens_per_rank=num_tokens_per_rank,
|
348
|
+
num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
|
349
|
+
num_tokens_per_expert=num_tokens_per_expert,
|
298
350
|
)
|
299
351
|
|
300
352
|
return (
|
@@ -394,6 +446,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|
394
446
|
async_finish=self.async_finish,
|
395
447
|
previous_event=previous_event,
|
396
448
|
allocate_on_comm_stream=previous_event is not None,
|
449
|
+
config=DeepEPConfig.get_instance().normal_combine_config,
|
397
450
|
)
|
398
451
|
return combined_x, event
|
399
452
|
|
@@ -459,6 +512,10 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
|
459
512
|
):
|
460
513
|
hook() if self.return_recv_hook else event.current_stream_wait()
|
461
514
|
|
515
|
+
get_global_expert_distribution_recorder().on_deepep_dispatch_low_latency(
|
516
|
+
masked_m
|
517
|
+
)
|
518
|
+
|
462
519
|
reorder_topk_ids = seg_indptr = None
|
463
520
|
|
464
521
|
return (
|
@@ -571,6 +628,14 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
|
571
628
|
)
|
572
629
|
|
573
630
|
|
631
|
+
@dataclass
|
632
|
+
class _Stage(Enum):
|
633
|
+
INITIAL = auto()
|
634
|
+
AFTER_DISPATCH_A = auto()
|
635
|
+
AFTER_DISPATCH_B = auto()
|
636
|
+
AFTER_COMBINE_A = auto()
|
637
|
+
|
638
|
+
|
574
639
|
class DeepEPDispatcher:
|
575
640
|
def __init__(
|
576
641
|
self,
|
@@ -609,6 +674,8 @@ class DeepEPDispatcher:
|
|
609
674
|
**common_kwargs,
|
610
675
|
)
|
611
676
|
|
677
|
+
self._stage = _Stage.INITIAL
|
678
|
+
|
612
679
|
def dispatch(self, *args, **kwargs) -> Tuple:
|
613
680
|
self.dispatch_a(*args, **kwargs)
|
614
681
|
ret = self.dispatch_b()
|
@@ -621,6 +688,7 @@ class DeepEPDispatcher:
|
|
621
688
|
topk_weights: torch.Tensor,
|
622
689
|
forward_mode: ForwardMode = None,
|
623
690
|
):
|
691
|
+
self._update_stage(_Stage.INITIAL, _Stage.AFTER_DISPATCH_A)
|
624
692
|
inner_state = self._get_impl(forward_mode).dispatch_a(
|
625
693
|
hidden_states=hidden_states,
|
626
694
|
topk_idx=topk_idx,
|
@@ -629,6 +697,7 @@ class DeepEPDispatcher:
|
|
629
697
|
self._dispatch_intermediate_state = forward_mode, inner_state
|
630
698
|
|
631
699
|
def dispatch_b(self):
|
700
|
+
self._update_stage(_Stage.AFTER_DISPATCH_A, _Stage.AFTER_DISPATCH_B)
|
632
701
|
forward_mode, inner_state = self._dispatch_intermediate_state
|
633
702
|
del self._dispatch_intermediate_state
|
634
703
|
return self._get_impl(forward_mode).dispatch_b(*inner_state)
|
@@ -645,6 +714,7 @@ class DeepEPDispatcher:
|
|
645
714
|
topk_weights: torch.Tensor,
|
646
715
|
forward_mode: ForwardMode,
|
647
716
|
):
|
717
|
+
self._update_stage(_Stage.AFTER_DISPATCH_B, _Stage.AFTER_COMBINE_A)
|
648
718
|
inner_state = self._get_impl(forward_mode).combine_a(
|
649
719
|
hidden_states=hidden_states,
|
650
720
|
topk_idx=topk_idx,
|
@@ -653,6 +723,7 @@ class DeepEPDispatcher:
|
|
653
723
|
self._combine_intermediate_state = forward_mode, inner_state
|
654
724
|
|
655
725
|
def combine_b(self):
|
726
|
+
self._update_stage(_Stage.AFTER_COMBINE_A, _Stage.INITIAL)
|
656
727
|
forward_mode, inner_state = self._combine_intermediate_state
|
657
728
|
del self._combine_intermediate_state
|
658
729
|
return self._get_impl(forward_mode).combine_b(*inner_state)
|
@@ -665,3 +736,7 @@ class DeepEPDispatcher:
|
|
665
736
|
return self._low_latency_dispatcher
|
666
737
|
else:
|
667
738
|
raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}")
|
739
|
+
|
740
|
+
def _update_stage(self, old_stage, new_stage):
|
741
|
+
assert self._stage == old_stage
|
742
|
+
self._stage = new_stage
|
@@ -186,6 +186,19 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
186
186
|
|
187
187
|
if _is_hip and get_bool_env_var("SGLANG_AITER_MOE"):
|
188
188
|
assert not no_combine, "unsupported"
|
189
|
+
if apply_router_weight_on_input:
|
190
|
+
assert (
|
191
|
+
topk_weights.dim() == 2
|
192
|
+
), "`topk_weights` should be in shape (num_tokens, topk)"
|
193
|
+
_, topk = topk_weights.shape
|
194
|
+
assert (
|
195
|
+
topk == 1
|
196
|
+
), "Only support topk=1 when `apply_router_weight_on_input` is True"
|
197
|
+
x = x * topk_weights.to(x.dtype)
|
198
|
+
topk_weights = torch.ones_like(
|
199
|
+
topk_weights, dtype=torch.float32
|
200
|
+
) # topk_weights must be FP32 (float32)
|
201
|
+
|
189
202
|
return ck_moe_2stages(
|
190
203
|
x,
|
191
204
|
layer.w13_weight,
|
@@ -270,6 +283,7 @@ class FusedMoE(torch.nn.Module):
|
|
270
283
|
top_k: int,
|
271
284
|
hidden_size: int,
|
272
285
|
intermediate_size: int,
|
286
|
+
layer_id: Optional[int] = None,
|
273
287
|
params_dtype: Optional[torch.dtype] = None,
|
274
288
|
reduce_results: bool = False,
|
275
289
|
renormalize: bool = True,
|
sglang/srt/layers/moe/topk.py
CHANGED
@@ -18,7 +18,14 @@ from typing import Callable, Optional
|
|
18
18
|
import torch
|
19
19
|
import torch.nn.functional as F
|
20
20
|
|
21
|
-
from sglang.srt.managers.expert_distribution import
|
21
|
+
from sglang.srt.managers.expert_distribution import (
|
22
|
+
ExpertDistributionRecorder,
|
23
|
+
get_global_expert_distribution_recorder,
|
24
|
+
)
|
25
|
+
from sglang.srt.managers.expert_location_dispatch import (
|
26
|
+
ExpertLocationDispatchInfo,
|
27
|
+
topk_ids_logical_to_physical,
|
28
|
+
)
|
22
29
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
23
30
|
from sglang.srt.utils import get_compiler_backend, is_cuda, is_hip
|
24
31
|
|
@@ -32,9 +39,6 @@ if _is_cuda or _is_hip:
|
|
32
39
|
from sgl_kernel import topk_softmax
|
33
40
|
|
34
41
|
|
35
|
-
expert_distribution_recorder = ExpertDistributionRecorder()
|
36
|
-
|
37
|
-
|
38
42
|
def fused_topk_native(
|
39
43
|
hidden_states: torch.Tensor,
|
40
44
|
gating_output: torch.Tensor,
|
@@ -61,6 +65,7 @@ def fused_topk(
|
|
61
65
|
gating_output: torch.Tensor,
|
62
66
|
topk: int,
|
63
67
|
renormalize: bool,
|
68
|
+
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
64
69
|
):
|
65
70
|
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
66
71
|
|
@@ -84,7 +89,7 @@ def fused_topk(
|
|
84
89
|
|
85
90
|
if renormalize:
|
86
91
|
topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
|
87
|
-
|
92
|
+
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
|
88
93
|
return topk_weights, topk_ids
|
89
94
|
|
90
95
|
|
@@ -99,6 +104,8 @@ def grouped_topk(
|
|
99
104
|
topk_group: int = 0,
|
100
105
|
n_share_experts_fusion: int = 0,
|
101
106
|
routed_scaling_factor: Optional[float] = None,
|
107
|
+
num_token_non_padded: Optional[torch.Tensor] = None,
|
108
|
+
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
102
109
|
):
|
103
110
|
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
104
111
|
|
@@ -138,7 +145,10 @@ def grouped_topk(
|
|
138
145
|
)
|
139
146
|
topk_weights = topk_weights / topk_weights_sum
|
140
147
|
|
141
|
-
|
148
|
+
topk_weights, topk_ids = topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
149
|
+
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
|
150
|
+
_mask_topk_ids_padded_region(topk_ids, num_token_non_padded)
|
151
|
+
return topk_weights, topk_ids
|
142
152
|
|
143
153
|
|
144
154
|
def biased_grouped_topk_impl(
|
@@ -151,6 +161,8 @@ def biased_grouped_topk_impl(
|
|
151
161
|
topk_group: int = 0,
|
152
162
|
n_share_experts_fusion: int = 0,
|
153
163
|
routed_scaling_factor: Optional[float] = None,
|
164
|
+
num_token_non_padded: Optional[torch.Tensor] = None,
|
165
|
+
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
154
166
|
):
|
155
167
|
assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
|
156
168
|
|
@@ -197,13 +209,26 @@ def biased_grouped_topk_impl(
|
|
197
209
|
)
|
198
210
|
topk_weights = topk_weights / topk_weights_sum
|
199
211
|
|
200
|
-
|
212
|
+
topk_weights, topk_ids = topk_weights.to(torch.float32), topk_ids.to(torch.int32)
|
213
|
+
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
|
214
|
+
_mask_topk_ids_padded_region(topk_ids, num_token_non_padded)
|
215
|
+
return topk_weights, topk_ids
|
201
216
|
|
202
217
|
|
203
218
|
def is_power_of_two(n):
|
204
219
|
return n > 0 and math.log2(n).is_integer()
|
205
220
|
|
206
221
|
|
222
|
+
def _mask_topk_ids_padded_region(
|
223
|
+
topk_ids: torch.Tensor,
|
224
|
+
num_token_non_padded: Optional[torch.Tensor] = None,
|
225
|
+
):
|
226
|
+
if num_token_non_padded is None:
|
227
|
+
return
|
228
|
+
indices = torch.arange(0, topk_ids.shape[0], device=topk_ids.device)
|
229
|
+
topk_ids[indices >= num_token_non_padded, :] = -1
|
230
|
+
|
231
|
+
|
207
232
|
def biased_grouped_topk(
|
208
233
|
hidden_states: torch.Tensor,
|
209
234
|
gating_output: torch.Tensor,
|
@@ -215,6 +240,8 @@ def biased_grouped_topk(
|
|
215
240
|
compiled: bool = True,
|
216
241
|
n_share_experts_fusion: int = 0,
|
217
242
|
routed_scaling_factor: Optional[float] = None,
|
243
|
+
num_token_non_padded: Optional[torch.Tensor] = None,
|
244
|
+
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
218
245
|
):
|
219
246
|
assert (
|
220
247
|
routed_scaling_factor is not None
|
@@ -226,7 +253,7 @@ def biased_grouped_topk(
|
|
226
253
|
<= 32 # moe_fused_gate kernel ensure that num_experts/num_expert_group does not exceed MAX_VPT=32 now. And when kernel can handle MAX_VPT > 32, we can remove this assertion.
|
227
254
|
and is_power_of_two(correction_bias.shape[0])
|
228
255
|
):
|
229
|
-
|
256
|
+
topk_weights, topk_ids = moe_fused_gate(
|
230
257
|
gating_output,
|
231
258
|
correction_bias,
|
232
259
|
num_expert_group,
|
@@ -235,6 +262,15 @@ def biased_grouped_topk(
|
|
235
262
|
n_share_experts_fusion,
|
236
263
|
routed_scaling_factor,
|
237
264
|
)
|
265
|
+
# TODO merge into kernel for this branch
|
266
|
+
topk_ids = topk_ids_logical_to_physical(topk_ids, expert_location_dispatch_info)
|
267
|
+
# TODO will fuse this into kernel, thus use slow manual operation now
|
268
|
+
if num_token_non_padded is None:
|
269
|
+
return topk_weights, topk_ids
|
270
|
+
torch.compile(
|
271
|
+
_mask_topk_ids_padded_region, dynamic=True, backend=get_compiler_backend()
|
272
|
+
)(topk_ids, num_token_non_padded)
|
273
|
+
return topk_weights, topk_ids
|
238
274
|
else:
|
239
275
|
biased_grouped_topk_fn = (
|
240
276
|
torch.compile(
|
@@ -253,6 +289,8 @@ def biased_grouped_topk(
|
|
253
289
|
topk_group,
|
254
290
|
n_share_experts_fusion=n_share_experts_fusion,
|
255
291
|
routed_scaling_factor=routed_scaling_factor,
|
292
|
+
num_token_non_padded=num_token_non_padded,
|
293
|
+
expert_location_dispatch_info=expert_location_dispatch_info,
|
256
294
|
)
|
257
295
|
|
258
296
|
|
@@ -268,6 +306,8 @@ def select_experts(
|
|
268
306
|
correction_bias: Optional[torch.Tensor] = None,
|
269
307
|
torch_native: bool = False,
|
270
308
|
routed_scaling_factor: Optional[float] = None,
|
309
|
+
num_token_non_padded: Optional[torch.Tensor] = None,
|
310
|
+
expert_location_dispatch_info: Optional[ExpertLocationDispatchInfo] = None,
|
271
311
|
):
|
272
312
|
n_share_experts_fusion = global_server_args_dict["n_share_experts_fusion"]
|
273
313
|
# DeepSeek V2/V3/R1 series models use grouped_top_k
|
@@ -284,6 +324,8 @@ def select_experts(
|
|
284
324
|
topk_group=topk_group,
|
285
325
|
n_share_experts_fusion=n_share_experts_fusion,
|
286
326
|
routed_scaling_factor=routed_scaling_factor,
|
327
|
+
num_token_non_padded=num_token_non_padded,
|
328
|
+
expert_location_dispatch_info=expert_location_dispatch_info,
|
287
329
|
)
|
288
330
|
else:
|
289
331
|
topk_weights, topk_ids = biased_grouped_topk(
|
@@ -296,8 +338,14 @@ def select_experts(
|
|
296
338
|
topk_group=topk_group,
|
297
339
|
n_share_experts_fusion=n_share_experts_fusion,
|
298
340
|
routed_scaling_factor=routed_scaling_factor,
|
341
|
+
num_token_non_padded=num_token_non_padded,
|
342
|
+
expert_location_dispatch_info=expert_location_dispatch_info,
|
299
343
|
)
|
300
344
|
elif torch_native and custom_routing_function is None:
|
345
|
+
assert (
|
346
|
+
num_token_non_padded is None
|
347
|
+
), "num_token_non_padded is not yet supported in fused_topk_native"
|
348
|
+
assert expert_location_dispatch_info is None
|
301
349
|
topk_weights, topk_ids = fused_topk_native(
|
302
350
|
hidden_states=hidden_states,
|
303
351
|
gating_output=router_logits,
|
@@ -305,13 +353,22 @@ def select_experts(
|
|
305
353
|
renormalize=renormalize,
|
306
354
|
)
|
307
355
|
elif custom_routing_function is None:
|
356
|
+
assert (
|
357
|
+
num_token_non_padded is None
|
358
|
+
), "num_token_non_padded is not yet supported in fused_topk"
|
359
|
+
# Qwen3MOE uses fused_topk
|
308
360
|
topk_weights, topk_ids = fused_topk(
|
309
361
|
hidden_states=hidden_states,
|
310
362
|
gating_output=router_logits,
|
311
363
|
topk=top_k,
|
312
364
|
renormalize=renormalize,
|
365
|
+
expert_location_dispatch_info=expert_location_dispatch_info,
|
313
366
|
)
|
314
367
|
else:
|
368
|
+
assert (
|
369
|
+
num_token_non_padded is None
|
370
|
+
), "num_token_non_padded is not yet supported in custom_routing_function"
|
371
|
+
assert expert_location_dispatch_info is None
|
315
372
|
topk_weights, topk_ids = custom_routing_function(
|
316
373
|
hidden_states=hidden_states,
|
317
374
|
gating_output=router_logits,
|
@@ -319,6 +376,6 @@ def select_experts(
|
|
319
376
|
renormalize=renormalize,
|
320
377
|
)
|
321
378
|
|
322
|
-
|
379
|
+
get_global_expert_distribution_recorder().on_select_experts(topk_ids=topk_ids)
|
323
380
|
|
324
381
|
return topk_weights, topk_ids
|
@@ -0,0 +1,70 @@
|
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
14
|
+
"""Logits processing."""
|
15
|
+
|
16
|
+
import torch
|
17
|
+
import triton
|
18
|
+
import triton.language as tl
|
19
|
+
|
20
|
+
|
21
|
+
@triton.jit
|
22
|
+
def hash_kernel(
|
23
|
+
input_ptr,
|
24
|
+
output_ptr,
|
25
|
+
n_elements,
|
26
|
+
BLOCK_SIZE: tl.constexpr,
|
27
|
+
PRIME: tl.constexpr,
|
28
|
+
XCONST: tl.constexpr,
|
29
|
+
):
|
30
|
+
pid = tl.program_id(axis=0)
|
31
|
+
block_start = pid * BLOCK_SIZE
|
32
|
+
offsets = block_start + tl.arange(0, BLOCK_SIZE)
|
33
|
+
mask = offsets < n_elements
|
34
|
+
|
35
|
+
data = tl.load(input_ptr + offsets, mask=mask, other=0)
|
36
|
+
mixed = data ^ (offsets + XCONST)
|
37
|
+
hash_val = mixed * PRIME
|
38
|
+
hash_val = hash_val ^ (hash_val >> 16)
|
39
|
+
hash_val = hash_val * (PRIME ^ XCONST)
|
40
|
+
hash_val = hash_val ^ (hash_val >> 13)
|
41
|
+
|
42
|
+
tl.store(output_ptr + offsets, hash_val, mask=mask)
|
43
|
+
|
44
|
+
|
45
|
+
PRIME_1 = -(11400714785074694791 ^ 0xFFFFFFFFFFFFFFFF) - 1
|
46
|
+
PRIME_2 = -(14029467366897019727 ^ 0xFFFFFFFFFFFFFFFF) - 1
|
47
|
+
|
48
|
+
|
49
|
+
def gpu_tensor_hash(tensor: torch.Tensor) -> int:
|
50
|
+
assert tensor.is_cuda
|
51
|
+
tensor = tensor.contiguous().view(torch.int32)
|
52
|
+
n = tensor.numel()
|
53
|
+
BLOCK_SIZE = 1024
|
54
|
+
grid = (triton.cdiv(n, BLOCK_SIZE),)
|
55
|
+
|
56
|
+
intermediate_hashes = torch.empty(n, dtype=torch.int32, device=tensor.device)
|
57
|
+
|
58
|
+
hash_kernel[grid](
|
59
|
+
tensor,
|
60
|
+
intermediate_hashes,
|
61
|
+
n,
|
62
|
+
BLOCK_SIZE=BLOCK_SIZE,
|
63
|
+
PRIME=PRIME_1,
|
64
|
+
XCONST=PRIME_2,
|
65
|
+
)
|
66
|
+
|
67
|
+
# TODO: threads can't be synced on triton kernel
|
68
|
+
final_hash = intermediate_hashes.sum().item()
|
69
|
+
|
70
|
+
return final_hash
|
@@ -25,7 +25,6 @@ try:
|
|
25
25
|
from vllm.model_executor.layers.quantization.gptq import GPTQLinearMethod
|
26
26
|
from vllm.model_executor.layers.quantization.gptq_marlin import (
|
27
27
|
GPTQMarlinLinearMethod,
|
28
|
-
GPTQMarlinMoEMethod,
|
29
28
|
)
|
30
29
|
from vllm.model_executor.layers.quantization.gptq_marlin_24 import (
|
31
30
|
GPTQMarlin24Config,
|
@@ -58,12 +57,17 @@ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import
|
|
58
57
|
CompressedTensorsConfig,
|
59
58
|
)
|
60
59
|
from sglang.srt.layers.quantization.fp8 import Fp8Config
|
61
|
-
from sglang.srt.layers.quantization.gptq import
|
60
|
+
from sglang.srt.layers.quantization.gptq import (
|
61
|
+
GPTQConfig,
|
62
|
+
GPTQMarlinConfig,
|
63
|
+
GPTQMarlinMoEMethod,
|
64
|
+
)
|
62
65
|
from sglang.srt.layers.quantization.modelopt_quant import (
|
63
66
|
ModelOptFp4Config,
|
64
67
|
ModelOptFp8Config,
|
65
68
|
)
|
66
69
|
from sglang.srt.layers.quantization.moe_wna16 import MoeWNA16Config
|
70
|
+
from sglang.srt.layers.quantization.qoq import QoQConfig
|
67
71
|
from sglang.srt.layers.quantization.w8a8_fp8 import W8A8Fp8Config
|
68
72
|
from sglang.srt.layers.quantization.w8a8_int8 import W8A8Int8Config
|
69
73
|
|
@@ -77,6 +81,7 @@ BASE_QUANTIZATION_METHODS: Dict[str, Type[QuantizationConfig]] = {
|
|
77
81
|
"w8a8_fp8": W8A8Fp8Config,
|
78
82
|
"moe_wna16": MoeWNA16Config,
|
79
83
|
"compressed-tensors": CompressedTensorsConfig,
|
84
|
+
"qoq": QoQConfig,
|
80
85
|
}
|
81
86
|
|
82
87
|
# VLLM-dependent quantization methods
|
@@ -11,8 +11,10 @@ from tqdm.contrib.concurrent import thread_map
|
|
11
11
|
from sglang.srt.server_args import ServerArgs
|
12
12
|
from sglang.srt.utils import get_bool_env_var, get_device_sm, get_int_env_var, is_cuda
|
13
13
|
|
14
|
+
logger = logging.getLogger(__name__)
|
14
15
|
_ENABLE_JIT_DEEPGEMM = False
|
15
|
-
|
16
|
+
|
17
|
+
try:
|
16
18
|
import deep_gemm
|
17
19
|
from deep_gemm import get_num_sms
|
18
20
|
from deep_gemm.jit.compiler import get_nvcc_compiler
|
@@ -24,14 +26,14 @@ if is_cuda():
|
|
24
26
|
if sm_version == 90:
|
25
27
|
if get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="true"):
|
26
28
|
_ENABLE_JIT_DEEPGEMM = True
|
29
|
+
except ImportError:
|
30
|
+
logger.warning("Failed to import deepgemm, disable _ENABLE_JIT_DEEPGEMM.")
|
27
31
|
|
28
32
|
|
29
33
|
def get_enable_jit_deepgemm():
|
30
34
|
return _ENABLE_JIT_DEEPGEMM
|
31
35
|
|
32
36
|
|
33
|
-
logger = logging.getLogger(__name__)
|
34
|
-
|
35
37
|
_BUILTIN_M_LIST = list(range(1, 1024 * 16 + 1))
|
36
38
|
_ENABLE_JIT_DEEPGEMM_PRECOMPILE = get_bool_env_var(
|
37
39
|
"SGL_JIT_DEEPGEMM_PRECOMPILE", "true"
|
@@ -52,6 +52,7 @@ from sglang.srt.layers.quantization.fp8_utils import (
|
|
52
52
|
apply_w8a8_block_fp8_linear,
|
53
53
|
cutlass_fp8_supported,
|
54
54
|
input_to_float8,
|
55
|
+
is_sm100_supported,
|
55
56
|
normalize_e4m3fn_to_e4m3fnuz,
|
56
57
|
)
|
57
58
|
from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
|
@@ -470,6 +471,7 @@ class Fp8MoEMethod:
|
|
470
471
|
def __init__(self, quant_config):
|
471
472
|
self.quant_config = quant_config
|
472
473
|
self.block_quant = self.quant_config.weight_block_size is not None
|
474
|
+
self.cutlass_fp8_supported = cutlass_fp8_supported()
|
473
475
|
|
474
476
|
def create_weights(
|
475
477
|
self,
|
@@ -568,6 +570,63 @@ class Fp8MoEMethod:
|
|
568
570
|
layer.register_parameter("w13_weight_scale_inv", w13_weight_scale)
|
569
571
|
layer.register_parameter("w2_weight_scale_inv", w2_weight_scale)
|
570
572
|
assert self.quant_config.activation_scheme == "dynamic"
|
573
|
+
if (
|
574
|
+
get_bool_env_var("CUTLASS_MOE")
|
575
|
+
and self.cutlass_fp8_supported
|
576
|
+
and is_sm100_supported()
|
577
|
+
):
|
578
|
+
self.ab_strides1 = torch.full(
|
579
|
+
(num_experts,),
|
580
|
+
hidden_size,
|
581
|
+
device=w13_weight.device,
|
582
|
+
dtype=torch.int64,
|
583
|
+
)
|
584
|
+
self.c_strides1 = torch.full(
|
585
|
+
(num_experts,),
|
586
|
+
2 * intermediate_size,
|
587
|
+
device=w13_weight.device,
|
588
|
+
dtype=torch.int64,
|
589
|
+
)
|
590
|
+
self.ab_strides2 = torch.full(
|
591
|
+
(num_experts,),
|
592
|
+
intermediate_size,
|
593
|
+
device=w2_weight.device,
|
594
|
+
dtype=torch.int64,
|
595
|
+
)
|
596
|
+
self.c_strides2 = torch.full(
|
597
|
+
(num_experts,),
|
598
|
+
hidden_size,
|
599
|
+
device=w2_weight.device,
|
600
|
+
dtype=torch.int64,
|
601
|
+
)
|
602
|
+
self.workspace = torch.empty(
|
603
|
+
90000, device=w13_weight.device, dtype=torch.uint8
|
604
|
+
)
|
605
|
+
self.a_ptr = torch.empty(
|
606
|
+
num_experts, device=w13_weight.device, dtype=torch.int64
|
607
|
+
)
|
608
|
+
self.b_ptr = torch.empty(
|
609
|
+
num_experts, device=w13_weight.device, dtype=torch.int64
|
610
|
+
)
|
611
|
+
self.out_ptr = torch.empty(
|
612
|
+
num_experts, device=w13_weight.device, dtype=torch.int64
|
613
|
+
)
|
614
|
+
self.a_scales_ptr = torch.empty(
|
615
|
+
num_experts, device=w13_weight.device, dtype=torch.int64
|
616
|
+
)
|
617
|
+
self.b_scales_ptr = torch.empty(
|
618
|
+
num_experts, device=w13_weight.device, dtype=torch.int64
|
619
|
+
)
|
620
|
+
self.expert_offsets = torch.empty(
|
621
|
+
num_experts + 1, device=w13_weight.device, dtype=torch.int32
|
622
|
+
)
|
623
|
+
self.problem_sizes1 = torch.empty(
|
624
|
+
num_experts, 3, device=w13_weight.device, dtype=torch.int32
|
625
|
+
)
|
626
|
+
self.problem_sizes2 = torch.empty(
|
627
|
+
num_experts, 3, device=w13_weight.device, dtype=torch.int32
|
628
|
+
)
|
629
|
+
|
571
630
|
else:
|
572
631
|
# Allocate 2 scales for w1 and w3 respectively.
|
573
632
|
# They will be combined to a single scale after weight loading.
|
@@ -913,6 +972,37 @@ class Fp8MoEMethod:
|
|
913
972
|
if ret is not None:
|
914
973
|
return ret
|
915
974
|
|
975
|
+
if (
|
976
|
+
get_bool_env_var("CUTLASS_MOE")
|
977
|
+
and self.cutlass_fp8_supported
|
978
|
+
and self.block_quant
|
979
|
+
and is_sm100_supported()
|
980
|
+
):
|
981
|
+
from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts
|
982
|
+
|
983
|
+
return cutlass_fused_experts(
|
984
|
+
x,
|
985
|
+
layer.w13_weight.transpose(1, 2),
|
986
|
+
layer.w2_weight.transpose(1, 2),
|
987
|
+
layer.w13_weight_scale_inv.transpose(1, 2),
|
988
|
+
layer.w2_weight_scale_inv.transpose(1, 2),
|
989
|
+
topk_weights,
|
990
|
+
topk_ids,
|
991
|
+
self.ab_strides1,
|
992
|
+
self.c_strides1,
|
993
|
+
self.ab_strides2,
|
994
|
+
self.c_strides2,
|
995
|
+
self.workspace,
|
996
|
+
self.a_ptr,
|
997
|
+
self.b_ptr,
|
998
|
+
self.out_ptr,
|
999
|
+
self.a_scales_ptr,
|
1000
|
+
self.b_scales_ptr,
|
1001
|
+
self.expert_offsets,
|
1002
|
+
self.problem_sizes1,
|
1003
|
+
self.problem_sizes2,
|
1004
|
+
use_fp8_blockscale=True,
|
1005
|
+
)
|
916
1006
|
# Expert fusion with FP8 quantization
|
917
1007
|
return fused_experts(
|
918
1008
|
x,
|
@@ -80,6 +80,12 @@ def cutlass_fp8_supported():
|
|
80
80
|
return False
|
81
81
|
|
82
82
|
|
83
|
+
def is_sm100_supported(device=None) -> bool:
|
84
|
+
return (torch.cuda.get_device_capability(device)[0] == 10) and (
|
85
|
+
torch.version.cuda >= "12.8"
|
86
|
+
)
|
87
|
+
|
88
|
+
|
83
89
|
def normalize_e4m3fn_to_e4m3fnuz(
|
84
90
|
weight: torch.Tensor,
|
85
91
|
weight_scale: torch.Tensor,
|