sglang 0.4.9.post5__py3-none-any.whl → 0.4.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/model_config.py +6 -0
- sglang/srt/configs/step3_vl.py +172 -0
- sglang/srt/conversation.py +23 -0
- sglang/srt/disaggregation/decode.py +2 -8
- sglang/srt/disaggregation/prefill.py +2 -6
- sglang/srt/distributed/parallel_state.py +86 -1
- sglang/srt/entrypoints/engine.py +14 -18
- sglang/srt/entrypoints/http_server.py +23 -3
- sglang/srt/entrypoints/openai/protocol.py +3 -1
- sglang/srt/entrypoints/openai/serving_base.py +5 -2
- sglang/srt/entrypoints/openai/serving_chat.py +2 -21
- sglang/srt/eplb/expert_distribution.py +5 -0
- sglang/srt/eplb/expert_location.py +17 -6
- sglang/srt/eplb/expert_location_dispatch.py +1 -0
- sglang/srt/eplb/expert_location_updater.py +2 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/step3_detector.py +436 -0
- sglang/srt/hf_transformers_utils.py +2 -0
- sglang/srt/jinja_template_utils.py +4 -1
- sglang/srt/layers/moe/cutlass_moe.py +2 -1
- sglang/srt/layers/moe/ep_moe/layer.py +98 -603
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +83 -118
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=192,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=384,device_name=NVIDIA_H200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +26 -13
- sglang/srt/layers/moe/fused_moe_triton/layer.py +97 -38
- sglang/srt/layers/moe/token_dispatcher/__init__.py +0 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +48 -0
- sglang/srt/layers/moe/token_dispatcher/standard.py +19 -0
- sglang/srt/layers/moe/topk.py +6 -2
- sglang/srt/layers/quantization/fp8.py +0 -18
- sglang/srt/layers/quantization/modelopt_quant.py +2 -0
- sglang/srt/layers/quantization/unquant.py +0 -8
- sglang/srt/layers/quantization/w4afp8.py +1 -0
- sglang/srt/managers/cache_controller.py +143 -45
- sglang/srt/managers/data_parallel_controller.py +6 -0
- sglang/srt/managers/io_struct.py +12 -2
- sglang/srt/managers/scheduler.py +116 -669
- sglang/srt/managers/scheduler_input_blocker.py +106 -0
- sglang/srt/managers/scheduler_metrics_mixin.py +229 -0
- sglang/srt/managers/scheduler_profiler_mixin.py +279 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +142 -0
- sglang/srt/managers/template_manager.py +62 -19
- sglang/srt/managers/tokenizer_manager.py +166 -83
- sglang/srt/managers/tp_worker.py +9 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +2 -1
- sglang/srt/mem_cache/hicache_storage.py +45 -11
- sglang/srt/mem_cache/hiradix_cache.py +15 -4
- sglang/srt/mem_cache/memory_pool_host.py +73 -1
- sglang/srt/mem_cache/mooncake_store/mooncake_store.py +264 -0
- sglang/srt/mem_cache/mooncake_store/unit_test.py +40 -0
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +177 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +278 -0
- sglang/srt/mem_cache/storage/hf3fs/test_hf3fs_utils.py +43 -0
- sglang/srt/model_executor/model_runner.py +20 -13
- sglang/srt/models/arcee.py +532 -0
- sglang/srt/models/deepseek_v2.py +15 -56
- sglang/srt/models/glm4_moe.py +3 -1
- sglang/srt/models/granitemoe.py +3 -0
- sglang/srt/models/grok.py +3 -0
- sglang/srt/models/hunyuan.py +1 -0
- sglang/srt/models/llama4.py +3 -0
- sglang/srt/models/mixtral.py +3 -0
- sglang/srt/models/olmoe.py +3 -0
- sglang/srt/models/phimoe.py +1 -0
- sglang/srt/models/qwen3_moe.py +12 -69
- sglang/srt/models/step3_vl.py +994 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -16
- sglang/srt/multimodal/processors/step3_vl.py +515 -0
- sglang/srt/poll_based_barrier.py +31 -0
- sglang/srt/reasoning_parser.py +2 -1
- sglang/srt/server_args.py +18 -13
- sglang/srt/speculative/eagle_worker.py +2 -0
- sglang/srt/two_batch_overlap.py +8 -3
- sglang/test/test_utils.py +53 -0
- sglang/utils.py +0 -11
- sglang/version.py +1 -1
- {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/METADATA +4 -4
- {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/RECORD +84 -64
- {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post5.dist-info → sglang-0.4.10.dist-info}/top_level.txt +0 -0
@@ -1,5 +1,7 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
import logging
|
2
|
-
from typing import List, Optional, Tuple
|
4
|
+
from typing import TYPE_CHECKING, List, Optional, Tuple
|
3
5
|
|
4
6
|
import torch
|
5
7
|
|
@@ -50,6 +52,13 @@ from sglang.srt.utils import (
|
|
50
52
|
next_power_of_2,
|
51
53
|
)
|
52
54
|
|
55
|
+
if TYPE_CHECKING:
|
56
|
+
from sglang.srt.layers.moe.ep_moe.token_dispatcher import (
|
57
|
+
DeepEPLLOutput,
|
58
|
+
DeepEPNormalOutput,
|
59
|
+
DispatchOutput,
|
60
|
+
)
|
61
|
+
|
53
62
|
_is_hip = is_hip()
|
54
63
|
_is_npu = is_npu()
|
55
64
|
_is_fp8_fnuz = is_fp8_fnuz()
|
@@ -77,79 +86,6 @@ if use_flashinfer_trtllm_moe:
|
|
77
86
|
logger = logging.getLogger(__name__)
|
78
87
|
|
79
88
|
|
80
|
-
class GroupedGemmRunner(torch.nn.Module):
|
81
|
-
flashinfer_gemm_warpper = None
|
82
|
-
|
83
|
-
def __init__(
|
84
|
-
self,
|
85
|
-
device,
|
86
|
-
use_flashinfer: bool = False,
|
87
|
-
use_per_token_if_dynamic: bool = True,
|
88
|
-
):
|
89
|
-
super().__init__()
|
90
|
-
self.device = device
|
91
|
-
self.use_flashinfer = use_flashinfer
|
92
|
-
self.use_per_token_if_dynamic = use_per_token_if_dynamic
|
93
|
-
if self.use_flashinfer and GroupedGemmRunner.flashinfer_gemm_warpper is None:
|
94
|
-
GroupedGemmRunner._init_flashinfer_wrapper(device)
|
95
|
-
|
96
|
-
@classmethod
|
97
|
-
def _init_flashinfer_wrapper(cls, device):
|
98
|
-
from flashinfer import SegmentGEMMWrapper
|
99
|
-
|
100
|
-
workspace_buffer = torch.empty(
|
101
|
-
128 * 1024 * 1024, dtype=torch.int8, device=device
|
102
|
-
)
|
103
|
-
cls.flashinfer_gemm_warpper = SegmentGEMMWrapper(workspace_buffer)
|
104
|
-
|
105
|
-
# c = a * b
|
106
|
-
def forward(
|
107
|
-
self,
|
108
|
-
a: torch.Tensor,
|
109
|
-
b: torch.Tensor,
|
110
|
-
c: torch.Tensor,
|
111
|
-
batch_size: int,
|
112
|
-
weight_column_major: bool,
|
113
|
-
seg_indptr: Optional[torch.Tensor] = None,
|
114
|
-
weight_indices: Optional[torch.Tensor] = None,
|
115
|
-
use_fp8_w8a8: bool = False,
|
116
|
-
scale_a: torch.Tensor = None,
|
117
|
-
scale_b: torch.Tensor = None,
|
118
|
-
block_shape: Optional[List[int]] = None,
|
119
|
-
c_dtype=None,
|
120
|
-
):
|
121
|
-
if self.use_flashinfer:
|
122
|
-
# TODO: flashinfer
|
123
|
-
assert False
|
124
|
-
assert GroupedGemmRunner.flashinfer_gemm_warpper is not None
|
125
|
-
c = GroupedGemmRunner.flashinfer_gemm_warpper.run(
|
126
|
-
x=a,
|
127
|
-
weights=b,
|
128
|
-
batch_size=batch_size,
|
129
|
-
weight_column_major=weight_column_major,
|
130
|
-
seg_indptr=seg_indptr,
|
131
|
-
weight_indices=weight_indices,
|
132
|
-
)
|
133
|
-
else:
|
134
|
-
assert weight_column_major == True
|
135
|
-
c = grouped_gemm_triton(
|
136
|
-
a,
|
137
|
-
b,
|
138
|
-
c,
|
139
|
-
batch_size,
|
140
|
-
weight_column_major,
|
141
|
-
seg_indptr,
|
142
|
-
weight_indices,
|
143
|
-
use_fp8_w8a8,
|
144
|
-
scale_a,
|
145
|
-
scale_b,
|
146
|
-
block_shape=block_shape,
|
147
|
-
c_dtype=c_dtype,
|
148
|
-
use_per_token_if_dynamic=self.use_per_token_if_dynamic,
|
149
|
-
)
|
150
|
-
return c
|
151
|
-
|
152
|
-
|
153
89
|
def _get_tile_tokens_dim(num_tokens, top_k, num_experts):
|
154
90
|
# Guess tokens per expert assuming perfect expert distribution first.
|
155
91
|
num_tokens_per_expert = (num_tokens * top_k) // num_experts
|
@@ -174,140 +110,57 @@ class EPMoE(FusedMoE):
|
|
174
110
|
hidden_size: int,
|
175
111
|
intermediate_size: int,
|
176
112
|
layer_id: int,
|
113
|
+
num_fused_shared_experts: int = 0,
|
177
114
|
params_dtype: Optional[torch.dtype] = None,
|
178
115
|
quant_config: Optional[QuantizationConfig] = None,
|
179
116
|
tp_size: Optional[int] = None,
|
180
117
|
prefix: str = "",
|
181
118
|
activation: str = "silu",
|
182
119
|
routed_scaling_factor: Optional[float] = None,
|
183
|
-
use_per_token_if_dynamic: bool = True,
|
184
120
|
):
|
185
121
|
super().__init__(
|
186
122
|
num_experts=num_experts,
|
187
123
|
hidden_size=hidden_size,
|
188
124
|
intermediate_size=intermediate_size,
|
189
|
-
|
125
|
+
num_fused_shared_experts=num_fused_shared_experts,
|
190
126
|
layer_id=layer_id,
|
127
|
+
top_k=top_k,
|
191
128
|
params_dtype=params_dtype,
|
192
129
|
quant_config=quant_config,
|
193
130
|
tp_size=tp_size,
|
194
131
|
prefix=prefix,
|
195
132
|
activation=activation,
|
133
|
+
# apply_router_weight_on_input=apply_router_weight_on_input,
|
196
134
|
routed_scaling_factor=routed_scaling_factor,
|
197
135
|
enable_ep_moe=True,
|
198
|
-
skip_quant=True,
|
199
136
|
)
|
200
137
|
|
201
|
-
|
202
|
-
params_dtype = torch.get_default_dtype()
|
203
|
-
|
204
|
-
self.layer_id = layer_id
|
205
|
-
self.num_local_experts, self.expert_map = self.determine_expert_map()
|
206
|
-
self.start_expert_id = self.ep_rank * self.num_local_experts
|
138
|
+
self.start_expert_id = self.moe_ep_rank * self.num_local_experts
|
207
139
|
self.end_expert_id = self.start_expert_id + self.num_local_experts - 1
|
208
140
|
|
209
141
|
self.intermediate_size = intermediate_size
|
210
|
-
self.use_per_token_if_dynamic = use_per_token_if_dynamic
|
211
142
|
|
212
|
-
|
213
|
-
if quant_config is None:
|
214
|
-
self.quant_method: Optional[QuantizeMethodBase] = (
|
215
|
-
UnquantizedFusedMoEMethod()
|
216
|
-
)
|
217
|
-
self.use_fp8_w8a8 = False
|
218
|
-
self.use_block_quant = False
|
219
|
-
self.block_shape = None
|
220
|
-
self.activation_scheme = None
|
221
|
-
self.w13_input_scale = None
|
222
|
-
self.w2_input_scale = None
|
223
|
-
self.w13_weight_scale = None
|
224
|
-
self.w2_weight_scale = None
|
225
|
-
elif isinstance(quant_config, W4AFp8Config):
|
226
|
-
self.quant_method: Optional[QuantizeMethodBase] = W4AFp8MoEMethod(
|
227
|
-
quant_config
|
228
|
-
)
|
229
|
-
self.use_fp8_w8a8 = False
|
230
|
-
self.use_block_quant = False
|
231
|
-
self.fp8_dtype = torch.float8_e4m3fn
|
232
|
-
self.w13_input_scale = None
|
233
|
-
self.w2_input_scale = None
|
234
|
-
self.w13_weight_scale = None
|
235
|
-
self.w2_weight_scale = None
|
236
|
-
self.activation_scheme = quant_config.moe_activation_scheme
|
237
|
-
elif isinstance(quant_config, Fp8Config):
|
238
|
-
self.quant_method: Optional[QuantizeMethodBase] = Fp8MoEMethod(quant_config)
|
239
|
-
self.use_fp8_w8a8 = True
|
143
|
+
if isinstance(quant_config, Fp8Config):
|
240
144
|
self.use_block_quant = getattr(self.quant_method, "block_quant", False)
|
241
145
|
self.block_shape = (
|
242
146
|
self.quant_method.quant_config.weight_block_size
|
243
147
|
if self.use_block_quant
|
244
148
|
else None
|
245
149
|
)
|
150
|
+
self.use_fp8_w8a8 = True
|
246
151
|
self.fp8_dtype = torch.float8_e4m3fn
|
247
152
|
self.activation_scheme = quant_config.activation_scheme
|
248
153
|
else:
|
249
|
-
|
250
|
-
|
251
|
-
|
252
|
-
|
253
|
-
layer=self,
|
254
|
-
num_experts=self.num_local_experts,
|
255
|
-
hidden_size=hidden_size,
|
256
|
-
intermediate_size=self.intermediate_size,
|
257
|
-
params_dtype=params_dtype,
|
258
|
-
weight_loader=self.weight_loader,
|
259
|
-
)
|
260
|
-
|
261
|
-
self.grouped_gemm_runner = None
|
262
|
-
|
263
|
-
# Adapted from https://github.com/vllm-project/vllm/blob/9fb52e523abf7bdaf7e60cf2971edb5a1b13dc08/vllm/model_executor/layers/fused_moe/layer.py#L544C1-L586C43
|
264
|
-
# Modifications: use determine_expert_map as a class internal function, set 'global_num_experts' rather than '-1' for experts not assigned to the current rank.
|
265
|
-
def determine_expert_map(self) -> Tuple[int, Optional[torch.Tensor]]:
|
266
|
-
"""
|
267
|
-
Calculates how many experts should be assigned to each rank for EP and
|
268
|
-
creates a mapping from global to local expert index. Experts are
|
269
|
-
distributed evenly across ranks. Any remaining are assigned to the
|
270
|
-
last rank.
|
271
|
-
|
272
|
-
Returns:
|
273
|
-
Tuple[int, Optional[torch.Tensor]]: A tuple containing:
|
274
|
-
- local_num_experts (int): The number of experts assigned
|
275
|
-
to the current rank.
|
276
|
-
- expert_map (Optional[torch.Tensor]): A tensor of shape
|
277
|
-
(global_num_experts,) mapping from global to local index.
|
278
|
-
Contains global_num_experts for experts not assigned to the current rank.
|
279
|
-
Returns None if ep_size is 1.
|
280
|
-
"""
|
281
|
-
ep_size = self.ep_size
|
282
|
-
ep_rank = self.ep_rank
|
283
|
-
global_num_experts = self.num_experts
|
284
|
-
|
285
|
-
assert ep_size > 0
|
286
|
-
if ep_size == 1:
|
287
|
-
return (global_num_experts, None)
|
288
|
-
|
289
|
-
local_num_experts = global_num_experts // ep_size
|
290
|
-
|
291
|
-
expert_map = torch.full(
|
292
|
-
(global_num_experts,), global_num_experts, dtype=torch.int32
|
293
|
-
)
|
294
|
-
if ep_rank < (ep_size - 1):
|
295
|
-
expert_map[
|
296
|
-
ep_rank * local_num_experts : (ep_rank + 1) * local_num_experts
|
297
|
-
] = torch.arange(0, local_num_experts, dtype=torch.int32)
|
298
|
-
else:
|
299
|
-
local_num_experts = global_num_experts - ep_rank * local_num_experts
|
300
|
-
|
301
|
-
expert_map[-local_num_experts:] = torch.arange(
|
302
|
-
0, local_num_experts, dtype=torch.int32
|
303
|
-
)
|
304
|
-
return (local_num_experts, expert_map)
|
154
|
+
self.use_fp8_w8a8 = False
|
155
|
+
self.use_block_quant = False
|
156
|
+
self.block_shape = None
|
157
|
+
self.activation_scheme = None
|
305
158
|
|
306
159
|
def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
|
307
160
|
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8:
|
308
161
|
return self.forward_deepgemm(hidden_states, topk_output)
|
309
162
|
else:
|
310
|
-
return
|
163
|
+
return super().forward(hidden_states, topk_output)
|
311
164
|
|
312
165
|
def forward_deepgemm(
|
313
166
|
self,
|
@@ -466,294 +319,6 @@ class EPMoE(FusedMoE):
|
|
466
319
|
)
|
467
320
|
return output
|
468
321
|
|
469
|
-
def forward_normal(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
|
470
|
-
return self.quant_method.apply(self, hidden_states, topk_output)
|
471
|
-
|
472
|
-
def run_moe(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
|
473
|
-
|
474
|
-
topk_weights, topk_ids, _ = topk_output
|
475
|
-
|
476
|
-
hidden_states_shape = hidden_states.shape
|
477
|
-
hidden_states_dtype = hidden_states.dtype
|
478
|
-
hidden_states_device = hidden_states.device
|
479
|
-
if self.grouped_gemm_runner is None:
|
480
|
-
self.grouped_gemm_runner = GroupedGemmRunner(
|
481
|
-
hidden_states.device,
|
482
|
-
use_flashinfer=False, # TODO: use flashinfer
|
483
|
-
use_per_token_if_dynamic=self.use_per_token_if_dynamic,
|
484
|
-
)
|
485
|
-
|
486
|
-
num_experts = self.num_experts
|
487
|
-
|
488
|
-
reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(
|
489
|
-
topk_ids,
|
490
|
-
num_experts,
|
491
|
-
)
|
492
|
-
|
493
|
-
gateup_input = torch.empty(
|
494
|
-
(int(hidden_states.shape[0] * self.top_k), hidden_states.shape[1]),
|
495
|
-
device=hidden_states.device,
|
496
|
-
dtype=(
|
497
|
-
self.fp8_dtype
|
498
|
-
if self.use_fp8_w8a8 and not self.use_block_quant
|
499
|
-
else hidden_states.dtype
|
500
|
-
),
|
501
|
-
)
|
502
|
-
if self.activation_scheme == "dynamic" and not self.use_block_quant:
|
503
|
-
if self.use_per_token_if_dynamic:
|
504
|
-
max_value = torch.max(hidden_states, dim=1).values.to(torch.float32)
|
505
|
-
self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max
|
506
|
-
else:
|
507
|
-
max_value = (
|
508
|
-
torch.max(hidden_states)
|
509
|
-
.repeat(self.num_local_experts)
|
510
|
-
.to(torch.float32)
|
511
|
-
)
|
512
|
-
self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max
|
513
|
-
|
514
|
-
# PreReorder
|
515
|
-
pre_reorder_triton_kernel[(hidden_states.shape[0],)](
|
516
|
-
hidden_states,
|
517
|
-
gateup_input,
|
518
|
-
src2dst,
|
519
|
-
topk_ids,
|
520
|
-
self.w13_input_scale,
|
521
|
-
self.start_expert_id,
|
522
|
-
self.end_expert_id,
|
523
|
-
self.top_k,
|
524
|
-
hidden_states.shape[1],
|
525
|
-
BLOCK_SIZE=512,
|
526
|
-
use_per_token_if_dynamic=self.use_per_token_if_dynamic,
|
527
|
-
)
|
528
|
-
dispose_tensor(hidden_states)
|
529
|
-
|
530
|
-
if (
|
531
|
-
self.activation_scheme == "dynamic"
|
532
|
-
and not self.use_block_quant
|
533
|
-
and self.use_per_token_if_dynamic
|
534
|
-
):
|
535
|
-
scale = torch.empty(
|
536
|
-
hidden_states_shape[0] * self.top_k,
|
537
|
-
device=hidden_states_device,
|
538
|
-
dtype=torch.float32,
|
539
|
-
)
|
540
|
-
scale[src2dst] = (
|
541
|
-
self.w13_input_scale.unsqueeze(1)
|
542
|
-
.expand(hidden_states_shape[0], self.top_k)
|
543
|
-
.reshape(-1)
|
544
|
-
)
|
545
|
-
self.w13_input_scale = scale
|
546
|
-
|
547
|
-
seg_indptr_cur_rank = seg_indptr[self.start_expert_id : self.end_expert_id + 2]
|
548
|
-
weight_indices_cur_rank = torch.arange(
|
549
|
-
0,
|
550
|
-
self.num_local_experts,
|
551
|
-
device=hidden_states_device,
|
552
|
-
dtype=torch.int64,
|
553
|
-
)
|
554
|
-
# GroupGemm-0
|
555
|
-
gateup_output = self.grouped_gemm_runner(
|
556
|
-
a=gateup_input,
|
557
|
-
b=self.w13_weight,
|
558
|
-
c=None,
|
559
|
-
c_dtype=hidden_states_dtype,
|
560
|
-
batch_size=self.num_local_experts,
|
561
|
-
weight_column_major=True,
|
562
|
-
seg_indptr=seg_indptr_cur_rank,
|
563
|
-
weight_indices=weight_indices_cur_rank,
|
564
|
-
use_fp8_w8a8=self.use_fp8_w8a8,
|
565
|
-
scale_a=self.w13_input_scale,
|
566
|
-
scale_b=self.w13_weight_scale,
|
567
|
-
block_shape=self.block_shape,
|
568
|
-
)
|
569
|
-
del gateup_input
|
570
|
-
|
571
|
-
# Act
|
572
|
-
if self.activation_scheme == "dynamic" and not self.use_block_quant:
|
573
|
-
self.w2_input_scale = None
|
574
|
-
down_input = torch.empty(
|
575
|
-
gateup_output.shape[0],
|
576
|
-
gateup_output.shape[1] // 2,
|
577
|
-
device=gateup_output.device,
|
578
|
-
dtype=hidden_states_dtype,
|
579
|
-
)
|
580
|
-
else:
|
581
|
-
down_input = torch.empty(
|
582
|
-
gateup_output.shape[0],
|
583
|
-
gateup_output.shape[1] // 2,
|
584
|
-
device=gateup_output.device,
|
585
|
-
dtype=(
|
586
|
-
self.fp8_dtype
|
587
|
-
if (self.use_fp8_w8a8 and not self.use_block_quant)
|
588
|
-
else hidden_states_dtype
|
589
|
-
),
|
590
|
-
)
|
591
|
-
|
592
|
-
if self.activation == "silu":
|
593
|
-
silu_and_mul_triton_kernel[(gateup_output.shape[0],)](
|
594
|
-
gateup_output,
|
595
|
-
down_input,
|
596
|
-
gateup_output.shape[1],
|
597
|
-
reorder_topk_ids,
|
598
|
-
self.w2_input_scale,
|
599
|
-
self.start_expert_id,
|
600
|
-
self.end_expert_id,
|
601
|
-
BLOCK_SIZE=512,
|
602
|
-
)
|
603
|
-
elif self.activation == "gelu":
|
604
|
-
gelu_and_mul_triton_kernel[(gateup_output.shape[0],)](
|
605
|
-
gateup_output,
|
606
|
-
down_input,
|
607
|
-
gateup_output.shape[1],
|
608
|
-
reorder_topk_ids,
|
609
|
-
self.w2_input_scale,
|
610
|
-
self.start_expert_id,
|
611
|
-
self.end_expert_id,
|
612
|
-
BLOCK_SIZE=512,
|
613
|
-
)
|
614
|
-
else:
|
615
|
-
raise ValueError(f"Unsupported activation: {self.activation=}")
|
616
|
-
del gateup_output
|
617
|
-
|
618
|
-
if self.activation_scheme == "dynamic" and not self.use_block_quant:
|
619
|
-
if self.use_per_token_if_dynamic:
|
620
|
-
down_input, self.w2_input_scale = sglang_per_token_quant_fp8(down_input)
|
621
|
-
else:
|
622
|
-
self.w2_input_scale = torch.ones(
|
623
|
-
self.num_local_experts,
|
624
|
-
dtype=torch.float32,
|
625
|
-
device=hidden_states_device,
|
626
|
-
)
|
627
|
-
|
628
|
-
# GroupGemm-1
|
629
|
-
down_output = torch.empty(
|
630
|
-
down_input.shape[0],
|
631
|
-
self.w2_weight.shape[1],
|
632
|
-
device=hidden_states_device,
|
633
|
-
dtype=hidden_states_dtype,
|
634
|
-
)
|
635
|
-
down_output = self.grouped_gemm_runner(
|
636
|
-
a=down_input,
|
637
|
-
b=self.w2_weight,
|
638
|
-
c=down_output,
|
639
|
-
batch_size=self.num_local_experts,
|
640
|
-
weight_column_major=True,
|
641
|
-
seg_indptr=seg_indptr_cur_rank,
|
642
|
-
weight_indices=weight_indices_cur_rank,
|
643
|
-
use_fp8_w8a8=self.use_fp8_w8a8,
|
644
|
-
scale_a=self.w2_input_scale,
|
645
|
-
scale_b=self.w2_weight_scale,
|
646
|
-
block_shape=self.block_shape,
|
647
|
-
)
|
648
|
-
del down_input
|
649
|
-
|
650
|
-
# PostReorder
|
651
|
-
output = torch.empty(
|
652
|
-
hidden_states_shape, dtype=hidden_states_dtype, device=hidden_states_device
|
653
|
-
)
|
654
|
-
post_reorder_triton_kernel[(hidden_states_shape[0],)](
|
655
|
-
down_output,
|
656
|
-
output,
|
657
|
-
src2dst,
|
658
|
-
topk_ids,
|
659
|
-
topk_weights,
|
660
|
-
self.start_expert_id,
|
661
|
-
self.end_expert_id,
|
662
|
-
self.top_k,
|
663
|
-
hidden_states_shape[1],
|
664
|
-
0,
|
665
|
-
BLOCK_SIZE=512,
|
666
|
-
)
|
667
|
-
return output
|
668
|
-
|
669
|
-
@classmethod
|
670
|
-
def make_expert_params_mapping(
|
671
|
-
cls,
|
672
|
-
ckpt_gate_proj_name: str,
|
673
|
-
ckpt_down_proj_name: str,
|
674
|
-
ckpt_up_proj_name: str,
|
675
|
-
num_experts: int,
|
676
|
-
) -> List[Tuple[str, str, int, str]]:
|
677
|
-
return [
|
678
|
-
# (param_name, weight_name, expert_id, shard_id)
|
679
|
-
(
|
680
|
-
(
|
681
|
-
"experts.w13_"
|
682
|
-
if weight_name in [ckpt_gate_proj_name, ckpt_up_proj_name]
|
683
|
-
else "experts.w2_"
|
684
|
-
),
|
685
|
-
f"experts.{expert_id}.{weight_name}.",
|
686
|
-
expert_id,
|
687
|
-
shard_id,
|
688
|
-
)
|
689
|
-
for expert_id in range(num_experts)
|
690
|
-
for shard_id, weight_name in [
|
691
|
-
("w1", ckpt_gate_proj_name),
|
692
|
-
("w2", ckpt_down_proj_name),
|
693
|
-
("w3", ckpt_up_proj_name),
|
694
|
-
]
|
695
|
-
]
|
696
|
-
|
697
|
-
@classmethod
|
698
|
-
def make_expert_input_scale_params_mapping(
|
699
|
-
cls,
|
700
|
-
num_experts: int,
|
701
|
-
) -> List[Tuple[str, str, int, str]]:
|
702
|
-
# (param_name, weight_name, expert_id, shard_id)
|
703
|
-
return [
|
704
|
-
(
|
705
|
-
"experts.w13_" if shard_id in ["w1", "w3"] else "experts.w2_",
|
706
|
-
f"experts.{expert_id}.{shard_id}.",
|
707
|
-
expert_id,
|
708
|
-
shard_id,
|
709
|
-
)
|
710
|
-
for expert_id in range(num_experts)
|
711
|
-
for shard_id in ["w1", "w2", "w3"]
|
712
|
-
]
|
713
|
-
|
714
|
-
def weight_loader(
|
715
|
-
self,
|
716
|
-
param: torch.nn.Parameter,
|
717
|
-
loaded_weight: torch.Tensor,
|
718
|
-
weight_name: str,
|
719
|
-
shard_id: str,
|
720
|
-
expert_id: int,
|
721
|
-
) -> None:
|
722
|
-
physical_expert_ids = (
|
723
|
-
get_global_expert_location_metadata().logical_to_all_physical(
|
724
|
-
self.layer_id, expert_id
|
725
|
-
)
|
726
|
-
)
|
727
|
-
for physical_expert_id in physical_expert_ids:
|
728
|
-
self._weight_loader_physical(
|
729
|
-
param=param,
|
730
|
-
loaded_weight=loaded_weight,
|
731
|
-
weight_name=weight_name,
|
732
|
-
shard_id=shard_id,
|
733
|
-
expert_id=physical_expert_id,
|
734
|
-
)
|
735
|
-
|
736
|
-
def _weight_loader_physical(
|
737
|
-
self,
|
738
|
-
param: torch.nn.Parameter,
|
739
|
-
loaded_weight: torch.Tensor,
|
740
|
-
weight_name: str,
|
741
|
-
shard_id: str,
|
742
|
-
expert_id: int,
|
743
|
-
) -> None:
|
744
|
-
if expert_id < self.start_expert_id or expert_id > self.end_expert_id:
|
745
|
-
return
|
746
|
-
expert_id = expert_id - self.start_expert_id
|
747
|
-
|
748
|
-
self._weight_loader_impl(
|
749
|
-
param=param,
|
750
|
-
loaded_weight=loaded_weight,
|
751
|
-
weight_name=weight_name,
|
752
|
-
shard_id=shard_id,
|
753
|
-
expert_id=expert_id,
|
754
|
-
)
|
755
|
-
return
|
756
|
-
|
757
322
|
|
758
323
|
class DeepEPMoE(EPMoE):
|
759
324
|
"""
|
@@ -769,6 +334,7 @@ class DeepEPMoE(EPMoE):
|
|
769
334
|
hidden_size: int,
|
770
335
|
intermediate_size: int,
|
771
336
|
layer_id: int,
|
337
|
+
num_fused_shared_experts: int = 0,
|
772
338
|
params_dtype: Optional[torch.dtype] = None,
|
773
339
|
quant_config: Optional[QuantizationConfig] = None,
|
774
340
|
tp_size: Optional[int] = None,
|
@@ -783,6 +349,7 @@ class DeepEPMoE(EPMoE):
|
|
783
349
|
hidden_size=hidden_size,
|
784
350
|
intermediate_size=intermediate_size,
|
785
351
|
layer_id=layer_id,
|
352
|
+
num_fused_shared_experts=num_fused_shared_experts,
|
786
353
|
params_dtype=params_dtype,
|
787
354
|
quant_config=quant_config,
|
788
355
|
tp_size=tp_size,
|
@@ -791,11 +358,24 @@ class DeepEPMoE(EPMoE):
|
|
791
358
|
routed_scaling_factor=routed_scaling_factor,
|
792
359
|
)
|
793
360
|
self.deepep_mode = deepep_mode
|
794
|
-
|
795
|
-
|
796
|
-
|
797
|
-
|
798
|
-
|
361
|
+
|
362
|
+
# TODO: move to the beginning of the file
|
363
|
+
from sglang.srt.distributed.parallel_state import get_tp_group
|
364
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
365
|
+
from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher
|
366
|
+
|
367
|
+
self.deepep_dispatcher = MaybeTboDeepEPDispatcher(
|
368
|
+
group=get_tp_group().device_group,
|
369
|
+
router_topk=self.top_k,
|
370
|
+
permute_fusion=True,
|
371
|
+
num_experts=self.num_experts,
|
372
|
+
num_local_experts=self.num_local_experts,
|
373
|
+
hidden_size=hidden_size,
|
374
|
+
params_dtype=params_dtype,
|
375
|
+
deepep_mode=deepep_mode,
|
376
|
+
async_finish=True, # TODO
|
377
|
+
return_recv_hook=True,
|
378
|
+
)
|
799
379
|
|
800
380
|
if self.deepep_mode.enable_low_latency():
|
801
381
|
assert (
|
@@ -837,156 +417,72 @@ class DeepEPMoE(EPMoE):
|
|
837
417
|
hidden_states: torch.Tensor,
|
838
418
|
topk_idx: torch.Tensor,
|
839
419
|
topk_weights: torch.Tensor,
|
840
|
-
reorder_topk_ids: torch.Tensor,
|
841
|
-
seg_indptr: torch.Tensor,
|
842
|
-
masked_m: torch.Tensor,
|
843
|
-
expected_m: int,
|
844
|
-
num_recv_tokens_per_expert: List[int],
|
845
420
|
forward_batch: ForwardBatch,
|
846
421
|
):
|
847
|
-
|
848
|
-
|
849
|
-
return self.forward_aiter(hidden_states, topk_idx, topk_weights)
|
850
|
-
resolved_deepep_mode = self.deepep_mode.resolve(
|
851
|
-
forward_batch.is_extend_in_batch
|
422
|
+
dispatch_output = self.dispatch(
|
423
|
+
hidden_states, topk_idx, topk_weights, forward_batch
|
852
424
|
)
|
853
|
-
|
854
|
-
|
855
|
-
|
856
|
-
|
857
|
-
|
858
|
-
|
859
|
-
|
860
|
-
|
861
|
-
return self.forward_deepgemm_masked(hidden_states, masked_m, expected_m)
|
862
|
-
else:
|
863
|
-
raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}")
|
425
|
+
hidden_states = self.moe_impl(dispatch_output)
|
426
|
+
hidden_states = self.combine(
|
427
|
+
hidden_states,
|
428
|
+
dispatch_output.topk_idx,
|
429
|
+
dispatch_output.topk_weights,
|
430
|
+
forward_batch,
|
431
|
+
)
|
432
|
+
return hidden_states
|
864
433
|
|
865
|
-
def
|
434
|
+
def dispatch(
|
866
435
|
self,
|
867
436
|
hidden_states: torch.Tensor,
|
868
|
-
|
869
|
-
|
437
|
+
topk_idx: torch.Tensor,
|
438
|
+
topk_weights: torch.Tensor,
|
439
|
+
forward_batch: ForwardBatch,
|
870
440
|
):
|
871
|
-
|
872
|
-
|
873
|
-
|
874
|
-
|
875
|
-
|
876
|
-
if self.grouped_gemm_runner is None:
|
877
|
-
self.grouped_gemm_runner = GroupedGemmRunner(
|
878
|
-
hidden_states.device, use_flashinfer=False # TODO: use flashinfer
|
879
|
-
)
|
880
|
-
|
881
|
-
if self.activation_scheme == "dynamic" and not self.use_block_quant:
|
882
|
-
max_value = (
|
883
|
-
torch.max(hidden_states)
|
884
|
-
.repeat(self.num_local_experts)
|
885
|
-
.to(torch.float32)
|
886
|
-
)
|
887
|
-
self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max
|
888
|
-
weight_indices_cur_rank = torch.arange(
|
889
|
-
0,
|
890
|
-
self.num_local_experts,
|
891
|
-
device=hidden_states.device,
|
892
|
-
dtype=torch.int64,
|
441
|
+
return self.deepep_dispatcher.dispatch(
|
442
|
+
hidden_states=hidden_states,
|
443
|
+
topk_idx=topk_idx,
|
444
|
+
topk_weights=topk_weights,
|
445
|
+
forward_batch=forward_batch,
|
893
446
|
)
|
894
447
|
|
895
|
-
|
896
|
-
if
|
897
|
-
|
898
|
-
|
899
|
-
|
900
|
-
|
901
|
-
|
902
|
-
|
903
|
-
|
904
|
-
|
905
|
-
weight_indices=weight_indices_cur_rank,
|
906
|
-
use_fp8_w8a8=self.use_fp8_w8a8,
|
907
|
-
scale_a=self.w13_input_scale,
|
908
|
-
scale_b=(
|
909
|
-
self.w13_weight_scale_inv
|
910
|
-
if self.use_block_quant
|
911
|
-
else self.w13_weight_scale
|
912
|
-
),
|
913
|
-
block_shape=self.block_shape,
|
914
|
-
)
|
915
|
-
else:
|
916
|
-
gateup_output = torch.empty(
|
917
|
-
hidden_states.shape[0],
|
918
|
-
self.w13_weight.shape[1],
|
919
|
-
device=hidden_states.device,
|
920
|
-
dtype=hidden_states.dtype,
|
921
|
-
)
|
922
|
-
|
923
|
-
# Act
|
924
|
-
down_input = torch.empty(
|
925
|
-
gateup_output.shape[0],
|
926
|
-
gateup_output.shape[1] // 2,
|
927
|
-
device=gateup_output.device,
|
928
|
-
dtype=(
|
929
|
-
self.fp8_dtype
|
930
|
-
if (self.use_fp8_w8a8 and not self.use_block_quant)
|
931
|
-
else hidden_states_dtype
|
932
|
-
),
|
933
|
-
)
|
934
|
-
if self.w2_input_scale is None and not self.use_block_quant:
|
935
|
-
self.w2_input_scale = torch.ones(
|
936
|
-
self.num_local_experts,
|
937
|
-
dtype=torch.float32,
|
938
|
-
device=hidden_states_device,
|
939
|
-
)
|
940
|
-
|
941
|
-
if self.activation == "silu":
|
942
|
-
silu_and_mul_triton_kernel[(gateup_output.shape[0],)](
|
943
|
-
gateup_output,
|
944
|
-
down_input,
|
945
|
-
gateup_output.shape[1],
|
946
|
-
reorder_topk_ids,
|
947
|
-
self.w2_input_scale,
|
948
|
-
0,
|
949
|
-
self.num_local_experts - 1,
|
950
|
-
BLOCK_SIZE=512,
|
951
|
-
)
|
448
|
+
def moe_impl(self, dispatch_output: DispatchOutput):
|
449
|
+
if _use_aiter:
|
450
|
+
# in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
|
451
|
+
return self.forward_aiter(dispatch_output)
|
452
|
+
if dispatch_output.format.is_deepep_normal():
|
453
|
+
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
|
454
|
+
return self.forward_deepgemm_contiguous(dispatch_output)
|
455
|
+
elif dispatch_output.format.is_deepep_ll():
|
456
|
+
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
|
457
|
+
return self.forward_deepgemm_masked(dispatch_output)
|
952
458
|
else:
|
953
|
-
raise ValueError(
|
954
|
-
|
955
|
-
del gateup_output
|
956
|
-
|
957
|
-
# GroupGemm-1
|
958
|
-
down_output = torch.empty(
|
959
|
-
down_input.shape[0],
|
960
|
-
self.w2_weight.shape[1],
|
961
|
-
device=hidden_states_device,
|
962
|
-
dtype=hidden_states_dtype,
|
963
|
-
)
|
964
|
-
if down_input.shape[0] > 0:
|
965
|
-
down_output = self.grouped_gemm_runner(
|
966
|
-
a=down_input,
|
967
|
-
b=self.w2_weight,
|
968
|
-
c=down_output,
|
969
|
-
batch_size=self.num_local_experts,
|
970
|
-
weight_column_major=True,
|
971
|
-
seg_indptr=seg_indptr,
|
972
|
-
weight_indices=weight_indices_cur_rank,
|
973
|
-
use_fp8_w8a8=self.use_fp8_w8a8,
|
974
|
-
scale_a=self.w2_input_scale,
|
975
|
-
scale_b=(
|
976
|
-
self.w2_weight_scale_inv
|
977
|
-
if self.use_block_quant
|
978
|
-
else self.w2_weight_scale
|
979
|
-
),
|
980
|
-
block_shape=self.block_shape,
|
459
|
+
raise ValueError(
|
460
|
+
f"Dispatch output format {dispatch_output.format} is not supported"
|
981
461
|
)
|
982
|
-
return down_output
|
983
462
|
|
984
|
-
def
|
463
|
+
def combine(
|
985
464
|
self,
|
986
465
|
hidden_states: torch.Tensor,
|
987
466
|
topk_idx: torch.Tensor,
|
988
467
|
topk_weights: torch.Tensor,
|
468
|
+
forward_batch: ForwardBatch,
|
989
469
|
):
|
470
|
+
return self.deepep_dispatcher.combine(
|
471
|
+
hidden_states=hidden_states,
|
472
|
+
topk_idx=topk_idx,
|
473
|
+
topk_weights=topk_weights,
|
474
|
+
forward_batch=forward_batch,
|
475
|
+
)
|
476
|
+
|
477
|
+
def forward_aiter(
|
478
|
+
self,
|
479
|
+
dispatch_output: DeepEPNormalOutput,
|
480
|
+
):
|
481
|
+
hidden_states, topk_idx, topk_weights = (
|
482
|
+
dispatch_output.hidden_states,
|
483
|
+
dispatch_output.topk_idx,
|
484
|
+
dispatch_output.topk_weights,
|
485
|
+
)
|
990
486
|
if hidden_states.shape[0] == 0:
|
991
487
|
return hidden_states
|
992
488
|
# in original deepep, idx == -1 meaning invalid and will not be processed.
|
@@ -1014,11 +510,11 @@ class DeepEPMoE(EPMoE):
|
|
1014
510
|
|
1015
511
|
def forward_deepgemm_contiguous(
|
1016
512
|
self,
|
1017
|
-
|
1018
|
-
topk_idx,
|
1019
|
-
topk_weights,
|
1020
|
-
num_recv_tokens_per_expert: List[int],
|
513
|
+
dispatch_output: DeepEPNormalOutput,
|
1021
514
|
):
|
515
|
+
hidden_states_fp8, topk_idx, topk_weights, num_recv_tokens_per_expert = (
|
516
|
+
dispatch_output
|
517
|
+
)
|
1022
518
|
hidden_states_fp8, hidden_states_scale = hidden_states_fp8
|
1023
519
|
assert self.quant_method is not None
|
1024
520
|
assert self.activation == "silu"
|
@@ -1138,10 +634,9 @@ class DeepEPMoE(EPMoE):
|
|
1138
634
|
|
1139
635
|
def forward_deepgemm_masked(
|
1140
636
|
self,
|
1141
|
-
|
1142
|
-
masked_m: torch.Tensor,
|
1143
|
-
expected_m: int,
|
637
|
+
dispatch_output: DeepEPLLOutput,
|
1144
638
|
):
|
639
|
+
hidden_states_fp8, _, _, masked_m, expected_m = dispatch_output
|
1145
640
|
assert self.quant_method is not None
|
1146
641
|
assert self.activation == "silu"
|
1147
642
|
|
@@ -1268,7 +763,7 @@ class FlashInferEPMoE(EPMoE):
|
|
1268
763
|
topk_group=self.topk_group,
|
1269
764
|
intermediate_size=self.w2_weight.shape[2],
|
1270
765
|
local_expert_offset=self.start_expert_id,
|
1271
|
-
local_num_experts=self.
|
766
|
+
local_num_experts=self.num_local_experts,
|
1272
767
|
routed_scaling_factor=self.routed_scaling_factor,
|
1273
768
|
tile_tokens_dim=_get_tile_tokens_dim(
|
1274
769
|
hidden_states.shape[0], self.top_k, self.num_experts
|