sglang 0.4.6.post3__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +10 -8
- sglang/bench_one_batch.py +7 -6
- sglang/bench_one_batch_server.py +157 -21
- sglang/bench_serving.py +137 -59
- sglang/compile_deep_gemm.py +5 -5
- sglang/eval/loogle_eval.py +157 -0
- sglang/lang/chat_template.py +78 -78
- sglang/lang/tracer.py +1 -1
- sglang/srt/code_completion_parser.py +1 -1
- sglang/srt/configs/deepseekvl2.py +2 -2
- sglang/srt/configs/model_config.py +40 -28
- sglang/srt/constrained/base_grammar_backend.py +55 -72
- sglang/srt/constrained/llguidance_backend.py +25 -21
- sglang/srt/constrained/outlines_backend.py +27 -26
- sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
- sglang/srt/constrained/xgrammar_backend.py +69 -43
- sglang/srt/conversation.py +49 -44
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +129 -135
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
- sglang/srt/disaggregation/fake/conn.py +3 -13
- sglang/srt/disaggregation/kv_events.py +357 -0
- sglang/srt/disaggregation/mini_lb.py +57 -24
- sglang/srt/disaggregation/mooncake/conn.py +238 -122
- sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
- sglang/srt/disaggregation/nixl/conn.py +10 -19
- sglang/srt/disaggregation/prefill.py +132 -47
- sglang/srt/disaggregation/utils.py +123 -6
- sglang/srt/distributed/utils.py +3 -3
- sglang/srt/entrypoints/EngineBase.py +5 -0
- sglang/srt/entrypoints/engine.py +44 -9
- sglang/srt/entrypoints/http_server.py +23 -6
- sglang/srt/entrypoints/http_server_engine.py +5 -2
- sglang/srt/function_call/base_format_detector.py +250 -0
- sglang/srt/function_call/core_types.py +34 -0
- sglang/srt/function_call/deepseekv3_detector.py +157 -0
- sglang/srt/function_call/ebnf_composer.py +234 -0
- sglang/srt/function_call/function_call_parser.py +175 -0
- sglang/srt/function_call/llama32_detector.py +74 -0
- sglang/srt/function_call/mistral_detector.py +84 -0
- sglang/srt/function_call/pythonic_detector.py +163 -0
- sglang/srt/function_call/qwen25_detector.py +67 -0
- sglang/srt/function_call/utils.py +35 -0
- sglang/srt/hf_transformers_utils.py +46 -7
- sglang/srt/layers/attention/aiter_backend.py +513 -0
- sglang/srt/layers/attention/flashattention_backend.py +64 -18
- sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
- sglang/srt/layers/attention/flashmla_backend.py +340 -78
- sglang/srt/layers/attention/triton_backend.py +3 -0
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
- sglang/srt/layers/attention/utils.py +6 -4
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/communicator.py +451 -0
- sglang/srt/layers/dp_attention.py +61 -21
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/logits_processor.py +46 -11
- sglang/srt/layers/moe/cutlass_moe.py +207 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +34 -12
- sglang/srt/layers/moe/ep_moe/layer.py +105 -51
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
- sglang/srt/layers/moe/topk.py +67 -10
- sglang/srt/layers/multimodal.py +70 -0
- sglang/srt/layers/quantization/__init__.py +8 -3
- sglang/srt/layers/quantization/blockwise_int8.py +2 -2
- sglang/srt/layers/quantization/deep_gemm.py +77 -74
- sglang/srt/layers/quantization/fp8.py +92 -2
- sglang/srt/layers/quantization/fp8_kernel.py +3 -3
- sglang/srt/layers/quantization/fp8_utils.py +6 -0
- sglang/srt/layers/quantization/gptq.py +298 -6
- sglang/srt/layers/quantization/int8_kernel.py +20 -7
- sglang/srt/layers/quantization/qoq.py +244 -0
- sglang/srt/layers/sampler.py +0 -4
- sglang/srt/layers/vocab_parallel_embedding.py +18 -7
- sglang/srt/lora/lora_manager.py +2 -4
- sglang/srt/lora/mem_pool.py +4 -4
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/data_parallel_controller.py +3 -3
- sglang/srt/managers/deepseek_eplb.py +278 -0
- sglang/srt/managers/detokenizer_manager.py +21 -8
- sglang/srt/managers/eplb_manager.py +55 -0
- sglang/srt/managers/expert_distribution.py +704 -56
- sglang/srt/managers/expert_location.py +394 -0
- sglang/srt/managers/expert_location_dispatch.py +91 -0
- sglang/srt/managers/io_struct.py +19 -4
- sglang/srt/managers/mm_utils.py +294 -140
- sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
- sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
- sglang/srt/managers/multimodal_processors/internvl.py +14 -5
- sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
- sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
- sglang/srt/managers/multimodal_processors/llava.py +46 -0
- sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
- sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
- sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
- sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
- sglang/srt/managers/schedule_batch.py +122 -42
- sglang/srt/managers/schedule_policy.py +1 -5
- sglang/srt/managers/scheduler.py +205 -138
- sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +232 -58
- sglang/srt/managers/tp_worker.py +12 -9
- sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
- sglang/srt/mem_cache/base_prefix_cache.py +3 -0
- sglang/srt/mem_cache/chunk_cache.py +3 -1
- sglang/srt/mem_cache/hiradix_cache.py +4 -4
- sglang/srt/mem_cache/memory_pool.py +76 -52
- sglang/srt/mem_cache/multimodal_cache.py +45 -0
- sglang/srt/mem_cache/radix_cache.py +58 -5
- sglang/srt/metrics/collector.py +314 -39
- sglang/srt/mm_utils.py +10 -0
- sglang/srt/model_executor/cuda_graph_runner.py +29 -19
- sglang/srt/model_executor/expert_location_updater.py +422 -0
- sglang/srt/model_executor/forward_batch_info.py +5 -1
- sglang/srt/model_executor/model_runner.py +163 -68
- sglang/srt/model_loader/loader.py +10 -6
- sglang/srt/models/clip.py +5 -1
- sglang/srt/models/deepseek_janus_pro.py +2 -2
- sglang/srt/models/deepseek_v2.py +308 -351
- sglang/srt/models/exaone.py +8 -3
- sglang/srt/models/gemma3_mm.py +70 -33
- sglang/srt/models/llama.py +2 -0
- sglang/srt/models/llama4.py +15 -8
- sglang/srt/models/llava.py +258 -7
- sglang/srt/models/mimo_mtp.py +220 -0
- sglang/srt/models/minicpmo.py +5 -12
- sglang/srt/models/mistral.py +71 -1
- sglang/srt/models/mixtral.py +98 -34
- sglang/srt/models/mllama.py +3 -3
- sglang/srt/models/pixtral.py +467 -0
- sglang/srt/models/qwen2.py +95 -26
- sglang/srt/models/qwen2_5_vl.py +8 -0
- sglang/srt/models/qwen2_moe.py +330 -60
- sglang/srt/models/qwen2_vl.py +6 -0
- sglang/srt/models/qwen3.py +52 -10
- sglang/srt/models/qwen3_moe.py +411 -48
- sglang/srt/models/roberta.py +1 -1
- sglang/srt/models/siglip.py +294 -0
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/openai_api/adapter.py +58 -20
- sglang/srt/openai_api/protocol.py +6 -8
- sglang/srt/operations.py +154 -0
- sglang/srt/operations_strategy.py +31 -0
- sglang/srt/reasoning_parser.py +3 -3
- sglang/srt/sampling/custom_logit_processor.py +18 -3
- sglang/srt/sampling/sampling_batch_info.py +4 -56
- sglang/srt/sampling/sampling_params.py +2 -2
- sglang/srt/server_args.py +162 -22
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
- sglang/srt/speculative/eagle_utils.py +138 -7
- sglang/srt/speculative/eagle_worker.py +69 -21
- sglang/srt/utils.py +74 -17
- sglang/test/few_shot_gsm8k.py +2 -2
- sglang/test/few_shot_gsm8k_engine.py +2 -2
- sglang/test/run_eval.py +2 -2
- sglang/test/runners.py +8 -1
- sglang/test/send_one.py +13 -3
- sglang/test/simple_eval_common.py +1 -1
- sglang/test/simple_eval_humaneval.py +1 -1
- sglang/test/test_cutlass_moe.py +278 -0
- sglang/test/test_programs.py +5 -5
- sglang/test/test_utils.py +55 -14
- sglang/utils.py +3 -3
- sglang/version.py +1 -1
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +23 -13
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +178 -149
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
- sglang/srt/function_call_parser.py +0 -858
- sglang/srt/platforms/interface.py +0 -371
- /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
- /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
@@ -5,6 +5,9 @@ import torch
|
|
5
5
|
from torch.nn import Module
|
6
6
|
|
7
7
|
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
|
8
|
+
from sglang.srt.managers.expert_location import get_global_expert_location_metadata
|
9
|
+
from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo
|
10
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
8
11
|
|
9
12
|
try:
|
10
13
|
from deep_gemm import (
|
@@ -40,7 +43,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
|
|
40
43
|
tma_align_input_scale,
|
41
44
|
)
|
42
45
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
|
43
|
-
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoEMethodBase
|
46
|
+
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE, FusedMoEMethodBase
|
44
47
|
from sglang.srt.layers.moe.topk import select_experts
|
45
48
|
from sglang.srt.layers.quantization.base_config import (
|
46
49
|
QuantizationConfig,
|
@@ -49,7 +52,7 @@ from sglang.srt.layers.quantization.base_config import (
|
|
49
52
|
from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
|
50
53
|
from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
|
51
54
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
52
|
-
from sglang.srt.utils import DeepEPMode, is_hip, set_weight_attrs
|
55
|
+
from sglang.srt.utils import DeepEPMode, dispose_tensor, is_hip, set_weight_attrs
|
53
56
|
|
54
57
|
_is_hip = is_hip()
|
55
58
|
|
@@ -92,6 +95,7 @@ class GroupedGemmRunner(torch.nn.Module):
|
|
92
95
|
scale_a: torch.Tensor = None,
|
93
96
|
scale_b: torch.Tensor = None,
|
94
97
|
block_shape: Optional[List[int]] = None,
|
98
|
+
c_dtype=None,
|
95
99
|
):
|
96
100
|
if self.use_flashinfer:
|
97
101
|
# TODO: flashinfer
|
@@ -119,6 +123,7 @@ class GroupedGemmRunner(torch.nn.Module):
|
|
119
123
|
scale_a,
|
120
124
|
scale_b,
|
121
125
|
block_shape=block_shape,
|
126
|
+
c_dtype=c_dtype,
|
122
127
|
)
|
123
128
|
return c
|
124
129
|
|
@@ -136,6 +141,7 @@ class EPMoE(torch.nn.Module):
|
|
136
141
|
top_k: int,
|
137
142
|
hidden_size: int,
|
138
143
|
intermediate_size: int,
|
144
|
+
layer_id: int,
|
139
145
|
params_dtype: Optional[torch.dtype] = None,
|
140
146
|
renormalize: bool = True,
|
141
147
|
use_grouped_topk: bool = False,
|
@@ -159,6 +165,7 @@ class EPMoE(torch.nn.Module):
|
|
159
165
|
)
|
160
166
|
self.tp_rank = get_tensor_model_parallel_rank()
|
161
167
|
|
168
|
+
self.layer_id = layer_id
|
162
169
|
self.num_experts = num_experts
|
163
170
|
assert self.num_experts % self.tp_size == 0
|
164
171
|
self.num_experts_per_partition = self.num_experts // self.tp_size
|
@@ -210,6 +217,10 @@ class EPMoE(torch.nn.Module):
|
|
210
217
|
self.grouped_gemm_runner = None
|
211
218
|
|
212
219
|
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
|
220
|
+
hidden_states_shape = hidden_states.shape
|
221
|
+
hidden_states_dtype = hidden_states.dtype
|
222
|
+
hidden_states_device = hidden_states.device
|
223
|
+
|
213
224
|
assert self.quant_method is not None
|
214
225
|
|
215
226
|
if self.grouped_gemm_runner is None:
|
@@ -229,6 +240,9 @@ class EPMoE(torch.nn.Module):
|
|
229
240
|
correction_bias=self.correction_bias,
|
230
241
|
custom_routing_function=self.custom_routing_function,
|
231
242
|
routed_scaling_factor=self.routed_scaling_factor,
|
243
|
+
expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
|
244
|
+
layer_id=self.layer_id,
|
245
|
+
),
|
232
246
|
)
|
233
247
|
|
234
248
|
reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(
|
@@ -265,25 +279,21 @@ class EPMoE(torch.nn.Module):
|
|
265
279
|
hidden_states.shape[1],
|
266
280
|
BLOCK_SIZE=512,
|
267
281
|
)
|
282
|
+
dispose_tensor(hidden_states)
|
268
283
|
|
269
284
|
seg_indptr_cur_rank = seg_indptr[self.start_expert_id : self.end_expert_id + 2]
|
270
285
|
weight_indices_cur_rank = torch.arange(
|
271
286
|
0,
|
272
287
|
self.num_experts_per_partition,
|
273
|
-
device=
|
288
|
+
device=hidden_states_device,
|
274
289
|
dtype=torch.int64,
|
275
290
|
)
|
276
291
|
# GroupGemm-0
|
277
|
-
gateup_output = torch.empty(
|
278
|
-
gateup_input.shape[0],
|
279
|
-
self.w13_weight.shape[1],
|
280
|
-
device=hidden_states.device,
|
281
|
-
dtype=hidden_states.dtype,
|
282
|
-
)
|
283
292
|
gateup_output = self.grouped_gemm_runner(
|
284
293
|
a=gateup_input,
|
285
294
|
b=self.w13_weight,
|
286
|
-
c=
|
295
|
+
c=None,
|
296
|
+
c_dtype=hidden_states_dtype,
|
287
297
|
batch_size=self.num_experts_per_partition,
|
288
298
|
weight_column_major=True,
|
289
299
|
seg_indptr=seg_indptr_cur_rank,
|
@@ -297,6 +307,7 @@ class EPMoE(torch.nn.Module):
|
|
297
307
|
),
|
298
308
|
block_shape=self.block_shape,
|
299
309
|
)
|
310
|
+
del gateup_input
|
300
311
|
|
301
312
|
# Act
|
302
313
|
down_input = torch.empty(
|
@@ -306,14 +317,14 @@ class EPMoE(torch.nn.Module):
|
|
306
317
|
dtype=(
|
307
318
|
self.fp8_dtype
|
308
319
|
if (self.use_fp8_w8a8 and not self.use_block_quant)
|
309
|
-
else
|
320
|
+
else hidden_states_dtype
|
310
321
|
),
|
311
322
|
)
|
312
323
|
if self.w2_input_scale is None and not self.use_block_quant:
|
313
324
|
self.w2_input_scale = torch.ones(
|
314
325
|
self.num_experts_per_partition,
|
315
326
|
dtype=torch.float32,
|
316
|
-
device=
|
327
|
+
device=hidden_states_device,
|
317
328
|
)
|
318
329
|
|
319
330
|
if self.activation == "silu":
|
@@ -340,13 +351,14 @@ class EPMoE(torch.nn.Module):
|
|
340
351
|
)
|
341
352
|
else:
|
342
353
|
raise ValueError(f"Unsupported activation: {self.activation=}")
|
354
|
+
del gateup_output
|
343
355
|
|
344
356
|
# GroupGemm-1
|
345
357
|
down_output = torch.empty(
|
346
358
|
down_input.shape[0],
|
347
359
|
self.w2_weight.shape[1],
|
348
|
-
device=
|
349
|
-
dtype=
|
360
|
+
device=hidden_states_device,
|
361
|
+
dtype=hidden_states_dtype,
|
350
362
|
)
|
351
363
|
down_output = self.grouped_gemm_runner(
|
352
364
|
a=down_input,
|
@@ -365,10 +377,13 @@ class EPMoE(torch.nn.Module):
|
|
365
377
|
),
|
366
378
|
block_shape=self.block_shape,
|
367
379
|
)
|
380
|
+
del down_input
|
368
381
|
|
369
382
|
# PostReorder
|
370
|
-
output = torch.
|
371
|
-
|
383
|
+
output = torch.empty(
|
384
|
+
hidden_states_shape, dtype=hidden_states_dtype, device=hidden_states_device
|
385
|
+
)
|
386
|
+
post_reorder_triton_kernel[(hidden_states_shape[0],)](
|
372
387
|
down_output,
|
373
388
|
output,
|
374
389
|
src2dst,
|
@@ -377,7 +392,7 @@ class EPMoE(torch.nn.Module):
|
|
377
392
|
self.start_expert_id,
|
378
393
|
self.end_expert_id,
|
379
394
|
self.top_k,
|
380
|
-
|
395
|
+
hidden_states_shape[1],
|
381
396
|
BLOCK_SIZE=512,
|
382
397
|
)
|
383
398
|
return output
|
@@ -417,6 +432,28 @@ class EPMoE(torch.nn.Module):
|
|
417
432
|
weight_name: str,
|
418
433
|
shard_id: str,
|
419
434
|
expert_id: int,
|
435
|
+
) -> None:
|
436
|
+
physical_expert_ids = (
|
437
|
+
get_global_expert_location_metadata().logical_to_all_physical(
|
438
|
+
self.layer_id, expert_id
|
439
|
+
)
|
440
|
+
)
|
441
|
+
for physical_expert_id in physical_expert_ids:
|
442
|
+
self._weight_loader_physical(
|
443
|
+
param=param,
|
444
|
+
loaded_weight=loaded_weight,
|
445
|
+
weight_name=weight_name,
|
446
|
+
shard_id=shard_id,
|
447
|
+
expert_id=physical_expert_id,
|
448
|
+
)
|
449
|
+
|
450
|
+
def _weight_loader_physical(
|
451
|
+
self,
|
452
|
+
param: torch.nn.Parameter,
|
453
|
+
loaded_weight: torch.Tensor,
|
454
|
+
weight_name: str,
|
455
|
+
shard_id: str,
|
456
|
+
expert_id: int,
|
420
457
|
) -> None:
|
421
458
|
if expert_id < self.start_expert_id or expert_id > self.end_expert_id:
|
422
459
|
return
|
@@ -460,7 +497,8 @@ class EPMoE(torch.nn.Module):
|
|
460
497
|
# Input scales can be loaded directly and should be equal.
|
461
498
|
if "input_scale" in weight_name:
|
462
499
|
if (
|
463
|
-
|
500
|
+
(shard_id == "w1" or shard_id == "w3")
|
501
|
+
and param_data[expert_id] != 1
|
464
502
|
and (param_data[expert_id] - loaded_weight).abs() > 1e-5
|
465
503
|
):
|
466
504
|
raise ValueError(
|
@@ -534,13 +572,10 @@ class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
534
572
|
set_weight_attrs(w2_weight, extra_weight_attrs)
|
535
573
|
|
536
574
|
# scale
|
575
|
+
layer.register_parameter("w13_input_scale", None)
|
576
|
+
layer.register_parameter("w13_weight_scale", None)
|
577
|
+
|
537
578
|
ones_tensor = torch.ones(num_experts_per_partition, dtype=torch.float32)
|
538
|
-
w13_input_scale = torch.nn.Parameter(
|
539
|
-
ones_tensor,
|
540
|
-
requires_grad=False,
|
541
|
-
)
|
542
|
-
layer.register_parameter("w13_input_scale", w13_input_scale)
|
543
|
-
set_weight_attrs(w13_input_scale, extra_weight_attrs)
|
544
579
|
|
545
580
|
w2_input_scale = torch.nn.Parameter(
|
546
581
|
ones_tensor,
|
@@ -549,13 +584,6 @@ class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
549
584
|
layer.register_parameter("w2_input_scale", w2_input_scale)
|
550
585
|
set_weight_attrs(w2_input_scale, extra_weight_attrs)
|
551
586
|
|
552
|
-
w13_weight_scale = torch.nn.Parameter(
|
553
|
-
ones_tensor,
|
554
|
-
requires_grad=False,
|
555
|
-
)
|
556
|
-
layer.register_parameter("w13_weight_scale", w13_weight_scale)
|
557
|
-
set_weight_attrs(w13_weight_scale, extra_weight_attrs)
|
558
|
-
|
559
587
|
w2_weight_scale = torch.nn.Parameter(
|
560
588
|
ones_tensor,
|
561
589
|
requires_grad=False,
|
@@ -611,7 +639,7 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
|
|
611
639
|
self.quant_config.weight_block_size[1],
|
612
640
|
)
|
613
641
|
# NOTE(HandH1998): To ensure proper alignment of the block-wise quantization scales, the output_size of the weights for both the gate and up layers must be divisible by block_n.
|
614
|
-
# Required by
|
642
|
+
# Required by column parallel or enabling merged weights
|
615
643
|
if intermediate_size % block_n != 0:
|
616
644
|
raise ValueError(
|
617
645
|
f"The output_size of gate's and up's weight = "
|
@@ -802,6 +830,7 @@ class DeepEPMoE(EPMoE):
|
|
802
830
|
top_k: int,
|
803
831
|
hidden_size: int,
|
804
832
|
intermediate_size: int,
|
833
|
+
layer_id: int,
|
805
834
|
params_dtype: Optional[torch.dtype] = None,
|
806
835
|
renormalize: bool = True,
|
807
836
|
use_grouped_topk: bool = False,
|
@@ -821,6 +850,7 @@ class DeepEPMoE(EPMoE):
|
|
821
850
|
top_k,
|
822
851
|
hidden_size,
|
823
852
|
intermediate_size,
|
853
|
+
layer_id,
|
824
854
|
params_dtype,
|
825
855
|
renormalize,
|
826
856
|
use_grouped_topk,
|
@@ -881,6 +911,9 @@ class DeepEPMoE(EPMoE):
|
|
881
911
|
reorder_topk_ids: torch.Tensor,
|
882
912
|
seg_indptr: torch.Tensor,
|
883
913
|
):
|
914
|
+
hidden_states_dtype = hidden_states.dtype
|
915
|
+
hidden_states_device = hidden_states.device
|
916
|
+
|
884
917
|
assert self.quant_method is not None
|
885
918
|
assert self.activation == "silu"
|
886
919
|
if self.grouped_gemm_runner is None:
|
@@ -903,18 +936,12 @@ class DeepEPMoE(EPMoE):
|
|
903
936
|
)
|
904
937
|
|
905
938
|
# GroupGemm-0
|
906
|
-
gateup_output = torch.empty(
|
907
|
-
hidden_states.shape[0],
|
908
|
-
self.w13_weight.shape[1],
|
909
|
-
device=hidden_states.device,
|
910
|
-
dtype=hidden_states.dtype,
|
911
|
-
)
|
912
|
-
|
913
939
|
if hidden_states.shape[0] > 0:
|
914
940
|
gateup_output = self.grouped_gemm_runner(
|
915
941
|
a=hidden_states,
|
916
942
|
b=self.w13_weight,
|
917
|
-
c=
|
943
|
+
c=None,
|
944
|
+
c_dtype=hidden_states.dtype,
|
918
945
|
batch_size=self.num_experts_per_partition,
|
919
946
|
weight_column_major=True,
|
920
947
|
seg_indptr=seg_indptr,
|
@@ -928,6 +955,13 @@ class DeepEPMoE(EPMoE):
|
|
928
955
|
),
|
929
956
|
block_shape=self.block_shape,
|
930
957
|
)
|
958
|
+
else:
|
959
|
+
gateup_output = torch.empty(
|
960
|
+
hidden_states.shape[0],
|
961
|
+
self.w13_weight.shape[1],
|
962
|
+
device=hidden_states.device,
|
963
|
+
dtype=hidden_states.dtype,
|
964
|
+
)
|
931
965
|
|
932
966
|
# Act
|
933
967
|
down_input = torch.empty(
|
@@ -937,14 +971,14 @@ class DeepEPMoE(EPMoE):
|
|
937
971
|
dtype=(
|
938
972
|
self.fp8_dtype
|
939
973
|
if (self.use_fp8_w8a8 and not self.use_block_quant)
|
940
|
-
else
|
974
|
+
else hidden_states_dtype
|
941
975
|
),
|
942
976
|
)
|
943
977
|
if self.w2_input_scale is None and not self.use_block_quant:
|
944
978
|
self.w2_input_scale = torch.ones(
|
945
979
|
self.num_experts_per_partition,
|
946
980
|
dtype=torch.float32,
|
947
|
-
device=
|
981
|
+
device=hidden_states_device,
|
948
982
|
)
|
949
983
|
|
950
984
|
if self.activation == "silu":
|
@@ -961,12 +995,14 @@ class DeepEPMoE(EPMoE):
|
|
961
995
|
else:
|
962
996
|
raise ValueError(f"Unsupported activation: {self.activation=}")
|
963
997
|
|
998
|
+
del gateup_output
|
999
|
+
|
964
1000
|
# GroupGemm-1
|
965
1001
|
down_output = torch.empty(
|
966
1002
|
down_input.shape[0],
|
967
1003
|
self.w2_weight.shape[1],
|
968
|
-
device=
|
969
|
-
dtype=
|
1004
|
+
device=hidden_states_device,
|
1005
|
+
dtype=hidden_states_dtype,
|
970
1006
|
)
|
971
1007
|
if down_input.shape[0] > 0:
|
972
1008
|
down_output = self.grouped_gemm_runner(
|
@@ -1007,11 +1043,9 @@ class DeepEPMoE(EPMoE):
|
|
1007
1043
|
N = self.w13_weight.size(1)
|
1008
1044
|
scale_block_size = 128
|
1009
1045
|
|
1010
|
-
|
1011
|
-
|
1012
|
-
|
1013
|
-
dtype=torch.bfloat16,
|
1014
|
-
)
|
1046
|
+
hidden_states_fp8_shape = hidden_states_fp8.shape
|
1047
|
+
hidden_states_fp8_device = hidden_states_fp8.device
|
1048
|
+
hidden_states_fp8_dtype = hidden_states_fp8.dtype
|
1015
1049
|
|
1016
1050
|
input_tensor = [
|
1017
1051
|
torch.empty(
|
@@ -1049,16 +1083,18 @@ class DeepEPMoE(EPMoE):
|
|
1049
1083
|
m_indices,
|
1050
1084
|
output_index,
|
1051
1085
|
)
|
1086
|
+
dispose_tensor(hidden_states_fp8)
|
1052
1087
|
|
1053
1088
|
gateup_output = torch.empty(
|
1054
1089
|
(all_tokens, N),
|
1055
|
-
device=
|
1090
|
+
device=hidden_states_fp8_device,
|
1056
1091
|
dtype=torch.bfloat16,
|
1057
1092
|
)
|
1058
1093
|
input_tensor[1] = tma_align_input_scale(input_tensor[1])
|
1059
1094
|
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
1060
1095
|
input_tensor, self.w13_weight_fp8, gateup_output, m_indices
|
1061
1096
|
)
|
1097
|
+
del input_tensor
|
1062
1098
|
down_input = torch.empty(
|
1063
1099
|
(
|
1064
1100
|
all_tokens,
|
@@ -1068,14 +1104,16 @@ class DeepEPMoE(EPMoE):
|
|
1068
1104
|
dtype=torch.bfloat16,
|
1069
1105
|
)
|
1070
1106
|
silu_and_mul(gateup_output.view(-1, N), down_input)
|
1107
|
+
del gateup_output
|
1071
1108
|
down_output = torch.empty(
|
1072
1109
|
(all_tokens, K),
|
1073
|
-
device=
|
1110
|
+
device=hidden_states_fp8_device,
|
1074
1111
|
dtype=torch.bfloat16,
|
1075
1112
|
)
|
1076
1113
|
down_input_fp8, down_input_scale = sglang_per_token_group_quant_fp8(
|
1077
1114
|
down_input, scale_block_size
|
1078
1115
|
)
|
1116
|
+
del down_input
|
1079
1117
|
down_input_scale = tma_align_input_scale(down_input_scale)
|
1080
1118
|
m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
|
1081
1119
|
(down_input_fp8, down_input_scale),
|
@@ -1083,7 +1121,13 @@ class DeepEPMoE(EPMoE):
|
|
1083
1121
|
down_output,
|
1084
1122
|
m_indices,
|
1085
1123
|
)
|
1124
|
+
del down_input_fp8, down_input_scale
|
1086
1125
|
|
1126
|
+
gather_out = torch.empty(
|
1127
|
+
hidden_states_fp8_shape,
|
1128
|
+
device=hidden_states_fp8_device,
|
1129
|
+
dtype=torch.bfloat16,
|
1130
|
+
)
|
1087
1131
|
ep_gather(down_output, topk_idx, topk_weights, output_index, gather_out)
|
1088
1132
|
|
1089
1133
|
return gather_out
|
@@ -1107,6 +1151,7 @@ class DeepEPMoE(EPMoE):
|
|
1107
1151
|
m_grouped_gemm_fp8_fp8_bf16_nt_masked(
|
1108
1152
|
hidden_states_fp8, self.w13_weight_fp8, gateup_output, masked_m, expected_m
|
1109
1153
|
)
|
1154
|
+
dispose_tensor(hidden_states_fp8[0])
|
1110
1155
|
|
1111
1156
|
# Act
|
1112
1157
|
down_input = torch.empty(
|
@@ -1135,6 +1180,7 @@ class DeepEPMoE(EPMoE):
|
|
1135
1180
|
scale_block_size,
|
1136
1181
|
masked_m,
|
1137
1182
|
)
|
1183
|
+
del gateup_output
|
1138
1184
|
|
1139
1185
|
# GroupGemm-1
|
1140
1186
|
n = self.w2_weight.size(1)
|
@@ -1150,3 +1196,11 @@ class DeepEPMoE(EPMoE):
|
|
1150
1196
|
)
|
1151
1197
|
|
1152
1198
|
return down_output
|
1199
|
+
|
1200
|
+
|
1201
|
+
def get_moe_impl_class():
|
1202
|
+
if global_server_args_dict["enable_deepep_moe"]:
|
1203
|
+
return DeepEPMoE
|
1204
|
+
if global_server_args_dict["enable_ep_moe"]:
|
1205
|
+
return EPMoE
|
1206
|
+
return FusedMoE
|
@@ -1,8 +1,15 @@
|
|
1
|
+
import logging
|
2
|
+
from dataclasses import dataclass
|
3
|
+
|
1
4
|
from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
|
2
|
-
from sglang.srt.
|
5
|
+
from sglang.srt.managers.expert_distribution import (
|
6
|
+
get_global_expert_distribution_recorder,
|
7
|
+
)
|
8
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
9
|
+
from sglang.srt.utils import DeepEPMode, load_json_config
|
3
10
|
|
4
11
|
try:
|
5
|
-
from deep_ep import Buffer
|
12
|
+
from deep_ep import Buffer, Config
|
6
13
|
|
7
14
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
8
15
|
sglang_per_token_group_quant_fp8,
|
@@ -12,7 +19,7 @@ try:
|
|
12
19
|
except ImportError:
|
13
20
|
use_deepep = False
|
14
21
|
|
15
|
-
from enum import IntEnum, auto
|
22
|
+
from enum import Enum, IntEnum, auto
|
16
23
|
from typing import Optional, Tuple, Union
|
17
24
|
|
18
25
|
import torch
|
@@ -25,6 +32,8 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
|
|
25
32
|
)
|
26
33
|
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
27
34
|
|
35
|
+
logger = logging.getLogger(__name__)
|
36
|
+
|
28
37
|
|
29
38
|
class DeepEPDispatchMode(IntEnum):
|
30
39
|
NORMAL = auto()
|
@@ -32,7 +41,6 @@ class DeepEPDispatchMode(IntEnum):
|
|
32
41
|
|
33
42
|
|
34
43
|
class DeepEPBuffer:
|
35
|
-
|
36
44
|
_buffer = None
|
37
45
|
_dispatch_mode: Optional[DeepEPDispatchMode] = None
|
38
46
|
_hidden_size: Optional[int] = None
|
@@ -60,8 +68,10 @@ class DeepEPBuffer:
|
|
60
68
|
if deepep_mode.enable_normal():
|
61
69
|
hidden_bytes = hidden_size * param_bytes
|
62
70
|
for config in (
|
63
|
-
|
64
|
-
Buffer.
|
71
|
+
DeepEPConfig.get_instance().normal_dispatch_config
|
72
|
+
or Buffer.get_dispatch_config(group.size()),
|
73
|
+
DeepEPConfig.get_instance().normal_combine_config
|
74
|
+
or Buffer.get_combine_config(group.size()),
|
65
75
|
):
|
66
76
|
num_nvl_bytes = max(
|
67
77
|
config.get_nvl_buffer_size_hint(hidden_bytes, group.size()),
|
@@ -88,7 +98,12 @@ class DeepEPBuffer:
|
|
88
98
|
num_nvl_bytes,
|
89
99
|
num_rdma_bytes,
|
90
100
|
low_latency_mode=deepep_mode.enable_low_latency(),
|
91
|
-
num_qps_per_rank=(
|
101
|
+
num_qps_per_rank=(
|
102
|
+
max(
|
103
|
+
num_experts // group.size(),
|
104
|
+
DeepEPConfig.get_instance().num_sms // 2,
|
105
|
+
)
|
106
|
+
),
|
92
107
|
)
|
93
108
|
return cls._buffer
|
94
109
|
|
@@ -113,6 +128,35 @@ class DeepEPBuffer:
|
|
113
128
|
cls._dispatch_mode = DeepEPDispatchMode.LOW_LATENCY
|
114
129
|
|
115
130
|
|
131
|
+
class DeepEPConfig:
|
132
|
+
_instance = None
|
133
|
+
|
134
|
+
def __init__(self):
|
135
|
+
config_str = global_server_args_dict["deepep_config"]
|
136
|
+
if config_str:
|
137
|
+
config_parsed = load_json_config(config_str)
|
138
|
+
if torch.distributed.get_rank() == 0:
|
139
|
+
logger.info(f"Use DeepEP Config: {config_parsed}")
|
140
|
+
config_dispatch = config_parsed["normal_dispatch"]
|
141
|
+
config_combine = config_parsed["normal_combine"]
|
142
|
+
|
143
|
+
self.normal_dispatch_config = Config(**config_dispatch)
|
144
|
+
self.normal_combine_config = Config(**config_combine)
|
145
|
+
|
146
|
+
assert config_dispatch["num_sms"] == config_combine["num_sms"]
|
147
|
+
self.num_sms = config_dispatch["num_sms"]
|
148
|
+
else:
|
149
|
+
self.normal_dispatch_config = None
|
150
|
+
self.normal_combine_config = None
|
151
|
+
self.num_sms = Buffer.num_sms
|
152
|
+
|
153
|
+
@classmethod
|
154
|
+
def get_instance(cls):
|
155
|
+
if cls._instance is None:
|
156
|
+
cls._instance = DeepEPConfig()
|
157
|
+
return cls._instance
|
158
|
+
|
159
|
+
|
116
160
|
class _DeepEPDispatcherImplBase:
|
117
161
|
def __init__(
|
118
162
|
self,
|
@@ -295,6 +339,14 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|
295
339
|
async_finish=self.async_finish,
|
296
340
|
allocate_on_comm_stream=(previous_event is not None) and self.async_finish,
|
297
341
|
expert_alignment=128 if _ENABLE_JIT_DEEPGEMM else 1,
|
342
|
+
config=DeepEPConfig.get_instance().normal_dispatch_config,
|
343
|
+
)
|
344
|
+
|
345
|
+
get_global_expert_distribution_recorder().on_deepep_dispatch_normal(
|
346
|
+
num_recv_tokens_per_expert_list,
|
347
|
+
num_tokens_per_rank=num_tokens_per_rank,
|
348
|
+
num_tokens_per_rdma_rank=num_tokens_per_rdma_rank,
|
349
|
+
num_tokens_per_expert=num_tokens_per_expert,
|
298
350
|
)
|
299
351
|
|
300
352
|
return (
|
@@ -394,6 +446,7 @@ class _DeepEPDispatcherImplNormal(_DeepEPDispatcherImplBase):
|
|
394
446
|
async_finish=self.async_finish,
|
395
447
|
previous_event=previous_event,
|
396
448
|
allocate_on_comm_stream=previous_event is not None,
|
449
|
+
config=DeepEPConfig.get_instance().normal_combine_config,
|
397
450
|
)
|
398
451
|
return combined_x, event
|
399
452
|
|
@@ -459,6 +512,10 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
|
459
512
|
):
|
460
513
|
hook() if self.return_recv_hook else event.current_stream_wait()
|
461
514
|
|
515
|
+
get_global_expert_distribution_recorder().on_deepep_dispatch_low_latency(
|
516
|
+
masked_m
|
517
|
+
)
|
518
|
+
|
462
519
|
reorder_topk_ids = seg_indptr = None
|
463
520
|
|
464
521
|
return (
|
@@ -571,6 +628,14 @@ class _DeepEPDispatcherImplLowLatency(_DeepEPDispatcherImplBase):
|
|
571
628
|
)
|
572
629
|
|
573
630
|
|
631
|
+
@dataclass
|
632
|
+
class _Stage(Enum):
|
633
|
+
INITIAL = auto()
|
634
|
+
AFTER_DISPATCH_A = auto()
|
635
|
+
AFTER_DISPATCH_B = auto()
|
636
|
+
AFTER_COMBINE_A = auto()
|
637
|
+
|
638
|
+
|
574
639
|
class DeepEPDispatcher:
|
575
640
|
def __init__(
|
576
641
|
self,
|
@@ -609,6 +674,8 @@ class DeepEPDispatcher:
|
|
609
674
|
**common_kwargs,
|
610
675
|
)
|
611
676
|
|
677
|
+
self._stage = _Stage.INITIAL
|
678
|
+
|
612
679
|
def dispatch(self, *args, **kwargs) -> Tuple:
|
613
680
|
self.dispatch_a(*args, **kwargs)
|
614
681
|
ret = self.dispatch_b()
|
@@ -621,6 +688,7 @@ class DeepEPDispatcher:
|
|
621
688
|
topk_weights: torch.Tensor,
|
622
689
|
forward_mode: ForwardMode = None,
|
623
690
|
):
|
691
|
+
self._update_stage(_Stage.INITIAL, _Stage.AFTER_DISPATCH_A)
|
624
692
|
inner_state = self._get_impl(forward_mode).dispatch_a(
|
625
693
|
hidden_states=hidden_states,
|
626
694
|
topk_idx=topk_idx,
|
@@ -629,6 +697,7 @@ class DeepEPDispatcher:
|
|
629
697
|
self._dispatch_intermediate_state = forward_mode, inner_state
|
630
698
|
|
631
699
|
def dispatch_b(self):
|
700
|
+
self._update_stage(_Stage.AFTER_DISPATCH_A, _Stage.AFTER_DISPATCH_B)
|
632
701
|
forward_mode, inner_state = self._dispatch_intermediate_state
|
633
702
|
del self._dispatch_intermediate_state
|
634
703
|
return self._get_impl(forward_mode).dispatch_b(*inner_state)
|
@@ -645,6 +714,7 @@ class DeepEPDispatcher:
|
|
645
714
|
topk_weights: torch.Tensor,
|
646
715
|
forward_mode: ForwardMode,
|
647
716
|
):
|
717
|
+
self._update_stage(_Stage.AFTER_DISPATCH_B, _Stage.AFTER_COMBINE_A)
|
648
718
|
inner_state = self._get_impl(forward_mode).combine_a(
|
649
719
|
hidden_states=hidden_states,
|
650
720
|
topk_idx=topk_idx,
|
@@ -653,6 +723,7 @@ class DeepEPDispatcher:
|
|
653
723
|
self._combine_intermediate_state = forward_mode, inner_state
|
654
724
|
|
655
725
|
def combine_b(self):
|
726
|
+
self._update_stage(_Stage.AFTER_COMBINE_A, _Stage.INITIAL)
|
656
727
|
forward_mode, inner_state = self._combine_intermediate_state
|
657
728
|
del self._combine_intermediate_state
|
658
729
|
return self._get_impl(forward_mode).combine_b(*inner_state)
|
@@ -665,3 +736,7 @@ class DeepEPDispatcher:
|
|
665
736
|
return self._low_latency_dispatcher
|
666
737
|
else:
|
667
738
|
raise ValueError(f"Invalid deepep_mode: {self.deepep_mode}")
|
739
|
+
|
740
|
+
def _update_stage(self, old_stage, new_stage):
|
741
|
+
assert self._stage == old_stage
|
742
|
+
self._stage = new_stage
|
@@ -994,7 +994,7 @@ def get_default_config(
|
|
994
994
|
"num_stages": 2 if _is_hip else 4,
|
995
995
|
}
|
996
996
|
else:
|
997
|
-
# Block-wise quant: BLOCK_SIZE_K must be
|
997
|
+
# Block-wise quant: BLOCK_SIZE_K must be divisible by block_shape[1]
|
998
998
|
config = {
|
999
999
|
"BLOCK_SIZE_M": 64,
|
1000
1000
|
"BLOCK_SIZE_N": block_shape[0],
|
@@ -186,6 +186,19 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
|
|
186
186
|
|
187
187
|
if _is_hip and get_bool_env_var("SGLANG_AITER_MOE"):
|
188
188
|
assert not no_combine, "unsupported"
|
189
|
+
if apply_router_weight_on_input:
|
190
|
+
assert (
|
191
|
+
topk_weights.dim() == 2
|
192
|
+
), "`topk_weights` should be in shape (num_tokens, topk)"
|
193
|
+
_, topk = topk_weights.shape
|
194
|
+
assert (
|
195
|
+
topk == 1
|
196
|
+
), "Only support topk=1 when `apply_router_weight_on_input` is True"
|
197
|
+
x = x * topk_weights.to(x.dtype)
|
198
|
+
topk_weights = torch.ones_like(
|
199
|
+
topk_weights, dtype=torch.float32
|
200
|
+
) # topk_weights must be FP32 (float32)
|
201
|
+
|
189
202
|
return ck_moe_2stages(
|
190
203
|
x,
|
191
204
|
layer.w13_weight,
|
@@ -270,6 +283,7 @@ class FusedMoE(torch.nn.Module):
|
|
270
283
|
top_k: int,
|
271
284
|
hidden_size: int,
|
272
285
|
intermediate_size: int,
|
286
|
+
layer_id: Optional[int] = None,
|
273
287
|
params_dtype: Optional[torch.dtype] = None,
|
274
288
|
reduce_results: bool = False,
|
275
289
|
renormalize: bool = True,
|