sglang 0.4.9.post6__py3-none-any.whl → 0.4.10.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +20 -0
- sglang/bench_one_batch.py +3 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/model_config.py +4 -0
- sglang/srt/configs/step3_vl.py +172 -0
- sglang/srt/conversation.py +23 -0
- sglang/srt/disaggregation/decode.py +2 -8
- sglang/srt/disaggregation/launch_lb.py +5 -20
- sglang/srt/disaggregation/mooncake/conn.py +33 -15
- sglang/srt/disaggregation/prefill.py +2 -6
- sglang/srt/distributed/parallel_state.py +86 -1
- sglang/srt/entrypoints/engine.py +14 -18
- sglang/srt/entrypoints/http_server.py +10 -2
- sglang/srt/entrypoints/openai/serving_chat.py +2 -21
- sglang/srt/eplb/expert_distribution.py +5 -0
- sglang/srt/eplb/expert_location.py +17 -6
- sglang/srt/eplb/expert_location_dispatch.py +1 -0
- sglang/srt/eplb/expert_location_updater.py +2 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/step3_detector.py +436 -0
- sglang/srt/hf_transformers_utils.py +2 -0
- sglang/srt/jinja_template_utils.py +4 -1
- sglang/srt/layers/attention/trtllm_mla_backend.py +372 -0
- sglang/srt/layers/attention/utils.py +6 -1
- sglang/srt/layers/moe/cutlass_moe.py +2 -1
- sglang/srt/layers/moe/ep_moe/layer.py +39 -674
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +26 -13
- sglang/srt/layers/moe/fused_moe_triton/layer.py +152 -39
- sglang/srt/layers/quantization/fp8.py +52 -18
- sglang/srt/layers/quantization/unquant.py +0 -8
- sglang/srt/layers/quantization/w4afp8.py +1 -0
- sglang/srt/layers/quantization/w8a8_int8.py +4 -1
- sglang/srt/managers/cache_controller.py +165 -67
- sglang/srt/managers/data_parallel_controller.py +2 -0
- sglang/srt/managers/io_struct.py +0 -2
- sglang/srt/managers/scheduler.py +90 -671
- sglang/srt/managers/scheduler_metrics_mixin.py +229 -0
- sglang/srt/managers/scheduler_profiler_mixin.py +279 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +142 -0
- sglang/srt/managers/template_manager.py +62 -19
- sglang/srt/managers/tokenizer_manager.py +123 -74
- sglang/srt/managers/tp_worker.py +4 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +2 -1
- sglang/srt/mem_cache/hicache_storage.py +60 -17
- sglang/srt/mem_cache/hiradix_cache.py +36 -8
- sglang/srt/mem_cache/memory_pool.py +15 -118
- sglang/srt/mem_cache/memory_pool_host.py +418 -29
- sglang/srt/mem_cache/mooncake_store/mooncake_store.py +264 -0
- sglang/srt/mem_cache/mooncake_store/unit_test.py +40 -0
- sglang/srt/mem_cache/nixl/hicache_nixl.py +163 -0
- sglang/srt/mem_cache/nixl/nixl_utils.py +238 -0
- sglang/srt/mem_cache/nixl/test_hicache_nixl_storage.py +216 -0
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +183 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +278 -0
- sglang/srt/mem_cache/storage/hf3fs/test_hf3fs_utils.py +43 -0
- sglang/srt/model_executor/cuda_graph_runner.py +25 -1
- sglang/srt/model_executor/model_runner.py +13 -1
- sglang/srt/model_loader/weight_utils.py +2 -0
- sglang/srt/models/arcee.py +532 -0
- sglang/srt/models/deepseek_v2.py +7 -6
- sglang/srt/models/glm4_moe.py +6 -4
- sglang/srt/models/granitemoe.py +3 -0
- sglang/srt/models/grok.py +3 -0
- sglang/srt/models/hunyuan.py +1 -0
- sglang/srt/models/llama4.py +3 -0
- sglang/srt/models/mixtral.py +3 -0
- sglang/srt/models/olmoe.py +3 -0
- sglang/srt/models/phimoe.py +1 -0
- sglang/srt/models/step3_vl.py +991 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -16
- sglang/srt/multimodal/processors/step3_vl.py +515 -0
- sglang/srt/reasoning_parser.py +2 -1
- sglang/srt/server_args.py +49 -18
- sglang/srt/speculative/eagle_worker.py +2 -0
- sglang/srt/utils.py +1 -0
- sglang/test/attention/test_trtllm_mla_backend.py +945 -0
- sglang/utils.py +0 -11
- sglang/version.py +1 -1
- {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/METADATA +3 -4
- {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/RECORD +83 -65
- {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post6.dist-info → sglang-0.4.10.post1.dist-info}/top_level.txt +0 -0
@@ -25,14 +25,22 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
|
|
25
25
|
silu_and_mul_triton_kernel,
|
26
26
|
tma_align_input_scale,
|
27
27
|
)
|
28
|
-
from sglang.srt.layers.moe.fused_moe_triton.layer import
|
28
|
+
from sglang.srt.layers.moe.fused_moe_triton.layer import (
|
29
|
+
FlashInferFusedMoE,
|
30
|
+
FusedMoE,
|
31
|
+
should_use_flashinfer_trtllm_moe,
|
32
|
+
)
|
29
33
|
from sglang.srt.layers.moe.topk import TopKOutput
|
30
34
|
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
31
35
|
from sglang.srt.layers.quantization.base_config import (
|
32
36
|
QuantizationConfig,
|
33
37
|
QuantizeMethodBase,
|
34
38
|
)
|
35
|
-
from sglang.srt.layers.quantization.fp8 import
|
39
|
+
from sglang.srt.layers.quantization.fp8 import (
|
40
|
+
Fp8Config,
|
41
|
+
Fp8MoEMethod,
|
42
|
+
get_tile_tokens_dim,
|
43
|
+
)
|
36
44
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
37
45
|
is_fp8_fnuz,
|
38
46
|
sglang_per_token_group_quant_fp8,
|
@@ -49,7 +57,6 @@ from sglang.srt.utils import (
|
|
49
57
|
get_bool_env_var,
|
50
58
|
is_hip,
|
51
59
|
is_npu,
|
52
|
-
next_power_of_2,
|
53
60
|
)
|
54
61
|
|
55
62
|
if TYPE_CHECKING:
|
@@ -63,10 +70,7 @@ _is_hip = is_hip()
|
|
63
70
|
_is_npu = is_npu()
|
64
71
|
_is_fp8_fnuz = is_fp8_fnuz()
|
65
72
|
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
66
|
-
|
67
|
-
global_server_args_dict["enable_flashinfer_trtllm_moe"]
|
68
|
-
and global_server_args_dict["enable_ep_moe"]
|
69
|
-
)
|
73
|
+
|
70
74
|
|
71
75
|
if not (_is_npu or _is_hip):
|
72
76
|
from sgl_kernel import silu_and_mul
|
@@ -76,99 +80,9 @@ if _use_aiter:
|
|
76
80
|
from aiter.fused_moe import fused_moe
|
77
81
|
from aiter.ops.shuffle import shuffle_weight
|
78
82
|
|
79
|
-
if use_flashinfer_trtllm_moe:
|
80
|
-
try:
|
81
|
-
import flashinfer.fused_moe as fi_fused_moe
|
82
|
-
except ImportError:
|
83
|
-
fi_fused_moe = None
|
84
|
-
use_flashinfer_trtllm_moe = False
|
85
|
-
|
86
83
|
logger = logging.getLogger(__name__)
|
87
84
|
|
88
85
|
|
89
|
-
class GroupedGemmRunner(torch.nn.Module):
|
90
|
-
flashinfer_gemm_warpper = None
|
91
|
-
|
92
|
-
def __init__(
|
93
|
-
self,
|
94
|
-
device,
|
95
|
-
use_flashinfer: bool = False,
|
96
|
-
use_per_token_if_dynamic: bool = True,
|
97
|
-
):
|
98
|
-
super().__init__()
|
99
|
-
self.device = device
|
100
|
-
self.use_flashinfer = use_flashinfer
|
101
|
-
self.use_per_token_if_dynamic = use_per_token_if_dynamic
|
102
|
-
if self.use_flashinfer and GroupedGemmRunner.flashinfer_gemm_warpper is None:
|
103
|
-
GroupedGemmRunner._init_flashinfer_wrapper(device)
|
104
|
-
|
105
|
-
@classmethod
|
106
|
-
def _init_flashinfer_wrapper(cls, device):
|
107
|
-
from flashinfer import SegmentGEMMWrapper
|
108
|
-
|
109
|
-
workspace_buffer = torch.empty(
|
110
|
-
128 * 1024 * 1024, dtype=torch.int8, device=device
|
111
|
-
)
|
112
|
-
cls.flashinfer_gemm_warpper = SegmentGEMMWrapper(workspace_buffer)
|
113
|
-
|
114
|
-
# c = a * b
|
115
|
-
def forward(
|
116
|
-
self,
|
117
|
-
a: torch.Tensor,
|
118
|
-
b: torch.Tensor,
|
119
|
-
c: torch.Tensor,
|
120
|
-
batch_size: int,
|
121
|
-
weight_column_major: bool,
|
122
|
-
seg_indptr: Optional[torch.Tensor] = None,
|
123
|
-
weight_indices: Optional[torch.Tensor] = None,
|
124
|
-
use_fp8_w8a8: bool = False,
|
125
|
-
scale_a: torch.Tensor = None,
|
126
|
-
scale_b: torch.Tensor = None,
|
127
|
-
block_shape: Optional[List[int]] = None,
|
128
|
-
c_dtype=None,
|
129
|
-
):
|
130
|
-
if self.use_flashinfer:
|
131
|
-
# TODO: flashinfer
|
132
|
-
assert False
|
133
|
-
assert GroupedGemmRunner.flashinfer_gemm_warpper is not None
|
134
|
-
c = GroupedGemmRunner.flashinfer_gemm_warpper.run(
|
135
|
-
x=a,
|
136
|
-
weights=b,
|
137
|
-
batch_size=batch_size,
|
138
|
-
weight_column_major=weight_column_major,
|
139
|
-
seg_indptr=seg_indptr,
|
140
|
-
weight_indices=weight_indices,
|
141
|
-
)
|
142
|
-
else:
|
143
|
-
assert weight_column_major == True
|
144
|
-
c = grouped_gemm_triton(
|
145
|
-
a,
|
146
|
-
b,
|
147
|
-
c,
|
148
|
-
batch_size,
|
149
|
-
weight_column_major,
|
150
|
-
seg_indptr,
|
151
|
-
weight_indices,
|
152
|
-
use_fp8_w8a8,
|
153
|
-
scale_a,
|
154
|
-
scale_b,
|
155
|
-
block_shape=block_shape,
|
156
|
-
c_dtype=c_dtype,
|
157
|
-
use_per_token_if_dynamic=self.use_per_token_if_dynamic,
|
158
|
-
)
|
159
|
-
return c
|
160
|
-
|
161
|
-
|
162
|
-
def _get_tile_tokens_dim(num_tokens, top_k, num_experts):
|
163
|
-
# Guess tokens per expert assuming perfect expert distribution first.
|
164
|
-
num_tokens_per_expert = (num_tokens * top_k) // num_experts
|
165
|
-
# And pad the number to the next power of 2.
|
166
|
-
tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
|
167
|
-
# Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel.
|
168
|
-
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
|
169
|
-
return tile_tokens_dim
|
170
|
-
|
171
|
-
|
172
86
|
class EPMoE(FusedMoE):
|
173
87
|
"""
|
174
88
|
MoE Expert Parallel Impl
|
@@ -183,140 +97,57 @@ class EPMoE(FusedMoE):
|
|
183
97
|
hidden_size: int,
|
184
98
|
intermediate_size: int,
|
185
99
|
layer_id: int,
|
100
|
+
num_fused_shared_experts: int = 0,
|
186
101
|
params_dtype: Optional[torch.dtype] = None,
|
187
102
|
quant_config: Optional[QuantizationConfig] = None,
|
188
103
|
tp_size: Optional[int] = None,
|
189
104
|
prefix: str = "",
|
190
105
|
activation: str = "silu",
|
191
106
|
routed_scaling_factor: Optional[float] = None,
|
192
|
-
use_per_token_if_dynamic: bool = True,
|
193
107
|
):
|
194
108
|
super().__init__(
|
195
109
|
num_experts=num_experts,
|
196
110
|
hidden_size=hidden_size,
|
197
111
|
intermediate_size=intermediate_size,
|
198
|
-
|
112
|
+
num_fused_shared_experts=num_fused_shared_experts,
|
199
113
|
layer_id=layer_id,
|
114
|
+
top_k=top_k,
|
200
115
|
params_dtype=params_dtype,
|
201
116
|
quant_config=quant_config,
|
202
117
|
tp_size=tp_size,
|
203
118
|
prefix=prefix,
|
204
119
|
activation=activation,
|
120
|
+
# apply_router_weight_on_input=apply_router_weight_on_input,
|
205
121
|
routed_scaling_factor=routed_scaling_factor,
|
206
122
|
enable_ep_moe=True,
|
207
|
-
skip_quant=True,
|
208
123
|
)
|
209
124
|
|
210
|
-
|
211
|
-
params_dtype = torch.get_default_dtype()
|
212
|
-
|
213
|
-
self.layer_id = layer_id
|
214
|
-
self.num_local_experts, self.expert_map = self.determine_expert_map()
|
215
|
-
self.start_expert_id = self.ep_rank * self.num_local_experts
|
125
|
+
self.start_expert_id = self.moe_ep_rank * self.num_local_experts
|
216
126
|
self.end_expert_id = self.start_expert_id + self.num_local_experts - 1
|
217
127
|
|
218
128
|
self.intermediate_size = intermediate_size
|
219
|
-
self.use_per_token_if_dynamic = use_per_token_if_dynamic
|
220
129
|
|
221
|
-
|
222
|
-
if quant_config is None:
|
223
|
-
self.quant_method: Optional[QuantizeMethodBase] = (
|
224
|
-
UnquantizedFusedMoEMethod()
|
225
|
-
)
|
226
|
-
self.use_fp8_w8a8 = False
|
227
|
-
self.use_block_quant = False
|
228
|
-
self.block_shape = None
|
229
|
-
self.activation_scheme = None
|
230
|
-
self.w13_input_scale = None
|
231
|
-
self.w2_input_scale = None
|
232
|
-
self.w13_weight_scale = None
|
233
|
-
self.w2_weight_scale = None
|
234
|
-
elif isinstance(quant_config, W4AFp8Config):
|
235
|
-
self.quant_method: Optional[QuantizeMethodBase] = W4AFp8MoEMethod(
|
236
|
-
quant_config
|
237
|
-
)
|
238
|
-
self.use_fp8_w8a8 = False
|
239
|
-
self.use_block_quant = False
|
240
|
-
self.fp8_dtype = torch.float8_e4m3fn
|
241
|
-
self.w13_input_scale = None
|
242
|
-
self.w2_input_scale = None
|
243
|
-
self.w13_weight_scale = None
|
244
|
-
self.w2_weight_scale = None
|
245
|
-
self.activation_scheme = quant_config.moe_activation_scheme
|
246
|
-
elif isinstance(quant_config, Fp8Config):
|
247
|
-
self.quant_method: Optional[QuantizeMethodBase] = Fp8MoEMethod(quant_config)
|
248
|
-
self.use_fp8_w8a8 = True
|
130
|
+
if isinstance(quant_config, Fp8Config):
|
249
131
|
self.use_block_quant = getattr(self.quant_method, "block_quant", False)
|
250
132
|
self.block_shape = (
|
251
133
|
self.quant_method.quant_config.weight_block_size
|
252
134
|
if self.use_block_quant
|
253
135
|
else None
|
254
136
|
)
|
137
|
+
self.use_fp8_w8a8 = True
|
255
138
|
self.fp8_dtype = torch.float8_e4m3fn
|
256
139
|
self.activation_scheme = quant_config.activation_scheme
|
257
140
|
else:
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
layer=self,
|
263
|
-
num_experts=self.num_local_experts,
|
264
|
-
hidden_size=hidden_size,
|
265
|
-
intermediate_size=self.intermediate_size,
|
266
|
-
params_dtype=params_dtype,
|
267
|
-
weight_loader=self.weight_loader,
|
268
|
-
)
|
269
|
-
|
270
|
-
self.grouped_gemm_runner = None
|
271
|
-
|
272
|
-
# Adapted from https://github.com/vllm-project/vllm/blob/9fb52e523abf7bdaf7e60cf2971edb5a1b13dc08/vllm/model_executor/layers/fused_moe/layer.py#L544C1-L586C43
|
273
|
-
# Modifications: use determine_expert_map as a class internal function, set 'global_num_experts' rather than '-1' for experts not assigned to the current rank.
|
274
|
-
def determine_expert_map(self) -> Tuple[int, Optional[torch.Tensor]]:
|
275
|
-
"""
|
276
|
-
Calculates how many experts should be assigned to each rank for EP and
|
277
|
-
creates a mapping from global to local expert index. Experts are
|
278
|
-
distributed evenly across ranks. Any remaining are assigned to the
|
279
|
-
last rank.
|
280
|
-
|
281
|
-
Returns:
|
282
|
-
Tuple[int, Optional[torch.Tensor]]: A tuple containing:
|
283
|
-
- local_num_experts (int): The number of experts assigned
|
284
|
-
to the current rank.
|
285
|
-
- expert_map (Optional[torch.Tensor]): A tensor of shape
|
286
|
-
(global_num_experts,) mapping from global to local index.
|
287
|
-
Contains global_num_experts for experts not assigned to the current rank.
|
288
|
-
Returns None if ep_size is 1.
|
289
|
-
"""
|
290
|
-
ep_size = self.ep_size
|
291
|
-
ep_rank = self.ep_rank
|
292
|
-
global_num_experts = self.num_experts
|
293
|
-
|
294
|
-
assert ep_size > 0
|
295
|
-
if ep_size == 1:
|
296
|
-
return (global_num_experts, None)
|
297
|
-
|
298
|
-
local_num_experts = global_num_experts // ep_size
|
299
|
-
|
300
|
-
expert_map = torch.full(
|
301
|
-
(global_num_experts,), global_num_experts, dtype=torch.int32
|
302
|
-
)
|
303
|
-
if ep_rank < (ep_size - 1):
|
304
|
-
expert_map[
|
305
|
-
ep_rank * local_num_experts : (ep_rank + 1) * local_num_experts
|
306
|
-
] = torch.arange(0, local_num_experts, dtype=torch.int32)
|
307
|
-
else:
|
308
|
-
local_num_experts = global_num_experts - ep_rank * local_num_experts
|
309
|
-
|
310
|
-
expert_map[-local_num_experts:] = torch.arange(
|
311
|
-
0, local_num_experts, dtype=torch.int32
|
312
|
-
)
|
313
|
-
return (local_num_experts, expert_map)
|
141
|
+
self.use_fp8_w8a8 = False
|
142
|
+
self.use_block_quant = False
|
143
|
+
self.block_shape = None
|
144
|
+
self.activation_scheme = None
|
314
145
|
|
315
146
|
def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
|
316
147
|
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8:
|
317
148
|
return self.forward_deepgemm(hidden_states, topk_output)
|
318
149
|
else:
|
319
|
-
return
|
150
|
+
return super().forward(hidden_states, topk_output)
|
320
151
|
|
321
152
|
def forward_deepgemm(
|
322
153
|
self,
|
@@ -475,294 +306,6 @@ class EPMoE(FusedMoE):
|
|
475
306
|
)
|
476
307
|
return output
|
477
308
|
|
478
|
-
def forward_normal(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
|
479
|
-
return self.quant_method.apply(self, hidden_states, topk_output)
|
480
|
-
|
481
|
-
def run_moe(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
|
482
|
-
|
483
|
-
topk_weights, topk_ids, _ = topk_output
|
484
|
-
|
485
|
-
hidden_states_shape = hidden_states.shape
|
486
|
-
hidden_states_dtype = hidden_states.dtype
|
487
|
-
hidden_states_device = hidden_states.device
|
488
|
-
if self.grouped_gemm_runner is None:
|
489
|
-
self.grouped_gemm_runner = GroupedGemmRunner(
|
490
|
-
hidden_states.device,
|
491
|
-
use_flashinfer=False, # TODO: use flashinfer
|
492
|
-
use_per_token_if_dynamic=self.use_per_token_if_dynamic,
|
493
|
-
)
|
494
|
-
|
495
|
-
num_experts = self.num_experts
|
496
|
-
|
497
|
-
reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(
|
498
|
-
topk_ids,
|
499
|
-
num_experts,
|
500
|
-
)
|
501
|
-
|
502
|
-
gateup_input = torch.empty(
|
503
|
-
(int(hidden_states.shape[0] * self.top_k), hidden_states.shape[1]),
|
504
|
-
device=hidden_states.device,
|
505
|
-
dtype=(
|
506
|
-
self.fp8_dtype
|
507
|
-
if self.use_fp8_w8a8 and not self.use_block_quant
|
508
|
-
else hidden_states.dtype
|
509
|
-
),
|
510
|
-
)
|
511
|
-
if self.activation_scheme == "dynamic" and not self.use_block_quant:
|
512
|
-
if self.use_per_token_if_dynamic:
|
513
|
-
max_value = torch.max(hidden_states, dim=1).values.to(torch.float32)
|
514
|
-
self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max
|
515
|
-
else:
|
516
|
-
max_value = (
|
517
|
-
torch.max(hidden_states)
|
518
|
-
.repeat(self.num_local_experts)
|
519
|
-
.to(torch.float32)
|
520
|
-
)
|
521
|
-
self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max
|
522
|
-
|
523
|
-
# PreReorder
|
524
|
-
pre_reorder_triton_kernel[(hidden_states.shape[0],)](
|
525
|
-
hidden_states,
|
526
|
-
gateup_input,
|
527
|
-
src2dst,
|
528
|
-
topk_ids,
|
529
|
-
self.w13_input_scale,
|
530
|
-
self.start_expert_id,
|
531
|
-
self.end_expert_id,
|
532
|
-
self.top_k,
|
533
|
-
hidden_states.shape[1],
|
534
|
-
BLOCK_SIZE=512,
|
535
|
-
use_per_token_if_dynamic=self.use_per_token_if_dynamic,
|
536
|
-
)
|
537
|
-
dispose_tensor(hidden_states)
|
538
|
-
|
539
|
-
if (
|
540
|
-
self.activation_scheme == "dynamic"
|
541
|
-
and not self.use_block_quant
|
542
|
-
and self.use_per_token_if_dynamic
|
543
|
-
):
|
544
|
-
scale = torch.empty(
|
545
|
-
hidden_states_shape[0] * self.top_k,
|
546
|
-
device=hidden_states_device,
|
547
|
-
dtype=torch.float32,
|
548
|
-
)
|
549
|
-
scale[src2dst] = (
|
550
|
-
self.w13_input_scale.unsqueeze(1)
|
551
|
-
.expand(hidden_states_shape[0], self.top_k)
|
552
|
-
.reshape(-1)
|
553
|
-
)
|
554
|
-
self.w13_input_scale = scale
|
555
|
-
|
556
|
-
seg_indptr_cur_rank = seg_indptr[self.start_expert_id : self.end_expert_id + 2]
|
557
|
-
weight_indices_cur_rank = torch.arange(
|
558
|
-
0,
|
559
|
-
self.num_local_experts,
|
560
|
-
device=hidden_states_device,
|
561
|
-
dtype=torch.int64,
|
562
|
-
)
|
563
|
-
# GroupGemm-0
|
564
|
-
gateup_output = self.grouped_gemm_runner(
|
565
|
-
a=gateup_input,
|
566
|
-
b=self.w13_weight,
|
567
|
-
c=None,
|
568
|
-
c_dtype=hidden_states_dtype,
|
569
|
-
batch_size=self.num_local_experts,
|
570
|
-
weight_column_major=True,
|
571
|
-
seg_indptr=seg_indptr_cur_rank,
|
572
|
-
weight_indices=weight_indices_cur_rank,
|
573
|
-
use_fp8_w8a8=self.use_fp8_w8a8,
|
574
|
-
scale_a=self.w13_input_scale,
|
575
|
-
scale_b=self.w13_weight_scale,
|
576
|
-
block_shape=self.block_shape,
|
577
|
-
)
|
578
|
-
del gateup_input
|
579
|
-
|
580
|
-
# Act
|
581
|
-
if self.activation_scheme == "dynamic" and not self.use_block_quant:
|
582
|
-
self.w2_input_scale = None
|
583
|
-
down_input = torch.empty(
|
584
|
-
gateup_output.shape[0],
|
585
|
-
gateup_output.shape[1] // 2,
|
586
|
-
device=gateup_output.device,
|
587
|
-
dtype=hidden_states_dtype,
|
588
|
-
)
|
589
|
-
else:
|
590
|
-
down_input = torch.empty(
|
591
|
-
gateup_output.shape[0],
|
592
|
-
gateup_output.shape[1] // 2,
|
593
|
-
device=gateup_output.device,
|
594
|
-
dtype=(
|
595
|
-
self.fp8_dtype
|
596
|
-
if (self.use_fp8_w8a8 and not self.use_block_quant)
|
597
|
-
else hidden_states_dtype
|
598
|
-
),
|
599
|
-
)
|
600
|
-
|
601
|
-
if self.activation == "silu":
|
602
|
-
silu_and_mul_triton_kernel[(gateup_output.shape[0],)](
|
603
|
-
gateup_output,
|
604
|
-
down_input,
|
605
|
-
gateup_output.shape[1],
|
606
|
-
reorder_topk_ids,
|
607
|
-
self.w2_input_scale,
|
608
|
-
self.start_expert_id,
|
609
|
-
self.end_expert_id,
|
610
|
-
BLOCK_SIZE=512,
|
611
|
-
)
|
612
|
-
elif self.activation == "gelu":
|
613
|
-
gelu_and_mul_triton_kernel[(gateup_output.shape[0],)](
|
614
|
-
gateup_output,
|
615
|
-
down_input,
|
616
|
-
gateup_output.shape[1],
|
617
|
-
reorder_topk_ids,
|
618
|
-
self.w2_input_scale,
|
619
|
-
self.start_expert_id,
|
620
|
-
self.end_expert_id,
|
621
|
-
BLOCK_SIZE=512,
|
622
|
-
)
|
623
|
-
else:
|
624
|
-
raise ValueError(f"Unsupported activation: {self.activation=}")
|
625
|
-
del gateup_output
|
626
|
-
|
627
|
-
if self.activation_scheme == "dynamic" and not self.use_block_quant:
|
628
|
-
if self.use_per_token_if_dynamic:
|
629
|
-
down_input, self.w2_input_scale = sglang_per_token_quant_fp8(down_input)
|
630
|
-
else:
|
631
|
-
self.w2_input_scale = torch.ones(
|
632
|
-
self.num_local_experts,
|
633
|
-
dtype=torch.float32,
|
634
|
-
device=hidden_states_device,
|
635
|
-
)
|
636
|
-
|
637
|
-
# GroupGemm-1
|
638
|
-
down_output = torch.empty(
|
639
|
-
down_input.shape[0],
|
640
|
-
self.w2_weight.shape[1],
|
641
|
-
device=hidden_states_device,
|
642
|
-
dtype=hidden_states_dtype,
|
643
|
-
)
|
644
|
-
down_output = self.grouped_gemm_runner(
|
645
|
-
a=down_input,
|
646
|
-
b=self.w2_weight,
|
647
|
-
c=down_output,
|
648
|
-
batch_size=self.num_local_experts,
|
649
|
-
weight_column_major=True,
|
650
|
-
seg_indptr=seg_indptr_cur_rank,
|
651
|
-
weight_indices=weight_indices_cur_rank,
|
652
|
-
use_fp8_w8a8=self.use_fp8_w8a8,
|
653
|
-
scale_a=self.w2_input_scale,
|
654
|
-
scale_b=self.w2_weight_scale,
|
655
|
-
block_shape=self.block_shape,
|
656
|
-
)
|
657
|
-
del down_input
|
658
|
-
|
659
|
-
# PostReorder
|
660
|
-
output = torch.empty(
|
661
|
-
hidden_states_shape, dtype=hidden_states_dtype, device=hidden_states_device
|
662
|
-
)
|
663
|
-
post_reorder_triton_kernel[(hidden_states_shape[0],)](
|
664
|
-
down_output,
|
665
|
-
output,
|
666
|
-
src2dst,
|
667
|
-
topk_ids,
|
668
|
-
topk_weights,
|
669
|
-
self.start_expert_id,
|
670
|
-
self.end_expert_id,
|
671
|
-
self.top_k,
|
672
|
-
hidden_states_shape[1],
|
673
|
-
0,
|
674
|
-
BLOCK_SIZE=512,
|
675
|
-
)
|
676
|
-
return output
|
677
|
-
|
678
|
-
@classmethod
|
679
|
-
def make_expert_params_mapping(
|
680
|
-
cls,
|
681
|
-
ckpt_gate_proj_name: str,
|
682
|
-
ckpt_down_proj_name: str,
|
683
|
-
ckpt_up_proj_name: str,
|
684
|
-
num_experts: int,
|
685
|
-
) -> List[Tuple[str, str, int, str]]:
|
686
|
-
return [
|
687
|
-
# (param_name, weight_name, expert_id, shard_id)
|
688
|
-
(
|
689
|
-
(
|
690
|
-
"experts.w13_"
|
691
|
-
if weight_name in [ckpt_gate_proj_name, ckpt_up_proj_name]
|
692
|
-
else "experts.w2_"
|
693
|
-
),
|
694
|
-
f"experts.{expert_id}.{weight_name}.",
|
695
|
-
expert_id,
|
696
|
-
shard_id,
|
697
|
-
)
|
698
|
-
for expert_id in range(num_experts)
|
699
|
-
for shard_id, weight_name in [
|
700
|
-
("w1", ckpt_gate_proj_name),
|
701
|
-
("w2", ckpt_down_proj_name),
|
702
|
-
("w3", ckpt_up_proj_name),
|
703
|
-
]
|
704
|
-
]
|
705
|
-
|
706
|
-
@classmethod
|
707
|
-
def make_expert_input_scale_params_mapping(
|
708
|
-
cls,
|
709
|
-
num_experts: int,
|
710
|
-
) -> List[Tuple[str, str, int, str]]:
|
711
|
-
# (param_name, weight_name, expert_id, shard_id)
|
712
|
-
return [
|
713
|
-
(
|
714
|
-
"experts.w13_" if shard_id in ["w1", "w3"] else "experts.w2_",
|
715
|
-
f"experts.{expert_id}.{shard_id}.",
|
716
|
-
expert_id,
|
717
|
-
shard_id,
|
718
|
-
)
|
719
|
-
for expert_id in range(num_experts)
|
720
|
-
for shard_id in ["w1", "w2", "w3"]
|
721
|
-
]
|
722
|
-
|
723
|
-
def weight_loader(
|
724
|
-
self,
|
725
|
-
param: torch.nn.Parameter,
|
726
|
-
loaded_weight: torch.Tensor,
|
727
|
-
weight_name: str,
|
728
|
-
shard_id: str,
|
729
|
-
expert_id: int,
|
730
|
-
) -> None:
|
731
|
-
physical_expert_ids = (
|
732
|
-
get_global_expert_location_metadata().logical_to_all_physical(
|
733
|
-
self.layer_id, expert_id
|
734
|
-
)
|
735
|
-
)
|
736
|
-
for physical_expert_id in physical_expert_ids:
|
737
|
-
self._weight_loader_physical(
|
738
|
-
param=param,
|
739
|
-
loaded_weight=loaded_weight,
|
740
|
-
weight_name=weight_name,
|
741
|
-
shard_id=shard_id,
|
742
|
-
expert_id=physical_expert_id,
|
743
|
-
)
|
744
|
-
|
745
|
-
def _weight_loader_physical(
|
746
|
-
self,
|
747
|
-
param: torch.nn.Parameter,
|
748
|
-
loaded_weight: torch.Tensor,
|
749
|
-
weight_name: str,
|
750
|
-
shard_id: str,
|
751
|
-
expert_id: int,
|
752
|
-
) -> None:
|
753
|
-
if expert_id < self.start_expert_id or expert_id > self.end_expert_id:
|
754
|
-
return
|
755
|
-
expert_id = expert_id - self.start_expert_id
|
756
|
-
|
757
|
-
self._weight_loader_impl(
|
758
|
-
param=param,
|
759
|
-
loaded_weight=loaded_weight,
|
760
|
-
weight_name=weight_name,
|
761
|
-
shard_id=shard_id,
|
762
|
-
expert_id=expert_id,
|
763
|
-
)
|
764
|
-
return
|
765
|
-
|
766
309
|
|
767
310
|
class DeepEPMoE(EPMoE):
|
768
311
|
"""
|
@@ -778,6 +321,7 @@ class DeepEPMoE(EPMoE):
|
|
778
321
|
hidden_size: int,
|
779
322
|
intermediate_size: int,
|
780
323
|
layer_id: int,
|
324
|
+
num_fused_shared_experts: int = 0,
|
781
325
|
params_dtype: Optional[torch.dtype] = None,
|
782
326
|
quant_config: Optional[QuantizationConfig] = None,
|
783
327
|
tp_size: Optional[int] = None,
|
@@ -792,6 +336,7 @@ class DeepEPMoE(EPMoE):
|
|
792
336
|
hidden_size=hidden_size,
|
793
337
|
intermediate_size=intermediate_size,
|
794
338
|
layer_id=layer_id,
|
339
|
+
num_fused_shared_experts=num_fused_shared_experts,
|
795
340
|
params_dtype=params_dtype,
|
796
341
|
quant_config=quant_config,
|
797
342
|
tp_size=tp_size,
|
@@ -892,14 +437,15 @@ class DeepEPMoE(EPMoE):
|
|
892
437
|
# in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
|
893
438
|
return self.forward_aiter(dispatch_output)
|
894
439
|
if dispatch_output.format.is_deepep_normal():
|
895
|
-
|
896
|
-
|
897
|
-
else:
|
898
|
-
return self.forward_normal(dispatch_output)
|
440
|
+
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
|
441
|
+
return self.forward_deepgemm_contiguous(dispatch_output)
|
899
442
|
elif dispatch_output.format.is_deepep_ll():
|
443
|
+
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
|
900
444
|
return self.forward_deepgemm_masked(dispatch_output)
|
901
445
|
else:
|
902
|
-
raise ValueError(
|
446
|
+
raise ValueError(
|
447
|
+
f"Dispatch output format {dispatch_output.format} is not supported"
|
448
|
+
)
|
903
449
|
|
904
450
|
def combine(
|
905
451
|
self,
|
@@ -915,185 +461,6 @@ class DeepEPMoE(EPMoE):
|
|
915
461
|
forward_batch=forward_batch,
|
916
462
|
)
|
917
463
|
|
918
|
-
def _prepare_for_normal(
|
919
|
-
self,
|
920
|
-
hidden_states: torch.Tensor,
|
921
|
-
topk_idx: torch.Tensor,
|
922
|
-
):
|
923
|
-
from sglang.srt.layers.moe.ep_moe.kernels import (
|
924
|
-
deepep_permute_triton_kernel,
|
925
|
-
deepep_run_moe_deep_preprocess,
|
926
|
-
)
|
927
|
-
|
928
|
-
if hidden_states.shape[0] == 0:
|
929
|
-
reorder_topk_ids = torch.empty(
|
930
|
-
(0,), device=hidden_states.device, dtype=torch.int64
|
931
|
-
)
|
932
|
-
seg_indptr = torch.zeros(
|
933
|
-
(self.num_experts + 1,),
|
934
|
-
device=hidden_states.device,
|
935
|
-
dtype=torch.int64,
|
936
|
-
)
|
937
|
-
return reorder_topk_ids, seg_indptr, hidden_states
|
938
|
-
else:
|
939
|
-
if _use_aiter:
|
940
|
-
# skip permutation here as aiter fused_moe has fused inside
|
941
|
-
reorder_topk_ids = torch.empty(
|
942
|
-
(0,), device=hidden_states.device, dtype=torch.int64
|
943
|
-
)
|
944
|
-
seg_indptr = torch.zeros(
|
945
|
-
(self.num_experts + 1,),
|
946
|
-
device=hidden_states.device,
|
947
|
-
dtype=torch.int64,
|
948
|
-
)
|
949
|
-
return reorder_topk_ids, seg_indptr, hidden_states
|
950
|
-
|
951
|
-
reorder_topk_ids, self.src2dst, seg_indptr = deepep_run_moe_deep_preprocess(
|
952
|
-
topk_idx, self.num_experts
|
953
|
-
)
|
954
|
-
num_total_tokens = reorder_topk_ids.numel()
|
955
|
-
gateup_input = torch.empty(
|
956
|
-
(int(num_total_tokens), hidden_states.shape[1]),
|
957
|
-
device=hidden_states.device,
|
958
|
-
dtype=hidden_states.dtype,
|
959
|
-
)
|
960
|
-
# PreReorder
|
961
|
-
deepep_permute_triton_kernel[(hidden_states.shape[0],)](
|
962
|
-
hidden_states,
|
963
|
-
gateup_input,
|
964
|
-
self.src2dst,
|
965
|
-
topk_idx,
|
966
|
-
None,
|
967
|
-
self.router_topk,
|
968
|
-
hidden_states.shape[1],
|
969
|
-
BLOCK_SIZE=512,
|
970
|
-
)
|
971
|
-
return reorder_topk_ids, seg_indptr, gateup_input
|
972
|
-
|
973
|
-
def forward_normal(
|
974
|
-
self,
|
975
|
-
dispatch_output: DeepEPNormalOutput,
|
976
|
-
):
|
977
|
-
hidden_states, topk_idx = (
|
978
|
-
dispatch_output.hidden_states,
|
979
|
-
dispatch_output.topk_idx,
|
980
|
-
)
|
981
|
-
reorder_topk_ids, seg_indptr, hidden_states = self._prepare_for_normal(
|
982
|
-
hidden_states, topk_idx
|
983
|
-
)
|
984
|
-
hidden_states_dtype = hidden_states.dtype
|
985
|
-
hidden_states_device = hidden_states.device
|
986
|
-
|
987
|
-
assert self.quant_method is not None
|
988
|
-
assert self.activation == "silu"
|
989
|
-
if self.grouped_gemm_runner is None:
|
990
|
-
self.grouped_gemm_runner = GroupedGemmRunner(
|
991
|
-
hidden_states.device, use_flashinfer=False # TODO: use flashinfer
|
992
|
-
)
|
993
|
-
|
994
|
-
if self.activation_scheme == "dynamic" and not self.use_block_quant:
|
995
|
-
max_value = (
|
996
|
-
torch.max(hidden_states)
|
997
|
-
.repeat(self.num_local_experts)
|
998
|
-
.to(torch.float32)
|
999
|
-
)
|
1000
|
-
self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max
|
1001
|
-
weight_indices_cur_rank = torch.arange(
|
1002
|
-
0,
|
1003
|
-
self.num_local_experts,
|
1004
|
-
device=hidden_states.device,
|
1005
|
-
dtype=torch.int64,
|
1006
|
-
)
|
1007
|
-
|
1008
|
-
# GroupGemm-0
|
1009
|
-
if hidden_states.shape[0] > 0:
|
1010
|
-
gateup_output = self.grouped_gemm_runner(
|
1011
|
-
a=hidden_states,
|
1012
|
-
b=self.w13_weight,
|
1013
|
-
c=None,
|
1014
|
-
c_dtype=hidden_states.dtype,
|
1015
|
-
batch_size=self.num_local_experts,
|
1016
|
-
weight_column_major=True,
|
1017
|
-
seg_indptr=seg_indptr,
|
1018
|
-
weight_indices=weight_indices_cur_rank,
|
1019
|
-
use_fp8_w8a8=self.use_fp8_w8a8,
|
1020
|
-
scale_a=self.w13_input_scale,
|
1021
|
-
scale_b=(
|
1022
|
-
self.w13_weight_scale_inv
|
1023
|
-
if self.use_block_quant
|
1024
|
-
else self.w13_weight_scale
|
1025
|
-
),
|
1026
|
-
block_shape=self.block_shape,
|
1027
|
-
)
|
1028
|
-
else:
|
1029
|
-
gateup_output = torch.empty(
|
1030
|
-
hidden_states.shape[0],
|
1031
|
-
self.w13_weight.shape[1],
|
1032
|
-
device=hidden_states.device,
|
1033
|
-
dtype=hidden_states.dtype,
|
1034
|
-
)
|
1035
|
-
|
1036
|
-
# Act
|
1037
|
-
down_input = torch.empty(
|
1038
|
-
gateup_output.shape[0],
|
1039
|
-
gateup_output.shape[1] // 2,
|
1040
|
-
device=gateup_output.device,
|
1041
|
-
dtype=(
|
1042
|
-
self.fp8_dtype
|
1043
|
-
if (self.use_fp8_w8a8 and not self.use_block_quant)
|
1044
|
-
else hidden_states_dtype
|
1045
|
-
),
|
1046
|
-
)
|
1047
|
-
if self.w2_input_scale is None and not self.use_block_quant:
|
1048
|
-
self.w2_input_scale = torch.ones(
|
1049
|
-
self.num_local_experts,
|
1050
|
-
dtype=torch.float32,
|
1051
|
-
device=hidden_states_device,
|
1052
|
-
)
|
1053
|
-
|
1054
|
-
if self.activation == "silu":
|
1055
|
-
silu_and_mul_triton_kernel[(gateup_output.shape[0],)](
|
1056
|
-
gateup_output,
|
1057
|
-
down_input,
|
1058
|
-
gateup_output.shape[1],
|
1059
|
-
reorder_topk_ids,
|
1060
|
-
self.w2_input_scale,
|
1061
|
-
0,
|
1062
|
-
self.num_local_experts - 1,
|
1063
|
-
BLOCK_SIZE=512,
|
1064
|
-
)
|
1065
|
-
else:
|
1066
|
-
raise ValueError(f"Unsupported activation: {self.activation=}")
|
1067
|
-
|
1068
|
-
del gateup_output
|
1069
|
-
|
1070
|
-
# GroupGemm-1
|
1071
|
-
down_output = torch.empty(
|
1072
|
-
down_input.shape[0],
|
1073
|
-
self.w2_weight.shape[1],
|
1074
|
-
device=hidden_states_device,
|
1075
|
-
dtype=hidden_states_dtype,
|
1076
|
-
)
|
1077
|
-
if down_input.shape[0] > 0:
|
1078
|
-
down_output = self.grouped_gemm_runner(
|
1079
|
-
a=down_input,
|
1080
|
-
b=self.w2_weight,
|
1081
|
-
c=down_output,
|
1082
|
-
batch_size=self.num_local_experts,
|
1083
|
-
weight_column_major=True,
|
1084
|
-
seg_indptr=seg_indptr,
|
1085
|
-
weight_indices=weight_indices_cur_rank,
|
1086
|
-
use_fp8_w8a8=self.use_fp8_w8a8,
|
1087
|
-
scale_a=self.w2_input_scale,
|
1088
|
-
scale_b=(
|
1089
|
-
self.w2_weight_scale_inv
|
1090
|
-
if self.use_block_quant
|
1091
|
-
else self.w2_weight_scale
|
1092
|
-
),
|
1093
|
-
block_shape=self.block_shape,
|
1094
|
-
)
|
1095
|
-
return down_output
|
1096
|
-
|
1097
464
|
def forward_aiter(
|
1098
465
|
self,
|
1099
466
|
dispatch_output: DeepEPNormalOutput,
|
@@ -1351,10 +718,10 @@ class FlashInferEPMoE(EPMoE):
|
|
1351
718
|
self.num_expert_group = num_expert_group
|
1352
719
|
self.topk_group = topk_group
|
1353
720
|
self.correction_bias = correction_bias
|
1354
|
-
self.use_flashinfer_trtllm_moe =
|
721
|
+
self.use_flashinfer_trtllm_moe = should_use_flashinfer_trtllm_moe()
|
1355
722
|
|
1356
723
|
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
|
1357
|
-
assert use_flashinfer_trtllm_moe
|
724
|
+
assert self.use_flashinfer_trtllm_moe
|
1358
725
|
assert (
|
1359
726
|
self.activation == "silu"
|
1360
727
|
), "Only silu is supported for flashinfer blockscale fp8 moe"
|
@@ -1367,8 +734,9 @@ class FlashInferEPMoE(EPMoE):
|
|
1367
734
|
a_q, a_sf = sglang_per_token_group_quant_fp8(hidden_states, self.block_shape[1])
|
1368
735
|
# NOTE: scales of hidden states have to be transposed!
|
1369
736
|
a_sf_t = a_sf.t().contiguous()
|
1370
|
-
|
1371
|
-
|
737
|
+
from flashinfer.fused_moe import trtllm_fp8_block_scale_moe
|
738
|
+
|
739
|
+
return trtllm_fp8_block_scale_moe(
|
1372
740
|
routing_logits=router_logits.to(torch.float32),
|
1373
741
|
routing_bias=self.correction_bias.to(hidden_states.dtype),
|
1374
742
|
hidden_states=a_q,
|
@@ -1385,7 +753,7 @@ class FlashInferEPMoE(EPMoE):
|
|
1385
753
|
local_expert_offset=self.start_expert_id,
|
1386
754
|
local_num_experts=self.num_local_experts,
|
1387
755
|
routed_scaling_factor=self.routed_scaling_factor,
|
1388
|
-
tile_tokens_dim=
|
756
|
+
tile_tokens_dim=get_tile_tokens_dim(
|
1389
757
|
hidden_states.shape[0], self.top_k, self.num_experts
|
1390
758
|
),
|
1391
759
|
routing_method_type=2, # DeepSeek-styled routing method
|
@@ -1399,9 +767,6 @@ def get_moe_impl_class():
|
|
1399
767
|
if global_server_args_dict["enable_flashinfer_cutlass_moe"]:
|
1400
768
|
# Must come before EPMoE because FusedMoE also supports enable_ep_moe
|
1401
769
|
return FusedMoE
|
1402
|
-
if use_flashinfer_trtllm_moe:
|
1403
|
-
# Must come before EPMoE because FusedMoE also supports enable_ep_moe
|
1404
|
-
return FlashInferEPMoE
|
1405
770
|
if global_server_args_dict["enable_ep_moe"]:
|
1406
|
-
return EPMoE
|
1407
|
-
return FusedMoE
|
771
|
+
return FlashInferEPMoE if should_use_flashinfer_trtllm_moe() else EPMoE
|
772
|
+
return FlashInferFusedMoE if should_use_flashinfer_trtllm_moe() else FusedMoE
|