sglang 0.4.10__py3-none-any.whl → 0.4.10.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +20 -0
- sglang/compile_deep_gemm.py +8 -1
- sglang/global_config.py +5 -1
- sglang/srt/configs/model_config.py +1 -0
- sglang/srt/conversation.py +0 -112
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +1 -0
- sglang/srt/disaggregation/launch_lb.py +5 -20
- sglang/srt/disaggregation/mooncake/conn.py +33 -15
- sglang/srt/disaggregation/prefill.py +1 -0
- sglang/srt/distributed/device_communicators/pynccl.py +7 -0
- sglang/srt/distributed/device_communicators/pynccl_allocator.py +133 -0
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +42 -3
- sglang/srt/distributed/parallel_state.py +11 -0
- sglang/srt/entrypoints/engine.py +4 -2
- sglang/srt/entrypoints/http_server.py +35 -15
- sglang/srt/eplb/expert_distribution.py +4 -2
- sglang/srt/hf_transformers_utils.py +25 -10
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/flashattention_backend.py +7 -11
- sglang/srt/layers/attention/trtllm_mla_backend.py +372 -0
- sglang/srt/layers/attention/utils.py +6 -1
- sglang/srt/layers/attention/vision.py +27 -10
- sglang/srt/layers/communicator.py +14 -4
- sglang/srt/layers/linear.py +7 -1
- sglang/srt/layers/logits_processor.py +9 -1
- sglang/srt/layers/moe/ep_moe/layer.py +29 -68
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=352,device_name=NVIDIA_RTX_6000_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +82 -25
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +0 -31
- sglang/srt/layers/moe/token_dispatcher/__init__.py +23 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +12 -1
- sglang/srt/layers/moe/{ep_moe/token_dispatcher.py → token_dispatcher/deepep.py} +8 -15
- sglang/srt/layers/moe/utils.py +43 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +3 -2
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +1 -1
- sglang/srt/layers/quantization/fp8.py +57 -1
- sglang/srt/layers/quantization/fp8_kernel.py +0 -4
- sglang/srt/layers/quantization/w8a8_int8.py +4 -1
- sglang/srt/layers/vocab_parallel_embedding.py +7 -1
- sglang/srt/lora/lora_registry.py +7 -0
- sglang/srt/managers/cache_controller.py +43 -39
- sglang/srt/managers/data_parallel_controller.py +52 -2
- sglang/srt/managers/io_struct.py +6 -1
- sglang/srt/managers/schedule_batch.py +3 -2
- sglang/srt/managers/schedule_policy.py +3 -1
- sglang/srt/managers/scheduler.py +145 -6
- sglang/srt/managers/template_manager.py +25 -22
- sglang/srt/managers/tokenizer_manager.py +114 -62
- sglang/srt/managers/utils.py +45 -1
- sglang/srt/mem_cache/cpp_radix_tree/radix_tree.py +182 -0
- sglang/srt/mem_cache/hicache_storage.py +13 -12
- sglang/srt/mem_cache/hiradix_cache.py +21 -4
- sglang/srt/mem_cache/memory_pool.py +15 -118
- sglang/srt/mem_cache/memory_pool_host.py +350 -33
- sglang/srt/mem_cache/radix_cache_cpp.py +229 -0
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +8 -2
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp +35 -0
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +163 -0
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +238 -0
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +216 -0
- sglang/srt/model_executor/cuda_graph_runner.py +42 -4
- sglang/srt/model_executor/forward_batch_info.py +13 -3
- sglang/srt/model_executor/model_runner.py +13 -1
- sglang/srt/model_loader/weight_utils.py +2 -0
- sglang/srt/models/deepseek_v2.py +28 -23
- sglang/srt/models/glm4_moe.py +85 -22
- sglang/srt/models/grok.py +3 -3
- sglang/srt/models/llama4.py +13 -2
- sglang/srt/models/mixtral.py +3 -3
- sglang/srt/models/mllama4.py +428 -19
- sglang/srt/models/qwen2_moe.py +1 -4
- sglang/srt/models/qwen3_moe.py +7 -8
- sglang/srt/models/step3_vl.py +1 -4
- sglang/srt/multimodal/processors/base_processor.py +4 -3
- sglang/srt/multimodal/processors/gemma3n.py +0 -7
- sglang/srt/operations_strategy.py +1 -1
- sglang/srt/server_args.py +115 -21
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +18 -0
- sglang/srt/two_batch_overlap.py +6 -4
- sglang/srt/utils.py +4 -24
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +945 -0
- sglang/test/runners.py +2 -2
- sglang/test/test_utils.py +3 -3
- sglang/version.py +1 -1
- {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/METADATA +3 -2
- {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/RECORD +92 -81
- /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/mooncake_store.py +0 -0
- /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/unit_test.py +0 -0
- {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/WHEEL +0 -0
- {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/top_level.txt +0 -0
@@ -1,20 +1,25 @@
|
|
1
1
|
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/layer.py
|
2
2
|
|
3
|
+
import importlib.util
|
3
4
|
import logging
|
4
5
|
from enum import Enum
|
6
|
+
from functools import lru_cache
|
5
7
|
from typing import List, Optional, Tuple
|
6
8
|
|
7
9
|
import torch
|
10
|
+
from packaging import version as pkg_version
|
8
11
|
|
9
12
|
from sglang.srt.distributed import (
|
10
13
|
get_moe_expert_parallel_rank,
|
11
14
|
get_moe_expert_parallel_world_size,
|
12
15
|
get_moe_tensor_parallel_rank,
|
13
16
|
get_moe_tensor_parallel_world_size,
|
14
|
-
|
15
|
-
get_tensor_model_parallel_world_size,
|
17
|
+
get_tp_group,
|
16
18
|
tensor_model_parallel_all_reduce,
|
17
19
|
)
|
20
|
+
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
|
21
|
+
use_symmetric_memory,
|
22
|
+
)
|
18
23
|
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
|
19
24
|
from sglang.srt.layers.moe.topk import StandardTopKOutput
|
20
25
|
from sglang.srt.layers.quantization.base_config import (
|
@@ -33,6 +38,15 @@ _is_cpu = is_cpu()
|
|
33
38
|
logger = logging.getLogger(__name__)
|
34
39
|
|
35
40
|
|
41
|
+
@lru_cache(maxsize=1)
|
42
|
+
def should_use_flashinfer_trtllm_moe():
|
43
|
+
return global_server_args_dict["enable_flashinfer_trtllm_moe"] and (
|
44
|
+
not importlib.util.find_spec("flashinfer")
|
45
|
+
or pkg_version.parse(__import__("flashinfer").__version__)
|
46
|
+
>= pkg_version.parse("0.2.9rc1")
|
47
|
+
)
|
48
|
+
|
49
|
+
|
36
50
|
class FusedMoeWeightScaleSupported(Enum):
|
37
51
|
TENSOR = "tensor"
|
38
52
|
CHANNEL = "channel"
|
@@ -82,7 +96,6 @@ class FusedMoE(torch.nn.Module):
|
|
82
96
|
no_combine: bool = False,
|
83
97
|
routed_scaling_factor: Optional[float] = None,
|
84
98
|
enable_flashinfer_cutlass_moe: Optional[bool] = False,
|
85
|
-
enable_ep_moe: Optional[bool] = False,
|
86
99
|
):
|
87
100
|
super().__init__()
|
88
101
|
|
@@ -100,7 +113,6 @@ class FusedMoE(torch.nn.Module):
|
|
100
113
|
if enable_flashinfer_cutlass_moe and quant_config is None:
|
101
114
|
logger.warning("Disable flashinfer MoE when quantization config is None.")
|
102
115
|
enable_flashinfer_cutlass_moe = False
|
103
|
-
enable_ep_moe = False
|
104
116
|
|
105
117
|
self.enable_flashinfer_cutlass_moe = enable_flashinfer_cutlass_moe
|
106
118
|
self.moe_ep_size = get_moe_expert_parallel_world_size()
|
@@ -109,7 +121,7 @@ class FusedMoE(torch.nn.Module):
|
|
109
121
|
self.moe_tp_rank = get_moe_tensor_parallel_rank()
|
110
122
|
assert num_experts % self.moe_ep_size == 0
|
111
123
|
self.num_local_experts = num_experts // self.moe_ep_size
|
112
|
-
if
|
124
|
+
if self.moe_ep_size > 1:
|
113
125
|
# TODO(ch-wan): support shared experts fusion
|
114
126
|
# Create a tensor of size num_experts filled with -1
|
115
127
|
self.expert_map_cpu = torch.full((self.num_experts,), -1, dtype=torch.int32)
|
@@ -119,7 +131,8 @@ class FusedMoE(torch.nn.Module):
|
|
119
131
|
* self.num_local_experts : (self.moe_ep_rank + 1)
|
120
132
|
* self.num_local_experts
|
121
133
|
] = torch.arange(0, self.num_local_experts, dtype=torch.int32, device="cpu")
|
122
|
-
|
134
|
+
if not self.enable_flashinfer_cutlass_moe:
|
135
|
+
self.expert_map_gpu = self.expert_map_cpu.to(device="cuda")
|
123
136
|
|
124
137
|
self.routed_scaling_factor = routed_scaling_factor
|
125
138
|
assert intermediate_size % self.moe_tp_size == 0
|
@@ -454,7 +467,7 @@ class FusedMoE(torch.nn.Module):
|
|
454
467
|
)
|
455
468
|
|
456
469
|
# Flashinfer assumes w31 format for w13_weight. Same for the scales.
|
457
|
-
if
|
470
|
+
if should_use_flashinfer_trtllm_moe():
|
458
471
|
shard_id = {"w1": "w3", "w3": "w1", "w2": "w2"}[shard_id]
|
459
472
|
|
460
473
|
WEIGHT_SCALE_SUPPORTED = [e.value for e in FusedMoeWeightScaleSupported]
|
@@ -617,24 +630,27 @@ class FusedMoE(torch.nn.Module):
|
|
617
630
|
)
|
618
631
|
|
619
632
|
# Matrix multiply.
|
620
|
-
|
621
|
-
|
622
|
-
|
623
|
-
|
624
|
-
|
625
|
-
|
626
|
-
|
627
|
-
|
628
|
-
|
629
|
-
|
630
|
-
|
631
|
-
|
632
|
-
|
633
|
-
|
634
|
-
|
635
|
-
|
636
|
-
|
637
|
-
|
633
|
+
with use_symmetric_memory(get_tp_group()) as sm:
|
634
|
+
final_hidden_states = self.quant_method.apply(
|
635
|
+
layer=self,
|
636
|
+
x=hidden_states,
|
637
|
+
topk_output=topk_output,
|
638
|
+
activation=self.activation,
|
639
|
+
apply_router_weight_on_input=self.apply_router_weight_on_input,
|
640
|
+
routed_scaling_factor=self.routed_scaling_factor,
|
641
|
+
**(
|
642
|
+
dict(
|
643
|
+
tp_rank=self.moe_tp_rank,
|
644
|
+
tp_size=self.moe_tp_size,
|
645
|
+
ep_rank=self.moe_ep_rank,
|
646
|
+
ep_size=self.moe_ep_size,
|
647
|
+
)
|
648
|
+
if self.quant_method.__class__.__name__
|
649
|
+
== "ModelOptNvFp4FusedMoEMethod"
|
650
|
+
else {}
|
651
|
+
),
|
652
|
+
)
|
653
|
+
sm.tag(final_hidden_states)
|
638
654
|
|
639
655
|
if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1):
|
640
656
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
@@ -686,3 +702,44 @@ class FusedMoE(torch.nn.Module):
|
|
686
702
|
for expert_id in range(num_experts)
|
687
703
|
for shard_id in ["w1", "w2", "w3"]
|
688
704
|
]
|
705
|
+
|
706
|
+
|
707
|
+
class FlashInferFusedMoE(FusedMoE):
|
708
|
+
def __init__(self, *args, **kwargs):
|
709
|
+
renormalize = kwargs.pop("renormalize", True)
|
710
|
+
num_fused_shared_experts = kwargs.pop("num_fused_shared_experts", 0)
|
711
|
+
use_grouped_topk = kwargs.pop("use_grouped_topk", False)
|
712
|
+
num_expert_group = kwargs.pop("num_expert_group", None)
|
713
|
+
topk_group = kwargs.pop("topk_group", None)
|
714
|
+
correction_bias = kwargs.pop("correction_bias", None)
|
715
|
+
super().__init__(*args, **kwargs)
|
716
|
+
self.renormalize = renormalize
|
717
|
+
self.num_fused_shared_experts = num_fused_shared_experts
|
718
|
+
self.use_grouped_topk = use_grouped_topk
|
719
|
+
if self.use_grouped_topk:
|
720
|
+
assert num_expert_group is not None and topk_group is not None
|
721
|
+
self.num_expert_group = num_expert_group
|
722
|
+
self.topk_group = topk_group
|
723
|
+
self.correction_bias = correction_bias
|
724
|
+
|
725
|
+
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
|
726
|
+
assert self.quant_method is not None
|
727
|
+
assert (
|
728
|
+
self.renormalize
|
729
|
+
), "Renormalize is required for flashinfer blockscale fp8 moe"
|
730
|
+
assert (
|
731
|
+
self.num_fused_shared_experts == 0
|
732
|
+
), "Fused shared experts are not supported for flashinfer blockscale fp8 moe"
|
733
|
+
# Matrix multiply.
|
734
|
+
final_hidden_states = self.quant_method.apply_with_router_logits(
|
735
|
+
layer=self,
|
736
|
+
x=hidden_states,
|
737
|
+
router_logits=router_logits,
|
738
|
+
activation=self.activation,
|
739
|
+
routed_scaling_factor=self.routed_scaling_factor,
|
740
|
+
)
|
741
|
+
|
742
|
+
if self.reduce_results and (self.tp_size > 1 or self.ep_size > 1):
|
743
|
+
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
744
|
+
|
745
|
+
return final_hidden_states
|
@@ -146,34 +146,3 @@ def triton_kernel_fused_experts(
|
|
146
146
|
)
|
147
147
|
|
148
148
|
return intermediate_cache3
|
149
|
-
|
150
|
-
|
151
|
-
def triton_kernel_moe_forward_fake(
|
152
|
-
hidden_states: torch.Tensor,
|
153
|
-
w1: torch.Tensor,
|
154
|
-
w2: torch.Tensor,
|
155
|
-
gating_output: torch.Tensor,
|
156
|
-
topk: int,
|
157
|
-
renormalize: bool,
|
158
|
-
inplace: bool = False,
|
159
|
-
activation: str = "silu",
|
160
|
-
apply_router_weight_on_input: bool = False,
|
161
|
-
use_fp8_w8a8: bool = False,
|
162
|
-
per_channel_quant: bool = False,
|
163
|
-
global_num_experts: int = -1,
|
164
|
-
expert_map: Optional[torch.Tensor] = None,
|
165
|
-
w1_scale: Optional[torch.Tensor] = None,
|
166
|
-
w2_scale: Optional[torch.Tensor] = None,
|
167
|
-
a1_scale: Optional[torch.Tensor] = None,
|
168
|
-
a2_scale: Optional[torch.Tensor] = None,
|
169
|
-
block_shape: Optional[list[int]] = None,
|
170
|
-
) -> torch.Tensor:
|
171
|
-
return torch.empty_like(hidden_states)
|
172
|
-
|
173
|
-
|
174
|
-
direct_register_custom_op(
|
175
|
-
op_name="forward_cuda_triton",
|
176
|
-
op_func=triton_kernel_moe_forward,
|
177
|
-
mutates_args=[],
|
178
|
-
fake_impl=triton_kernel_moe_forward_fake,
|
179
|
-
)
|
@@ -0,0 +1,23 @@
|
|
1
|
+
from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import (
|
2
|
+
BaseDispatcher,
|
3
|
+
BaseDispatcherConfig,
|
4
|
+
DispatchOutput,
|
5
|
+
DispatchOutputFormat,
|
6
|
+
)
|
7
|
+
from sglang.srt.layers.moe.token_dispatcher.deepep import (
|
8
|
+
DeepEPConfig,
|
9
|
+
DeepEPDispatcher,
|
10
|
+
DeepEPLLOutput,
|
11
|
+
DeepEPNormalOutput,
|
12
|
+
)
|
13
|
+
|
14
|
+
__all__ = [
|
15
|
+
"BaseDispatcher",
|
16
|
+
"BaseDispatcherConfig",
|
17
|
+
"DispatchOutput",
|
18
|
+
"DispatchOutputFormat",
|
19
|
+
"DeepEPConfig",
|
20
|
+
"DeepEPDispatcher",
|
21
|
+
"DeepEPNormalOutput",
|
22
|
+
"DeepEPLLOutput",
|
23
|
+
]
|
@@ -2,11 +2,22 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
from abc import ABC, abstractmethod
|
4
4
|
from enum import Enum, auto
|
5
|
-
from typing import
|
5
|
+
from typing import Protocol, runtime_checkable
|
6
6
|
|
7
7
|
import torch
|
8
8
|
|
9
9
|
|
10
|
+
class MoEA2ABackend(Enum):
|
11
|
+
none = "none"
|
12
|
+
deepep = "deepep"
|
13
|
+
|
14
|
+
def is_none(self):
|
15
|
+
return self == MoEA2ABackend.none
|
16
|
+
|
17
|
+
def is_deepep(self):
|
18
|
+
return self == MoEA2ABackend.deepep
|
19
|
+
|
20
|
+
|
10
21
|
class DispatchOutputFormat(Enum):
|
11
22
|
standard = auto()
|
12
23
|
deepep_normal = auto()
|
@@ -1,5 +1,3 @@
|
|
1
|
-
# TODO(ch-wan): this file will be moved to sglang/srt/layers/moe/token_dispatcher/deepep.py
|
2
|
-
|
3
1
|
from __future__ import annotations
|
4
2
|
|
5
3
|
import logging
|
@@ -22,15 +20,10 @@ from sglang.srt.layers.moe.token_dispatcher.base_dispatcher import (
|
|
22
20
|
DispatchOutput,
|
23
21
|
DispatchOutputFormat,
|
24
22
|
)
|
23
|
+
from sglang.srt.layers.moe.utils import DeepEPMode
|
25
24
|
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
26
25
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
27
|
-
from sglang.srt.utils import
|
28
|
-
DeepEPMode,
|
29
|
-
get_bool_env_var,
|
30
|
-
get_int_env_var,
|
31
|
-
is_hip,
|
32
|
-
load_json_config,
|
33
|
-
)
|
26
|
+
from sglang.srt.utils import get_bool_env_var, get_int_env_var, is_hip, load_json_config
|
34
27
|
|
35
28
|
try:
|
36
29
|
from deep_ep import Buffer, Config
|
@@ -150,9 +143,9 @@ class DeepEPBuffer:
|
|
150
143
|
num_rdma_bytes,
|
151
144
|
)
|
152
145
|
|
153
|
-
if deepep_mode == DeepEPMode.
|
146
|
+
if deepep_mode == DeepEPMode.NORMAL:
|
154
147
|
num_qps_per_rank = DeepEPConfig.get_instance().num_sms // 2
|
155
|
-
elif deepep_mode in [DeepEPMode.
|
148
|
+
elif deepep_mode in [DeepEPMode.LOW_LATENCY, DeepEPMode.AUTO]:
|
156
149
|
num_qps_per_rank = num_experts // group.size()
|
157
150
|
else:
|
158
151
|
raise NotImplementedError
|
@@ -161,7 +154,7 @@ class DeepEPBuffer:
|
|
161
154
|
device="cuda"
|
162
155
|
).multi_processor_count
|
163
156
|
if (
|
164
|
-
(deepep_mode != DeepEPMode.
|
157
|
+
(deepep_mode != DeepEPMode.LOW_LATENCY)
|
165
158
|
and not global_server_args_dict["enable_two_batch_overlap"]
|
166
159
|
and (DeepEPConfig.get_instance().num_sms < total_num_sms // 2)
|
167
160
|
):
|
@@ -611,7 +604,7 @@ class DeepEPDispatcher(BaseDispatcher):
|
|
611
604
|
num_local_experts: int = None,
|
612
605
|
hidden_size: int = None,
|
613
606
|
params_dtype: torch.dtype = None,
|
614
|
-
deepep_mode: DeepEPMode = DeepEPMode.
|
607
|
+
deepep_mode: DeepEPMode = DeepEPMode.AUTO,
|
615
608
|
async_finish: bool = False,
|
616
609
|
return_recv_hook: bool = False,
|
617
610
|
):
|
@@ -697,9 +690,9 @@ class DeepEPDispatcher(BaseDispatcher):
|
|
697
690
|
resolved_deepep_mode = self.deepep_mode.resolve(
|
698
691
|
forward_batch.is_extend_in_batch
|
699
692
|
)
|
700
|
-
if resolved_deepep_mode == DeepEPMode.
|
693
|
+
if resolved_deepep_mode == DeepEPMode.NORMAL:
|
701
694
|
return self._normal_dispatcher
|
702
|
-
elif resolved_deepep_mode == DeepEPMode.
|
695
|
+
elif resolved_deepep_mode == DeepEPMode.LOW_LATENCY:
|
703
696
|
return self._low_latency_dispatcher
|
704
697
|
else:
|
705
698
|
raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}")
|
@@ -0,0 +1,43 @@
|
|
1
|
+
from enum import Enum
|
2
|
+
|
3
|
+
|
4
|
+
class MoeA2ABackend(Enum):
|
5
|
+
|
6
|
+
STANDARD = ("standard", "none")
|
7
|
+
DEEPEP = "deepep"
|
8
|
+
|
9
|
+
@classmethod
|
10
|
+
def _missing_(cls, value):
|
11
|
+
if value is None:
|
12
|
+
return cls.STANDARD
|
13
|
+
for member in cls:
|
14
|
+
if value in member.value:
|
15
|
+
return member
|
16
|
+
raise ValueError(f"No {cls.__name__} member for value {value}")
|
17
|
+
|
18
|
+
def is_deepep(self):
|
19
|
+
return self == MoeA2ABackend.DEEPEP
|
20
|
+
|
21
|
+
def is_standard(self):
|
22
|
+
return self == MoeA2ABackend.STANDARD
|
23
|
+
|
24
|
+
|
25
|
+
class DeepEPMode(Enum):
|
26
|
+
NORMAL = "normal"
|
27
|
+
LOW_LATENCY = "low_latency"
|
28
|
+
AUTO = "auto"
|
29
|
+
|
30
|
+
def enable_normal(self):
|
31
|
+
return self in [DeepEPMode.NORMAL, DeepEPMode.AUTO]
|
32
|
+
|
33
|
+
def enable_low_latency(self):
|
34
|
+
return self in [DeepEPMode.LOW_LATENCY, DeepEPMode.AUTO]
|
35
|
+
|
36
|
+
def resolve(self, is_extend_in_batch: bool):
|
37
|
+
if self != DeepEPMode.AUTO:
|
38
|
+
return self
|
39
|
+
|
40
|
+
if is_extend_in_batch:
|
41
|
+
return DeepEPMode.NORMAL
|
42
|
+
else:
|
43
|
+
return DeepEPMode.LOW_LATENCY
|
@@ -23,6 +23,7 @@ from sglang.srt.layers.quantization.utils import (
|
|
23
23
|
from sglang.srt.utils import is_cpu, is_cuda, is_hip, is_npu, set_weight_attrs
|
24
24
|
|
25
25
|
if TYPE_CHECKING:
|
26
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
26
27
|
from sglang.srt.layers.moe.topk import TopKOutput
|
27
28
|
from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors import (
|
28
29
|
CompressedTensorsConfig,
|
@@ -189,7 +190,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|
189
190
|
layer.w13_input_scale = None
|
190
191
|
layer.w2_input_scale = None
|
191
192
|
|
192
|
-
def process_weights_after_loading(self, layer:
|
193
|
+
def process_weights_after_loading(self, layer: FusedMoE) -> None:
|
193
194
|
# Fp8 moe kernels require a single activation scale.
|
194
195
|
# We take the max of all the scales in case they differ.
|
195
196
|
if self.static_input_scales:
|
@@ -246,7 +247,7 @@ class CompressedTensorsW8A8Fp8MoEMethod(CompressedTensorsMoEMethod):
|
|
246
247
|
assert layer.w13_weight_scale is not None
|
247
248
|
shard_size = layer.intermediate_size_per_partition
|
248
249
|
max_w13_scales = layer.w13_weight_scale.max(dim=1).values
|
249
|
-
for expert_id in range(layer.
|
250
|
+
for expert_id in range(layer.num_local_experts):
|
250
251
|
start = 0
|
251
252
|
for shard_id in range(2):
|
252
253
|
dq_weight = per_tensor_dequantize(
|
@@ -148,7 +148,7 @@ def _compile_grouped_gemm_nt_f8f8bf16_masked_one(
|
|
148
148
|
"NUM_MATH_THREADS_PER_GROUP": num_math_threads_per_group,
|
149
149
|
"N": n,
|
150
150
|
"K": k,
|
151
|
-
"NUM_GROUPS":
|
151
|
+
"NUM_GROUPS": num_groups,
|
152
152
|
"BLOCK_M": block_m,
|
153
153
|
"BLOCK_N": block_n,
|
154
154
|
"BLOCK_K": block_k,
|
@@ -72,6 +72,7 @@ from sglang.srt.utils import (
|
|
72
72
|
is_hip,
|
73
73
|
is_npu,
|
74
74
|
log_info_on_rank0,
|
75
|
+
next_power_of_2,
|
75
76
|
print_warning_once,
|
76
77
|
set_weight_attrs,
|
77
78
|
use_intel_amx_backend,
|
@@ -490,6 +491,16 @@ class Fp8LinearMethod(LinearMethodBase):
|
|
490
491
|
)
|
491
492
|
|
492
493
|
|
494
|
+
def get_tile_tokens_dim(num_tokens, top_k, num_experts):
|
495
|
+
# Guess tokens per expert assuming perfect expert distribution first.
|
496
|
+
num_tokens_per_expert = (num_tokens * top_k) // num_experts
|
497
|
+
# And pad the number to the next power of 2.
|
498
|
+
tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
|
499
|
+
# Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel.
|
500
|
+
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
|
501
|
+
return tile_tokens_dim
|
502
|
+
|
503
|
+
|
493
504
|
class Fp8MoEMethod(FusedMoEMethodBase):
|
494
505
|
"""MoE method for FP8.
|
495
506
|
Supports loading FP8 checkpoints with static weight scale and
|
@@ -1028,7 +1039,7 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
1028
1039
|
from sglang.srt.layers.moe.cutlass_moe import cutlass_fused_experts_fp8
|
1029
1040
|
|
1030
1041
|
topk_weights, topk_ids, _ = topk_output
|
1031
|
-
|
1042
|
+
output = cutlass_fused_experts_fp8(
|
1032
1043
|
x,
|
1033
1044
|
layer.w13_weight.transpose(1, 2),
|
1034
1045
|
layer.w2_weight.transpose(1, 2),
|
@@ -1051,6 +1062,10 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
1051
1062
|
self.problem_sizes2,
|
1052
1063
|
use_fp8_blockscale=True,
|
1053
1064
|
)
|
1065
|
+
# TODO: Fuse into select_experts
|
1066
|
+
if routed_scaling_factor is not None:
|
1067
|
+
output *= routed_scaling_factor
|
1068
|
+
return output
|
1054
1069
|
# Expert fusion with FP8 quantization
|
1055
1070
|
return fused_experts(
|
1056
1071
|
x,
|
@@ -1076,6 +1091,47 @@ class Fp8MoEMethod(FusedMoEMethodBase):
|
|
1076
1091
|
routed_scaling_factor=routed_scaling_factor,
|
1077
1092
|
)
|
1078
1093
|
|
1094
|
+
def apply_with_router_logits(
|
1095
|
+
self,
|
1096
|
+
layer: torch.nn.Module,
|
1097
|
+
x: torch.Tensor,
|
1098
|
+
router_logits: torch.Tensor,
|
1099
|
+
*,
|
1100
|
+
activation: str = "silu",
|
1101
|
+
routed_scaling_factor: Optional[float] = None,
|
1102
|
+
) -> torch.Tensor:
|
1103
|
+
assert (
|
1104
|
+
activation == "silu"
|
1105
|
+
), "Only silu is supported for flashinfer blockscale fp8 moe"
|
1106
|
+
a_q, a_sf = per_token_group_quant_fp8(x, self.quant_config.weight_block_size[1])
|
1107
|
+
# NOTE: scales of hidden states have to be transposed!
|
1108
|
+
a_sf_t = a_sf.t().contiguous()
|
1109
|
+
from flashinfer.fused_moe import trtllm_fp8_block_scale_moe
|
1110
|
+
|
1111
|
+
return trtllm_fp8_block_scale_moe(
|
1112
|
+
routing_logits=router_logits.to(torch.float32),
|
1113
|
+
routing_bias=layer.correction_bias.to(x.dtype),
|
1114
|
+
hidden_states=a_q,
|
1115
|
+
hidden_states_scale=a_sf_t,
|
1116
|
+
gemm1_weights=layer.w13_weight,
|
1117
|
+
gemm1_weights_scale=layer.w13_weight_scale_inv,
|
1118
|
+
gemm2_weights=layer.w2_weight,
|
1119
|
+
gemm2_weights_scale=layer.w2_weight_scale_inv,
|
1120
|
+
num_experts=layer.num_experts,
|
1121
|
+
top_k=layer.top_k,
|
1122
|
+
n_group=layer.num_expert_group,
|
1123
|
+
topk_group=layer.topk_group,
|
1124
|
+
intermediate_size=layer.w2_weight.shape[2],
|
1125
|
+
local_expert_offset=layer.moe_ep_rank * layer.num_local_experts,
|
1126
|
+
local_num_experts=layer.num_local_experts,
|
1127
|
+
routed_scaling_factor=routed_scaling_factor,
|
1128
|
+
tile_tokens_dim=get_tile_tokens_dim(
|
1129
|
+
x.shape[0], layer.top_k, layer.num_experts
|
1130
|
+
),
|
1131
|
+
routing_method_type=2, # DeepSeek-styled routing method
|
1132
|
+
use_shuffled_weight=False,
|
1133
|
+
)
|
1134
|
+
|
1079
1135
|
def maybe_apply_hip_fused_experts(
|
1080
1136
|
self,
|
1081
1137
|
layer: torch.nn.Module,
|
@@ -354,10 +354,6 @@ def sglang_per_token_group_quant_fp8(
|
|
354
354
|
), "the last dimension of `x` cannot be divisible by `group_size`"
|
355
355
|
assert x.is_contiguous(), "`x` is not contiguous"
|
356
356
|
|
357
|
-
if scale_ue8m0:
|
358
|
-
# TODO: handle this case by fixing the (token=4, dim=256, group_size=128) UT case
|
359
|
-
assert x.shape[-1] % (group_size * 4) == 0
|
360
|
-
|
361
357
|
x_q = torch.empty_like(x, device=x.device, dtype=fp8_dtype)
|
362
358
|
x_s = create_per_token_group_quant_fp8_output_scale(
|
363
359
|
x_shape=x.shape,
|
@@ -231,7 +231,10 @@ class W8A8Int8Config(QuantizationConfig):
|
|
231
231
|
|
232
232
|
@classmethod
|
233
233
|
def get_config_filenames(cls) -> List[str]:
|
234
|
-
|
234
|
+
filenames = []
|
235
|
+
if _is_npu:
|
236
|
+
filenames.append("quant_model_description.json")
|
237
|
+
return filenames
|
235
238
|
|
236
239
|
@classmethod
|
237
240
|
def from_config(cls, config: Dict[str, Any]) -> W8A8Int8Config:
|
@@ -11,8 +11,12 @@ from sglang.srt.distributed import (
|
|
11
11
|
divide,
|
12
12
|
get_tensor_model_parallel_rank,
|
13
13
|
get_tensor_model_parallel_world_size,
|
14
|
+
parallel_state,
|
14
15
|
tensor_model_parallel_all_reduce,
|
15
16
|
)
|
17
|
+
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
|
18
|
+
use_symmetric_memory,
|
19
|
+
)
|
16
20
|
from sglang.srt.layers.amx_utils import PackWeightMethod
|
17
21
|
from sglang.srt.layers.dp_attention import get_attention_tp_rank, get_attention_tp_size
|
18
22
|
from sglang.srt.layers.parameter import BasevLLMParameter
|
@@ -464,7 +468,9 @@ class VocabParallelEmbedding(torch.nn.Module):
|
|
464
468
|
else:
|
465
469
|
masked_input = input_
|
466
470
|
# Get the embeddings.
|
467
|
-
|
471
|
+
with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
|
472
|
+
output_parallel = self.quant_method.embedding(self, masked_input.long())
|
473
|
+
sm.tag(output_parallel)
|
468
474
|
# Mask the output embedding.
|
469
475
|
if self.tp_size > 1:
|
470
476
|
output_parallel.masked_fill_(input_mask.unsqueeze(-1), 0)
|
sglang/srt/lora/lora_registry.py
CHANGED
@@ -186,3 +186,10 @@ class LoRARegistry:
|
|
186
186
|
self._registry[lora_ref.lora_name] = lora_ref
|
187
187
|
self._counters[lora_ref.lora_id] = ConcurrentCounter()
|
188
188
|
return lora_ref
|
189
|
+
|
190
|
+
@property
|
191
|
+
def num_registered_loras(self) -> int:
|
192
|
+
"""
|
193
|
+
Returns the total number of LoRA adapters currently registered.
|
194
|
+
"""
|
195
|
+
return len(self._registry)
|