sglang 0.4.9.post6__py3-none-any.whl → 0.4.10__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/srt/configs/__init__.py +8 -0
- sglang/srt/configs/model_config.py +3 -0
- sglang/srt/configs/step3_vl.py +172 -0
- sglang/srt/conversation.py +23 -0
- sglang/srt/disaggregation/decode.py +2 -8
- sglang/srt/disaggregation/prefill.py +2 -6
- sglang/srt/distributed/parallel_state.py +86 -1
- sglang/srt/entrypoints/engine.py +14 -18
- sglang/srt/entrypoints/http_server.py +10 -2
- sglang/srt/entrypoints/openai/serving_chat.py +2 -21
- sglang/srt/eplb/expert_distribution.py +5 -0
- sglang/srt/eplb/expert_location.py +17 -6
- sglang/srt/eplb/expert_location_dispatch.py +1 -0
- sglang/srt/eplb/expert_location_updater.py +2 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/step3_detector.py +436 -0
- sglang/srt/hf_transformers_utils.py +2 -0
- sglang/srt/jinja_template_utils.py +4 -1
- sglang/srt/layers/moe/cutlass_moe.py +2 -1
- sglang/srt/layers/moe/ep_moe/layer.py +20 -640
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +26 -13
- sglang/srt/layers/moe/fused_moe_triton/layer.py +97 -38
- sglang/srt/layers/quantization/fp8.py +0 -18
- sglang/srt/layers/quantization/unquant.py +0 -8
- sglang/srt/layers/quantization/w4afp8.py +1 -0
- sglang/srt/managers/cache_controller.py +143 -45
- sglang/srt/managers/data_parallel_controller.py +2 -0
- sglang/srt/managers/io_struct.py +0 -2
- sglang/srt/managers/scheduler.py +89 -671
- sglang/srt/managers/scheduler_metrics_mixin.py +229 -0
- sglang/srt/managers/scheduler_profiler_mixin.py +279 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +142 -0
- sglang/srt/managers/template_manager.py +62 -19
- sglang/srt/managers/tokenizer_manager.py +123 -74
- sglang/srt/managers/tp_worker.py +4 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +2 -1
- sglang/srt/mem_cache/hicache_storage.py +45 -11
- sglang/srt/mem_cache/hiradix_cache.py +15 -4
- sglang/srt/mem_cache/memory_pool_host.py +73 -1
- sglang/srt/mem_cache/mooncake_store/mooncake_store.py +264 -0
- sglang/srt/mem_cache/mooncake_store/unit_test.py +40 -0
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +177 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +278 -0
- sglang/srt/mem_cache/storage/hf3fs/test_hf3fs_utils.py +43 -0
- sglang/srt/model_executor/model_runner.py +5 -0
- sglang/srt/models/arcee.py +532 -0
- sglang/srt/models/deepseek_v2.py +2 -0
- sglang/srt/models/glm4_moe.py +3 -1
- sglang/srt/models/granitemoe.py +3 -0
- sglang/srt/models/grok.py +3 -0
- sglang/srt/models/hunyuan.py +1 -0
- sglang/srt/models/llama4.py +3 -0
- sglang/srt/models/mixtral.py +3 -0
- sglang/srt/models/olmoe.py +3 -0
- sglang/srt/models/phimoe.py +1 -0
- sglang/srt/models/step3_vl.py +994 -0
- sglang/srt/multimodal/processors/base_processor.py +15 -16
- sglang/srt/multimodal/processors/step3_vl.py +515 -0
- sglang/srt/reasoning_parser.py +2 -1
- sglang/srt/server_args.py +10 -13
- sglang/srt/speculative/eagle_worker.py +2 -0
- sglang/utils.py +0 -11
- sglang/version.py +1 -1
- {sglang-0.4.9.post6.dist-info → sglang-0.4.10.dist-info}/METADATA +3 -4
- {sglang-0.4.9.post6.dist-info → sglang-0.4.10.dist-info}/RECORD +69 -56
- {sglang-0.4.9.post6.dist-info → sglang-0.4.10.dist-info}/WHEEL +0 -0
- {sglang-0.4.9.post6.dist-info → sglang-0.4.10.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.9.post6.dist-info → sglang-0.4.10.dist-info}/top_level.txt +0 -0
@@ -86,79 +86,6 @@ if use_flashinfer_trtllm_moe:
|
|
86
86
|
logger = logging.getLogger(__name__)
|
87
87
|
|
88
88
|
|
89
|
-
class GroupedGemmRunner(torch.nn.Module):
|
90
|
-
flashinfer_gemm_warpper = None
|
91
|
-
|
92
|
-
def __init__(
|
93
|
-
self,
|
94
|
-
device,
|
95
|
-
use_flashinfer: bool = False,
|
96
|
-
use_per_token_if_dynamic: bool = True,
|
97
|
-
):
|
98
|
-
super().__init__()
|
99
|
-
self.device = device
|
100
|
-
self.use_flashinfer = use_flashinfer
|
101
|
-
self.use_per_token_if_dynamic = use_per_token_if_dynamic
|
102
|
-
if self.use_flashinfer and GroupedGemmRunner.flashinfer_gemm_warpper is None:
|
103
|
-
GroupedGemmRunner._init_flashinfer_wrapper(device)
|
104
|
-
|
105
|
-
@classmethod
|
106
|
-
def _init_flashinfer_wrapper(cls, device):
|
107
|
-
from flashinfer import SegmentGEMMWrapper
|
108
|
-
|
109
|
-
workspace_buffer = torch.empty(
|
110
|
-
128 * 1024 * 1024, dtype=torch.int8, device=device
|
111
|
-
)
|
112
|
-
cls.flashinfer_gemm_warpper = SegmentGEMMWrapper(workspace_buffer)
|
113
|
-
|
114
|
-
# c = a * b
|
115
|
-
def forward(
|
116
|
-
self,
|
117
|
-
a: torch.Tensor,
|
118
|
-
b: torch.Tensor,
|
119
|
-
c: torch.Tensor,
|
120
|
-
batch_size: int,
|
121
|
-
weight_column_major: bool,
|
122
|
-
seg_indptr: Optional[torch.Tensor] = None,
|
123
|
-
weight_indices: Optional[torch.Tensor] = None,
|
124
|
-
use_fp8_w8a8: bool = False,
|
125
|
-
scale_a: torch.Tensor = None,
|
126
|
-
scale_b: torch.Tensor = None,
|
127
|
-
block_shape: Optional[List[int]] = None,
|
128
|
-
c_dtype=None,
|
129
|
-
):
|
130
|
-
if self.use_flashinfer:
|
131
|
-
# TODO: flashinfer
|
132
|
-
assert False
|
133
|
-
assert GroupedGemmRunner.flashinfer_gemm_warpper is not None
|
134
|
-
c = GroupedGemmRunner.flashinfer_gemm_warpper.run(
|
135
|
-
x=a,
|
136
|
-
weights=b,
|
137
|
-
batch_size=batch_size,
|
138
|
-
weight_column_major=weight_column_major,
|
139
|
-
seg_indptr=seg_indptr,
|
140
|
-
weight_indices=weight_indices,
|
141
|
-
)
|
142
|
-
else:
|
143
|
-
assert weight_column_major == True
|
144
|
-
c = grouped_gemm_triton(
|
145
|
-
a,
|
146
|
-
b,
|
147
|
-
c,
|
148
|
-
batch_size,
|
149
|
-
weight_column_major,
|
150
|
-
seg_indptr,
|
151
|
-
weight_indices,
|
152
|
-
use_fp8_w8a8,
|
153
|
-
scale_a,
|
154
|
-
scale_b,
|
155
|
-
block_shape=block_shape,
|
156
|
-
c_dtype=c_dtype,
|
157
|
-
use_per_token_if_dynamic=self.use_per_token_if_dynamic,
|
158
|
-
)
|
159
|
-
return c
|
160
|
-
|
161
|
-
|
162
89
|
def _get_tile_tokens_dim(num_tokens, top_k, num_experts):
|
163
90
|
# Guess tokens per expert assuming perfect expert distribution first.
|
164
91
|
num_tokens_per_expert = (num_tokens * top_k) // num_experts
|
@@ -183,140 +110,57 @@ class EPMoE(FusedMoE):
|
|
183
110
|
hidden_size: int,
|
184
111
|
intermediate_size: int,
|
185
112
|
layer_id: int,
|
113
|
+
num_fused_shared_experts: int = 0,
|
186
114
|
params_dtype: Optional[torch.dtype] = None,
|
187
115
|
quant_config: Optional[QuantizationConfig] = None,
|
188
116
|
tp_size: Optional[int] = None,
|
189
117
|
prefix: str = "",
|
190
118
|
activation: str = "silu",
|
191
119
|
routed_scaling_factor: Optional[float] = None,
|
192
|
-
use_per_token_if_dynamic: bool = True,
|
193
120
|
):
|
194
121
|
super().__init__(
|
195
122
|
num_experts=num_experts,
|
196
123
|
hidden_size=hidden_size,
|
197
124
|
intermediate_size=intermediate_size,
|
198
|
-
|
125
|
+
num_fused_shared_experts=num_fused_shared_experts,
|
199
126
|
layer_id=layer_id,
|
127
|
+
top_k=top_k,
|
200
128
|
params_dtype=params_dtype,
|
201
129
|
quant_config=quant_config,
|
202
130
|
tp_size=tp_size,
|
203
131
|
prefix=prefix,
|
204
132
|
activation=activation,
|
133
|
+
# apply_router_weight_on_input=apply_router_weight_on_input,
|
205
134
|
routed_scaling_factor=routed_scaling_factor,
|
206
135
|
enable_ep_moe=True,
|
207
|
-
skip_quant=True,
|
208
136
|
)
|
209
137
|
|
210
|
-
|
211
|
-
params_dtype = torch.get_default_dtype()
|
212
|
-
|
213
|
-
self.layer_id = layer_id
|
214
|
-
self.num_local_experts, self.expert_map = self.determine_expert_map()
|
215
|
-
self.start_expert_id = self.ep_rank * self.num_local_experts
|
138
|
+
self.start_expert_id = self.moe_ep_rank * self.num_local_experts
|
216
139
|
self.end_expert_id = self.start_expert_id + self.num_local_experts - 1
|
217
140
|
|
218
141
|
self.intermediate_size = intermediate_size
|
219
|
-
self.use_per_token_if_dynamic = use_per_token_if_dynamic
|
220
142
|
|
221
|
-
|
222
|
-
if quant_config is None:
|
223
|
-
self.quant_method: Optional[QuantizeMethodBase] = (
|
224
|
-
UnquantizedFusedMoEMethod()
|
225
|
-
)
|
226
|
-
self.use_fp8_w8a8 = False
|
227
|
-
self.use_block_quant = False
|
228
|
-
self.block_shape = None
|
229
|
-
self.activation_scheme = None
|
230
|
-
self.w13_input_scale = None
|
231
|
-
self.w2_input_scale = None
|
232
|
-
self.w13_weight_scale = None
|
233
|
-
self.w2_weight_scale = None
|
234
|
-
elif isinstance(quant_config, W4AFp8Config):
|
235
|
-
self.quant_method: Optional[QuantizeMethodBase] = W4AFp8MoEMethod(
|
236
|
-
quant_config
|
237
|
-
)
|
238
|
-
self.use_fp8_w8a8 = False
|
239
|
-
self.use_block_quant = False
|
240
|
-
self.fp8_dtype = torch.float8_e4m3fn
|
241
|
-
self.w13_input_scale = None
|
242
|
-
self.w2_input_scale = None
|
243
|
-
self.w13_weight_scale = None
|
244
|
-
self.w2_weight_scale = None
|
245
|
-
self.activation_scheme = quant_config.moe_activation_scheme
|
246
|
-
elif isinstance(quant_config, Fp8Config):
|
247
|
-
self.quant_method: Optional[QuantizeMethodBase] = Fp8MoEMethod(quant_config)
|
248
|
-
self.use_fp8_w8a8 = True
|
143
|
+
if isinstance(quant_config, Fp8Config):
|
249
144
|
self.use_block_quant = getattr(self.quant_method, "block_quant", False)
|
250
145
|
self.block_shape = (
|
251
146
|
self.quant_method.quant_config.weight_block_size
|
252
147
|
if self.use_block_quant
|
253
148
|
else None
|
254
149
|
)
|
150
|
+
self.use_fp8_w8a8 = True
|
255
151
|
self.fp8_dtype = torch.float8_e4m3fn
|
256
152
|
self.activation_scheme = quant_config.activation_scheme
|
257
153
|
else:
|
258
|
-
|
259
|
-
|
260
|
-
|
261
|
-
|
262
|
-
layer=self,
|
263
|
-
num_experts=self.num_local_experts,
|
264
|
-
hidden_size=hidden_size,
|
265
|
-
intermediate_size=self.intermediate_size,
|
266
|
-
params_dtype=params_dtype,
|
267
|
-
weight_loader=self.weight_loader,
|
268
|
-
)
|
269
|
-
|
270
|
-
self.grouped_gemm_runner = None
|
271
|
-
|
272
|
-
# Adapted from https://github.com/vllm-project/vllm/blob/9fb52e523abf7bdaf7e60cf2971edb5a1b13dc08/vllm/model_executor/layers/fused_moe/layer.py#L544C1-L586C43
|
273
|
-
# Modifications: use determine_expert_map as a class internal function, set 'global_num_experts' rather than '-1' for experts not assigned to the current rank.
|
274
|
-
def determine_expert_map(self) -> Tuple[int, Optional[torch.Tensor]]:
|
275
|
-
"""
|
276
|
-
Calculates how many experts should be assigned to each rank for EP and
|
277
|
-
creates a mapping from global to local expert index. Experts are
|
278
|
-
distributed evenly across ranks. Any remaining are assigned to the
|
279
|
-
last rank.
|
280
|
-
|
281
|
-
Returns:
|
282
|
-
Tuple[int, Optional[torch.Tensor]]: A tuple containing:
|
283
|
-
- local_num_experts (int): The number of experts assigned
|
284
|
-
to the current rank.
|
285
|
-
- expert_map (Optional[torch.Tensor]): A tensor of shape
|
286
|
-
(global_num_experts,) mapping from global to local index.
|
287
|
-
Contains global_num_experts for experts not assigned to the current rank.
|
288
|
-
Returns None if ep_size is 1.
|
289
|
-
"""
|
290
|
-
ep_size = self.ep_size
|
291
|
-
ep_rank = self.ep_rank
|
292
|
-
global_num_experts = self.num_experts
|
293
|
-
|
294
|
-
assert ep_size > 0
|
295
|
-
if ep_size == 1:
|
296
|
-
return (global_num_experts, None)
|
297
|
-
|
298
|
-
local_num_experts = global_num_experts // ep_size
|
299
|
-
|
300
|
-
expert_map = torch.full(
|
301
|
-
(global_num_experts,), global_num_experts, dtype=torch.int32
|
302
|
-
)
|
303
|
-
if ep_rank < (ep_size - 1):
|
304
|
-
expert_map[
|
305
|
-
ep_rank * local_num_experts : (ep_rank + 1) * local_num_experts
|
306
|
-
] = torch.arange(0, local_num_experts, dtype=torch.int32)
|
307
|
-
else:
|
308
|
-
local_num_experts = global_num_experts - ep_rank * local_num_experts
|
309
|
-
|
310
|
-
expert_map[-local_num_experts:] = torch.arange(
|
311
|
-
0, local_num_experts, dtype=torch.int32
|
312
|
-
)
|
313
|
-
return (local_num_experts, expert_map)
|
154
|
+
self.use_fp8_w8a8 = False
|
155
|
+
self.use_block_quant = False
|
156
|
+
self.block_shape = None
|
157
|
+
self.activation_scheme = None
|
314
158
|
|
315
159
|
def forward(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
|
316
160
|
if deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8:
|
317
161
|
return self.forward_deepgemm(hidden_states, topk_output)
|
318
162
|
else:
|
319
|
-
return
|
163
|
+
return super().forward(hidden_states, topk_output)
|
320
164
|
|
321
165
|
def forward_deepgemm(
|
322
166
|
self,
|
@@ -475,294 +319,6 @@ class EPMoE(FusedMoE):
|
|
475
319
|
)
|
476
320
|
return output
|
477
321
|
|
478
|
-
def forward_normal(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
|
479
|
-
return self.quant_method.apply(self, hidden_states, topk_output)
|
480
|
-
|
481
|
-
def run_moe(self, hidden_states: torch.Tensor, topk_output: TopKOutput):
|
482
|
-
|
483
|
-
topk_weights, topk_ids, _ = topk_output
|
484
|
-
|
485
|
-
hidden_states_shape = hidden_states.shape
|
486
|
-
hidden_states_dtype = hidden_states.dtype
|
487
|
-
hidden_states_device = hidden_states.device
|
488
|
-
if self.grouped_gemm_runner is None:
|
489
|
-
self.grouped_gemm_runner = GroupedGemmRunner(
|
490
|
-
hidden_states.device,
|
491
|
-
use_flashinfer=False, # TODO: use flashinfer
|
492
|
-
use_per_token_if_dynamic=self.use_per_token_if_dynamic,
|
493
|
-
)
|
494
|
-
|
495
|
-
num_experts = self.num_experts
|
496
|
-
|
497
|
-
reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(
|
498
|
-
topk_ids,
|
499
|
-
num_experts,
|
500
|
-
)
|
501
|
-
|
502
|
-
gateup_input = torch.empty(
|
503
|
-
(int(hidden_states.shape[0] * self.top_k), hidden_states.shape[1]),
|
504
|
-
device=hidden_states.device,
|
505
|
-
dtype=(
|
506
|
-
self.fp8_dtype
|
507
|
-
if self.use_fp8_w8a8 and not self.use_block_quant
|
508
|
-
else hidden_states.dtype
|
509
|
-
),
|
510
|
-
)
|
511
|
-
if self.activation_scheme == "dynamic" and not self.use_block_quant:
|
512
|
-
if self.use_per_token_if_dynamic:
|
513
|
-
max_value = torch.max(hidden_states, dim=1).values.to(torch.float32)
|
514
|
-
self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max
|
515
|
-
else:
|
516
|
-
max_value = (
|
517
|
-
torch.max(hidden_states)
|
518
|
-
.repeat(self.num_local_experts)
|
519
|
-
.to(torch.float32)
|
520
|
-
)
|
521
|
-
self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max
|
522
|
-
|
523
|
-
# PreReorder
|
524
|
-
pre_reorder_triton_kernel[(hidden_states.shape[0],)](
|
525
|
-
hidden_states,
|
526
|
-
gateup_input,
|
527
|
-
src2dst,
|
528
|
-
topk_ids,
|
529
|
-
self.w13_input_scale,
|
530
|
-
self.start_expert_id,
|
531
|
-
self.end_expert_id,
|
532
|
-
self.top_k,
|
533
|
-
hidden_states.shape[1],
|
534
|
-
BLOCK_SIZE=512,
|
535
|
-
use_per_token_if_dynamic=self.use_per_token_if_dynamic,
|
536
|
-
)
|
537
|
-
dispose_tensor(hidden_states)
|
538
|
-
|
539
|
-
if (
|
540
|
-
self.activation_scheme == "dynamic"
|
541
|
-
and not self.use_block_quant
|
542
|
-
and self.use_per_token_if_dynamic
|
543
|
-
):
|
544
|
-
scale = torch.empty(
|
545
|
-
hidden_states_shape[0] * self.top_k,
|
546
|
-
device=hidden_states_device,
|
547
|
-
dtype=torch.float32,
|
548
|
-
)
|
549
|
-
scale[src2dst] = (
|
550
|
-
self.w13_input_scale.unsqueeze(1)
|
551
|
-
.expand(hidden_states_shape[0], self.top_k)
|
552
|
-
.reshape(-1)
|
553
|
-
)
|
554
|
-
self.w13_input_scale = scale
|
555
|
-
|
556
|
-
seg_indptr_cur_rank = seg_indptr[self.start_expert_id : self.end_expert_id + 2]
|
557
|
-
weight_indices_cur_rank = torch.arange(
|
558
|
-
0,
|
559
|
-
self.num_local_experts,
|
560
|
-
device=hidden_states_device,
|
561
|
-
dtype=torch.int64,
|
562
|
-
)
|
563
|
-
# GroupGemm-0
|
564
|
-
gateup_output = self.grouped_gemm_runner(
|
565
|
-
a=gateup_input,
|
566
|
-
b=self.w13_weight,
|
567
|
-
c=None,
|
568
|
-
c_dtype=hidden_states_dtype,
|
569
|
-
batch_size=self.num_local_experts,
|
570
|
-
weight_column_major=True,
|
571
|
-
seg_indptr=seg_indptr_cur_rank,
|
572
|
-
weight_indices=weight_indices_cur_rank,
|
573
|
-
use_fp8_w8a8=self.use_fp8_w8a8,
|
574
|
-
scale_a=self.w13_input_scale,
|
575
|
-
scale_b=self.w13_weight_scale,
|
576
|
-
block_shape=self.block_shape,
|
577
|
-
)
|
578
|
-
del gateup_input
|
579
|
-
|
580
|
-
# Act
|
581
|
-
if self.activation_scheme == "dynamic" and not self.use_block_quant:
|
582
|
-
self.w2_input_scale = None
|
583
|
-
down_input = torch.empty(
|
584
|
-
gateup_output.shape[0],
|
585
|
-
gateup_output.shape[1] // 2,
|
586
|
-
device=gateup_output.device,
|
587
|
-
dtype=hidden_states_dtype,
|
588
|
-
)
|
589
|
-
else:
|
590
|
-
down_input = torch.empty(
|
591
|
-
gateup_output.shape[0],
|
592
|
-
gateup_output.shape[1] // 2,
|
593
|
-
device=gateup_output.device,
|
594
|
-
dtype=(
|
595
|
-
self.fp8_dtype
|
596
|
-
if (self.use_fp8_w8a8 and not self.use_block_quant)
|
597
|
-
else hidden_states_dtype
|
598
|
-
),
|
599
|
-
)
|
600
|
-
|
601
|
-
if self.activation == "silu":
|
602
|
-
silu_and_mul_triton_kernel[(gateup_output.shape[0],)](
|
603
|
-
gateup_output,
|
604
|
-
down_input,
|
605
|
-
gateup_output.shape[1],
|
606
|
-
reorder_topk_ids,
|
607
|
-
self.w2_input_scale,
|
608
|
-
self.start_expert_id,
|
609
|
-
self.end_expert_id,
|
610
|
-
BLOCK_SIZE=512,
|
611
|
-
)
|
612
|
-
elif self.activation == "gelu":
|
613
|
-
gelu_and_mul_triton_kernel[(gateup_output.shape[0],)](
|
614
|
-
gateup_output,
|
615
|
-
down_input,
|
616
|
-
gateup_output.shape[1],
|
617
|
-
reorder_topk_ids,
|
618
|
-
self.w2_input_scale,
|
619
|
-
self.start_expert_id,
|
620
|
-
self.end_expert_id,
|
621
|
-
BLOCK_SIZE=512,
|
622
|
-
)
|
623
|
-
else:
|
624
|
-
raise ValueError(f"Unsupported activation: {self.activation=}")
|
625
|
-
del gateup_output
|
626
|
-
|
627
|
-
if self.activation_scheme == "dynamic" and not self.use_block_quant:
|
628
|
-
if self.use_per_token_if_dynamic:
|
629
|
-
down_input, self.w2_input_scale = sglang_per_token_quant_fp8(down_input)
|
630
|
-
else:
|
631
|
-
self.w2_input_scale = torch.ones(
|
632
|
-
self.num_local_experts,
|
633
|
-
dtype=torch.float32,
|
634
|
-
device=hidden_states_device,
|
635
|
-
)
|
636
|
-
|
637
|
-
# GroupGemm-1
|
638
|
-
down_output = torch.empty(
|
639
|
-
down_input.shape[0],
|
640
|
-
self.w2_weight.shape[1],
|
641
|
-
device=hidden_states_device,
|
642
|
-
dtype=hidden_states_dtype,
|
643
|
-
)
|
644
|
-
down_output = self.grouped_gemm_runner(
|
645
|
-
a=down_input,
|
646
|
-
b=self.w2_weight,
|
647
|
-
c=down_output,
|
648
|
-
batch_size=self.num_local_experts,
|
649
|
-
weight_column_major=True,
|
650
|
-
seg_indptr=seg_indptr_cur_rank,
|
651
|
-
weight_indices=weight_indices_cur_rank,
|
652
|
-
use_fp8_w8a8=self.use_fp8_w8a8,
|
653
|
-
scale_a=self.w2_input_scale,
|
654
|
-
scale_b=self.w2_weight_scale,
|
655
|
-
block_shape=self.block_shape,
|
656
|
-
)
|
657
|
-
del down_input
|
658
|
-
|
659
|
-
# PostReorder
|
660
|
-
output = torch.empty(
|
661
|
-
hidden_states_shape, dtype=hidden_states_dtype, device=hidden_states_device
|
662
|
-
)
|
663
|
-
post_reorder_triton_kernel[(hidden_states_shape[0],)](
|
664
|
-
down_output,
|
665
|
-
output,
|
666
|
-
src2dst,
|
667
|
-
topk_ids,
|
668
|
-
topk_weights,
|
669
|
-
self.start_expert_id,
|
670
|
-
self.end_expert_id,
|
671
|
-
self.top_k,
|
672
|
-
hidden_states_shape[1],
|
673
|
-
0,
|
674
|
-
BLOCK_SIZE=512,
|
675
|
-
)
|
676
|
-
return output
|
677
|
-
|
678
|
-
@classmethod
|
679
|
-
def make_expert_params_mapping(
|
680
|
-
cls,
|
681
|
-
ckpt_gate_proj_name: str,
|
682
|
-
ckpt_down_proj_name: str,
|
683
|
-
ckpt_up_proj_name: str,
|
684
|
-
num_experts: int,
|
685
|
-
) -> List[Tuple[str, str, int, str]]:
|
686
|
-
return [
|
687
|
-
# (param_name, weight_name, expert_id, shard_id)
|
688
|
-
(
|
689
|
-
(
|
690
|
-
"experts.w13_"
|
691
|
-
if weight_name in [ckpt_gate_proj_name, ckpt_up_proj_name]
|
692
|
-
else "experts.w2_"
|
693
|
-
),
|
694
|
-
f"experts.{expert_id}.{weight_name}.",
|
695
|
-
expert_id,
|
696
|
-
shard_id,
|
697
|
-
)
|
698
|
-
for expert_id in range(num_experts)
|
699
|
-
for shard_id, weight_name in [
|
700
|
-
("w1", ckpt_gate_proj_name),
|
701
|
-
("w2", ckpt_down_proj_name),
|
702
|
-
("w3", ckpt_up_proj_name),
|
703
|
-
]
|
704
|
-
]
|
705
|
-
|
706
|
-
@classmethod
|
707
|
-
def make_expert_input_scale_params_mapping(
|
708
|
-
cls,
|
709
|
-
num_experts: int,
|
710
|
-
) -> List[Tuple[str, str, int, str]]:
|
711
|
-
# (param_name, weight_name, expert_id, shard_id)
|
712
|
-
return [
|
713
|
-
(
|
714
|
-
"experts.w13_" if shard_id in ["w1", "w3"] else "experts.w2_",
|
715
|
-
f"experts.{expert_id}.{shard_id}.",
|
716
|
-
expert_id,
|
717
|
-
shard_id,
|
718
|
-
)
|
719
|
-
for expert_id in range(num_experts)
|
720
|
-
for shard_id in ["w1", "w2", "w3"]
|
721
|
-
]
|
722
|
-
|
723
|
-
def weight_loader(
|
724
|
-
self,
|
725
|
-
param: torch.nn.Parameter,
|
726
|
-
loaded_weight: torch.Tensor,
|
727
|
-
weight_name: str,
|
728
|
-
shard_id: str,
|
729
|
-
expert_id: int,
|
730
|
-
) -> None:
|
731
|
-
physical_expert_ids = (
|
732
|
-
get_global_expert_location_metadata().logical_to_all_physical(
|
733
|
-
self.layer_id, expert_id
|
734
|
-
)
|
735
|
-
)
|
736
|
-
for physical_expert_id in physical_expert_ids:
|
737
|
-
self._weight_loader_physical(
|
738
|
-
param=param,
|
739
|
-
loaded_weight=loaded_weight,
|
740
|
-
weight_name=weight_name,
|
741
|
-
shard_id=shard_id,
|
742
|
-
expert_id=physical_expert_id,
|
743
|
-
)
|
744
|
-
|
745
|
-
def _weight_loader_physical(
|
746
|
-
self,
|
747
|
-
param: torch.nn.Parameter,
|
748
|
-
loaded_weight: torch.Tensor,
|
749
|
-
weight_name: str,
|
750
|
-
shard_id: str,
|
751
|
-
expert_id: int,
|
752
|
-
) -> None:
|
753
|
-
if expert_id < self.start_expert_id or expert_id > self.end_expert_id:
|
754
|
-
return
|
755
|
-
expert_id = expert_id - self.start_expert_id
|
756
|
-
|
757
|
-
self._weight_loader_impl(
|
758
|
-
param=param,
|
759
|
-
loaded_weight=loaded_weight,
|
760
|
-
weight_name=weight_name,
|
761
|
-
shard_id=shard_id,
|
762
|
-
expert_id=expert_id,
|
763
|
-
)
|
764
|
-
return
|
765
|
-
|
766
322
|
|
767
323
|
class DeepEPMoE(EPMoE):
|
768
324
|
"""
|
@@ -778,6 +334,7 @@ class DeepEPMoE(EPMoE):
|
|
778
334
|
hidden_size: int,
|
779
335
|
intermediate_size: int,
|
780
336
|
layer_id: int,
|
337
|
+
num_fused_shared_experts: int = 0,
|
781
338
|
params_dtype: Optional[torch.dtype] = None,
|
782
339
|
quant_config: Optional[QuantizationConfig] = None,
|
783
340
|
tp_size: Optional[int] = None,
|
@@ -792,6 +349,7 @@ class DeepEPMoE(EPMoE):
|
|
792
349
|
hidden_size=hidden_size,
|
793
350
|
intermediate_size=intermediate_size,
|
794
351
|
layer_id=layer_id,
|
352
|
+
num_fused_shared_experts=num_fused_shared_experts,
|
795
353
|
params_dtype=params_dtype,
|
796
354
|
quant_config=quant_config,
|
797
355
|
tp_size=tp_size,
|
@@ -892,14 +450,15 @@ class DeepEPMoE(EPMoE):
|
|
892
450
|
# in forward_aiter, we skip token permutation and unpermutation, which have been fused inside aiter kernel
|
893
451
|
return self.forward_aiter(dispatch_output)
|
894
452
|
if dispatch_output.format.is_deepep_normal():
|
895
|
-
|
896
|
-
|
897
|
-
else:
|
898
|
-
return self.forward_normal(dispatch_output)
|
453
|
+
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
|
454
|
+
return self.forward_deepgemm_contiguous(dispatch_output)
|
899
455
|
elif dispatch_output.format.is_deepep_ll():
|
456
|
+
assert deep_gemm_wrapper.ENABLE_JIT_DEEPGEMM and self.use_fp8_w8a8
|
900
457
|
return self.forward_deepgemm_masked(dispatch_output)
|
901
458
|
else:
|
902
|
-
raise ValueError(
|
459
|
+
raise ValueError(
|
460
|
+
f"Dispatch output format {dispatch_output.format} is not supported"
|
461
|
+
)
|
903
462
|
|
904
463
|
def combine(
|
905
464
|
self,
|
@@ -915,185 +474,6 @@ class DeepEPMoE(EPMoE):
|
|
915
474
|
forward_batch=forward_batch,
|
916
475
|
)
|
917
476
|
|
918
|
-
def _prepare_for_normal(
|
919
|
-
self,
|
920
|
-
hidden_states: torch.Tensor,
|
921
|
-
topk_idx: torch.Tensor,
|
922
|
-
):
|
923
|
-
from sglang.srt.layers.moe.ep_moe.kernels import (
|
924
|
-
deepep_permute_triton_kernel,
|
925
|
-
deepep_run_moe_deep_preprocess,
|
926
|
-
)
|
927
|
-
|
928
|
-
if hidden_states.shape[0] == 0:
|
929
|
-
reorder_topk_ids = torch.empty(
|
930
|
-
(0,), device=hidden_states.device, dtype=torch.int64
|
931
|
-
)
|
932
|
-
seg_indptr = torch.zeros(
|
933
|
-
(self.num_experts + 1,),
|
934
|
-
device=hidden_states.device,
|
935
|
-
dtype=torch.int64,
|
936
|
-
)
|
937
|
-
return reorder_topk_ids, seg_indptr, hidden_states
|
938
|
-
else:
|
939
|
-
if _use_aiter:
|
940
|
-
# skip permutation here as aiter fused_moe has fused inside
|
941
|
-
reorder_topk_ids = torch.empty(
|
942
|
-
(0,), device=hidden_states.device, dtype=torch.int64
|
943
|
-
)
|
944
|
-
seg_indptr = torch.zeros(
|
945
|
-
(self.num_experts + 1,),
|
946
|
-
device=hidden_states.device,
|
947
|
-
dtype=torch.int64,
|
948
|
-
)
|
949
|
-
return reorder_topk_ids, seg_indptr, hidden_states
|
950
|
-
|
951
|
-
reorder_topk_ids, self.src2dst, seg_indptr = deepep_run_moe_deep_preprocess(
|
952
|
-
topk_idx, self.num_experts
|
953
|
-
)
|
954
|
-
num_total_tokens = reorder_topk_ids.numel()
|
955
|
-
gateup_input = torch.empty(
|
956
|
-
(int(num_total_tokens), hidden_states.shape[1]),
|
957
|
-
device=hidden_states.device,
|
958
|
-
dtype=hidden_states.dtype,
|
959
|
-
)
|
960
|
-
# PreReorder
|
961
|
-
deepep_permute_triton_kernel[(hidden_states.shape[0],)](
|
962
|
-
hidden_states,
|
963
|
-
gateup_input,
|
964
|
-
self.src2dst,
|
965
|
-
topk_idx,
|
966
|
-
None,
|
967
|
-
self.router_topk,
|
968
|
-
hidden_states.shape[1],
|
969
|
-
BLOCK_SIZE=512,
|
970
|
-
)
|
971
|
-
return reorder_topk_ids, seg_indptr, gateup_input
|
972
|
-
|
973
|
-
def forward_normal(
|
974
|
-
self,
|
975
|
-
dispatch_output: DeepEPNormalOutput,
|
976
|
-
):
|
977
|
-
hidden_states, topk_idx = (
|
978
|
-
dispatch_output.hidden_states,
|
979
|
-
dispatch_output.topk_idx,
|
980
|
-
)
|
981
|
-
reorder_topk_ids, seg_indptr, hidden_states = self._prepare_for_normal(
|
982
|
-
hidden_states, topk_idx
|
983
|
-
)
|
984
|
-
hidden_states_dtype = hidden_states.dtype
|
985
|
-
hidden_states_device = hidden_states.device
|
986
|
-
|
987
|
-
assert self.quant_method is not None
|
988
|
-
assert self.activation == "silu"
|
989
|
-
if self.grouped_gemm_runner is None:
|
990
|
-
self.grouped_gemm_runner = GroupedGemmRunner(
|
991
|
-
hidden_states.device, use_flashinfer=False # TODO: use flashinfer
|
992
|
-
)
|
993
|
-
|
994
|
-
if self.activation_scheme == "dynamic" and not self.use_block_quant:
|
995
|
-
max_value = (
|
996
|
-
torch.max(hidden_states)
|
997
|
-
.repeat(self.num_local_experts)
|
998
|
-
.to(torch.float32)
|
999
|
-
)
|
1000
|
-
self.w13_input_scale = max_value / torch.finfo(self.fp8_dtype).max
|
1001
|
-
weight_indices_cur_rank = torch.arange(
|
1002
|
-
0,
|
1003
|
-
self.num_local_experts,
|
1004
|
-
device=hidden_states.device,
|
1005
|
-
dtype=torch.int64,
|
1006
|
-
)
|
1007
|
-
|
1008
|
-
# GroupGemm-0
|
1009
|
-
if hidden_states.shape[0] > 0:
|
1010
|
-
gateup_output = self.grouped_gemm_runner(
|
1011
|
-
a=hidden_states,
|
1012
|
-
b=self.w13_weight,
|
1013
|
-
c=None,
|
1014
|
-
c_dtype=hidden_states.dtype,
|
1015
|
-
batch_size=self.num_local_experts,
|
1016
|
-
weight_column_major=True,
|
1017
|
-
seg_indptr=seg_indptr,
|
1018
|
-
weight_indices=weight_indices_cur_rank,
|
1019
|
-
use_fp8_w8a8=self.use_fp8_w8a8,
|
1020
|
-
scale_a=self.w13_input_scale,
|
1021
|
-
scale_b=(
|
1022
|
-
self.w13_weight_scale_inv
|
1023
|
-
if self.use_block_quant
|
1024
|
-
else self.w13_weight_scale
|
1025
|
-
),
|
1026
|
-
block_shape=self.block_shape,
|
1027
|
-
)
|
1028
|
-
else:
|
1029
|
-
gateup_output = torch.empty(
|
1030
|
-
hidden_states.shape[0],
|
1031
|
-
self.w13_weight.shape[1],
|
1032
|
-
device=hidden_states.device,
|
1033
|
-
dtype=hidden_states.dtype,
|
1034
|
-
)
|
1035
|
-
|
1036
|
-
# Act
|
1037
|
-
down_input = torch.empty(
|
1038
|
-
gateup_output.shape[0],
|
1039
|
-
gateup_output.shape[1] // 2,
|
1040
|
-
device=gateup_output.device,
|
1041
|
-
dtype=(
|
1042
|
-
self.fp8_dtype
|
1043
|
-
if (self.use_fp8_w8a8 and not self.use_block_quant)
|
1044
|
-
else hidden_states_dtype
|
1045
|
-
),
|
1046
|
-
)
|
1047
|
-
if self.w2_input_scale is None and not self.use_block_quant:
|
1048
|
-
self.w2_input_scale = torch.ones(
|
1049
|
-
self.num_local_experts,
|
1050
|
-
dtype=torch.float32,
|
1051
|
-
device=hidden_states_device,
|
1052
|
-
)
|
1053
|
-
|
1054
|
-
if self.activation == "silu":
|
1055
|
-
silu_and_mul_triton_kernel[(gateup_output.shape[0],)](
|
1056
|
-
gateup_output,
|
1057
|
-
down_input,
|
1058
|
-
gateup_output.shape[1],
|
1059
|
-
reorder_topk_ids,
|
1060
|
-
self.w2_input_scale,
|
1061
|
-
0,
|
1062
|
-
self.num_local_experts - 1,
|
1063
|
-
BLOCK_SIZE=512,
|
1064
|
-
)
|
1065
|
-
else:
|
1066
|
-
raise ValueError(f"Unsupported activation: {self.activation=}")
|
1067
|
-
|
1068
|
-
del gateup_output
|
1069
|
-
|
1070
|
-
# GroupGemm-1
|
1071
|
-
down_output = torch.empty(
|
1072
|
-
down_input.shape[0],
|
1073
|
-
self.w2_weight.shape[1],
|
1074
|
-
device=hidden_states_device,
|
1075
|
-
dtype=hidden_states_dtype,
|
1076
|
-
)
|
1077
|
-
if down_input.shape[0] > 0:
|
1078
|
-
down_output = self.grouped_gemm_runner(
|
1079
|
-
a=down_input,
|
1080
|
-
b=self.w2_weight,
|
1081
|
-
c=down_output,
|
1082
|
-
batch_size=self.num_local_experts,
|
1083
|
-
weight_column_major=True,
|
1084
|
-
seg_indptr=seg_indptr,
|
1085
|
-
weight_indices=weight_indices_cur_rank,
|
1086
|
-
use_fp8_w8a8=self.use_fp8_w8a8,
|
1087
|
-
scale_a=self.w2_input_scale,
|
1088
|
-
scale_b=(
|
1089
|
-
self.w2_weight_scale_inv
|
1090
|
-
if self.use_block_quant
|
1091
|
-
else self.w2_weight_scale
|
1092
|
-
),
|
1093
|
-
block_shape=self.block_shape,
|
1094
|
-
)
|
1095
|
-
return down_output
|
1096
|
-
|
1097
477
|
def forward_aiter(
|
1098
478
|
self,
|
1099
479
|
dispatch_output: DeepEPNormalOutput,
|