sglang 0.4.10__py3-none-any.whl → 0.4.10.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +20 -0
- sglang/compile_deep_gemm.py +8 -1
- sglang/global_config.py +5 -1
- sglang/srt/configs/model_config.py +1 -0
- sglang/srt/conversation.py +0 -112
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +1 -0
- sglang/srt/disaggregation/launch_lb.py +5 -20
- sglang/srt/disaggregation/mooncake/conn.py +33 -15
- sglang/srt/disaggregation/prefill.py +1 -0
- sglang/srt/distributed/device_communicators/pynccl.py +7 -0
- sglang/srt/distributed/device_communicators/pynccl_allocator.py +133 -0
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +42 -3
- sglang/srt/distributed/parallel_state.py +11 -0
- sglang/srt/entrypoints/engine.py +4 -2
- sglang/srt/entrypoints/http_server.py +35 -15
- sglang/srt/eplb/expert_distribution.py +4 -2
- sglang/srt/hf_transformers_utils.py +25 -10
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/flashattention_backend.py +7 -11
- sglang/srt/layers/attention/trtllm_mla_backend.py +372 -0
- sglang/srt/layers/attention/utils.py +6 -1
- sglang/srt/layers/attention/vision.py +27 -10
- sglang/srt/layers/communicator.py +14 -4
- sglang/srt/layers/linear.py +7 -1
- sglang/srt/layers/logits_processor.py +9 -1
- sglang/srt/layers/moe/ep_moe/layer.py +29 -68
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=352,device_name=NVIDIA_RTX_6000_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +82 -25
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +0 -31
- sglang/srt/layers/moe/token_dispatcher/__init__.py +23 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +12 -1
- sglang/srt/layers/moe/{ep_moe/token_dispatcher.py → token_dispatcher/deepep.py} +8 -15
- sglang/srt/layers/moe/utils.py +43 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +3 -2
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +1 -1
- sglang/srt/layers/quantization/fp8.py +57 -1
- sglang/srt/layers/quantization/fp8_kernel.py +0 -4
- sglang/srt/layers/quantization/w8a8_int8.py +4 -1
- sglang/srt/layers/vocab_parallel_embedding.py +7 -1
- sglang/srt/lora/lora_registry.py +7 -0
- sglang/srt/managers/cache_controller.py +43 -39
- sglang/srt/managers/data_parallel_controller.py +52 -2
- sglang/srt/managers/io_struct.py +6 -1
- sglang/srt/managers/schedule_batch.py +3 -2
- sglang/srt/managers/schedule_policy.py +3 -1
- sglang/srt/managers/scheduler.py +145 -6
- sglang/srt/managers/template_manager.py +25 -22
- sglang/srt/managers/tokenizer_manager.py +114 -62
- sglang/srt/managers/utils.py +45 -1
- sglang/srt/mem_cache/cpp_radix_tree/radix_tree.py +182 -0
- sglang/srt/mem_cache/hicache_storage.py +13 -12
- sglang/srt/mem_cache/hiradix_cache.py +21 -4
- sglang/srt/mem_cache/memory_pool.py +15 -118
- sglang/srt/mem_cache/memory_pool_host.py +350 -33
- sglang/srt/mem_cache/radix_cache_cpp.py +229 -0
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +8 -2
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp +35 -0
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +163 -0
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +238 -0
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +216 -0
- sglang/srt/model_executor/cuda_graph_runner.py +42 -4
- sglang/srt/model_executor/forward_batch_info.py +13 -3
- sglang/srt/model_executor/model_runner.py +13 -1
- sglang/srt/model_loader/weight_utils.py +2 -0
- sglang/srt/models/deepseek_v2.py +28 -23
- sglang/srt/models/glm4_moe.py +85 -22
- sglang/srt/models/grok.py +3 -3
- sglang/srt/models/llama4.py +13 -2
- sglang/srt/models/mixtral.py +3 -3
- sglang/srt/models/mllama4.py +428 -19
- sglang/srt/models/qwen2_moe.py +1 -4
- sglang/srt/models/qwen3_moe.py +7 -8
- sglang/srt/models/step3_vl.py +1 -4
- sglang/srt/multimodal/processors/base_processor.py +4 -3
- sglang/srt/multimodal/processors/gemma3n.py +0 -7
- sglang/srt/operations_strategy.py +1 -1
- sglang/srt/server_args.py +115 -21
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +18 -0
- sglang/srt/two_batch_overlap.py +6 -4
- sglang/srt/utils.py +4 -24
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +945 -0
- sglang/test/runners.py +2 -2
- sglang/test/test_utils.py +3 -3
- sglang/version.py +1 -1
- {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/METADATA +3 -2
- {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/RECORD +92 -81
- /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/mooncake_store.py +0 -0
- /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/unit_test.py +0 -0
- {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/WHEEL +0 -0
- {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/top_level.txt +0 -0
@@ -4,7 +4,7 @@ import dataclasses
|
|
4
4
|
import functools
|
5
5
|
import math
|
6
6
|
from functools import lru_cache, partial
|
7
|
-
from typing import Any, Optional, Tuple, Union
|
7
|
+
from typing import Any, Callable, Optional, Tuple, Union
|
8
8
|
|
9
9
|
import torch
|
10
10
|
import torch.nn as nn
|
@@ -308,6 +308,7 @@ class VisionFlash3Attention(nn.Module):
|
|
308
308
|
cu_seqlens = cu_seqlens.to(dtype=torch.int32).to(q.device)
|
309
309
|
seq_lens = cu_seqlens[1:] - cu_seqlens[:-1]
|
310
310
|
max_seqlen = seq_lens.max().item()
|
311
|
+
|
311
312
|
output = flash_attn_varlen_func(
|
312
313
|
q,
|
313
314
|
k,
|
@@ -358,6 +359,9 @@ class VisionAttention(nn.Module):
|
|
358
359
|
qkv_bias: bool = True,
|
359
360
|
qk_normalization: bool = False,
|
360
361
|
layer_norm_eps: float = 1e-06,
|
362
|
+
customized_position_embedding_applier: Callable[
|
363
|
+
[torch.Tensor, torch.Tensor, Any, Any], Tuple[torch.Tensor, torch.Tensor]
|
364
|
+
] = None,
|
361
365
|
**kwargs,
|
362
366
|
):
|
363
367
|
super().__init__()
|
@@ -392,6 +396,7 @@ class VisionAttention(nn.Module):
|
|
392
396
|
self.dummy_dim, eps=layer_norm_eps, var_hidden_size=embed_dim
|
393
397
|
)
|
394
398
|
|
399
|
+
# priority: server_args > passed qkv_backend > sdpa
|
395
400
|
if global_server_args_dict["mm_attention_backend"] is None:
|
396
401
|
if qkv_backend is None:
|
397
402
|
qkv_backend = "sdpa"
|
@@ -401,6 +406,9 @@ class VisionAttention(nn.Module):
|
|
401
406
|
|
402
407
|
print_info_once(f"Using {qkv_backend} as multimodal attention backend.")
|
403
408
|
|
409
|
+
self.customized_position_embedding_applier = (
|
410
|
+
customized_position_embedding_applier
|
411
|
+
)
|
404
412
|
self.qkv_backend = QKV_BACKEND_IMPL[qkv_backend](
|
405
413
|
head_dim=self.head_size,
|
406
414
|
num_heads=self.num_attention_heads_per_partition,
|
@@ -473,13 +481,13 @@ class VisionAttention(nn.Module):
|
|
473
481
|
if x.dim() == 2:
|
474
482
|
x = x.unsqueeze(0)
|
475
483
|
assert x.dim() == 3, x.shape
|
476
|
-
|
484
|
+
x_shape = x.shape
|
485
|
+
bsz, s, _ = x_shape
|
477
486
|
head = self.num_attention_heads_per_partition
|
478
487
|
kv_head = self.num_attention_kv_heads_per_partition
|
479
488
|
if self.use_qkv_parallel:
|
480
489
|
# [b, s, embed_dim] --> [b, s, embed_dim]
|
481
490
|
qkv, _ = self.qkv_proj(x)
|
482
|
-
|
483
491
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
484
492
|
|
485
493
|
# [b, s, embed_dim] --> [b * s, head, head_size]
|
@@ -508,16 +516,25 @@ class VisionAttention(nn.Module):
|
|
508
516
|
]
|
509
517
|
|
510
518
|
if position_embeddings is not None:
|
511
|
-
cos, sin = position_embeddings
|
512
519
|
original_shape = q.shape
|
513
|
-
# [total_tokens, head, head_size]
|
514
|
-
q = q.view(-1, head, self.head_size)
|
515
|
-
k = k.view(-1, head, self.head_size)
|
516
520
|
|
517
|
-
|
521
|
+
if self.customized_position_embedding_applier is not None:
|
522
|
+
q, k = self.customized_position_embedding_applier(
|
523
|
+
q, k, position_embeddings, x_shape
|
524
|
+
)
|
525
|
+
q = q.view(original_shape)
|
526
|
+
k = k.view(original_shape)
|
527
|
+
else:
|
528
|
+
cos, sin = position_embeddings
|
529
|
+
|
530
|
+
# [total_tokens, head, head_size]
|
531
|
+
q = q.view(-1, head, self.head_size)
|
532
|
+
k = k.view(-1, head, self.head_size)
|
533
|
+
|
534
|
+
q, k = apply_rotary_pos_emb(q, k, cos, sin)
|
518
535
|
|
519
|
-
|
520
|
-
|
536
|
+
q = q.view(original_shape)
|
537
|
+
k = k.view(original_shape)
|
521
538
|
|
522
539
|
if q.dim() == 4:
|
523
540
|
# [b, s, head, head_size] --> [b * s, head, head_size]
|
@@ -108,7 +108,7 @@ class LayerScatterModes:
|
|
108
108
|
if context.is_layer_sparse:
|
109
109
|
return (
|
110
110
|
ScatterMode.SCATTERED
|
111
|
-
if global_server_args_dict["
|
111
|
+
if not global_server_args_dict["moe_a2a_backend"].is_standard()
|
112
112
|
else ScatterMode.FULL
|
113
113
|
)
|
114
114
|
else:
|
@@ -404,14 +404,24 @@ class CommunicateWithAllReduceAndLayerNormFn:
|
|
404
404
|
if context.attn_dp_size != 1:
|
405
405
|
if context.attn_tp_rank == 0:
|
406
406
|
hidden_states += residual
|
407
|
+
|
408
|
+
# Perform layernorm on smaller data before comm. Only valid when attn_tp_size is 1 (tp_size == dp_size)
|
409
|
+
use_layer_norm_before_gather = context.attn_tp_size == 1
|
410
|
+
if use_layer_norm_before_gather:
|
411
|
+
residual.copy_(hidden_states)
|
412
|
+
if hidden_states.shape[0] != 0:
|
413
|
+
hidden_states = layernorm(hidden_states)
|
414
|
+
|
407
415
|
hidden_states, local_hidden_states = (
|
408
416
|
forward_batch.gathered_buffer,
|
409
417
|
hidden_states,
|
410
418
|
)
|
411
419
|
dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
|
412
|
-
|
413
|
-
if
|
414
|
-
hidden_states
|
420
|
+
|
421
|
+
if not use_layer_norm_before_gather:
|
422
|
+
dp_scatter(residual, hidden_states, forward_batch)
|
423
|
+
if hidden_states.shape[0] != 0:
|
424
|
+
hidden_states = layernorm(hidden_states)
|
415
425
|
else:
|
416
426
|
# According to the discussion in https://github.com/flashinfer-ai/flashinfer/issues/1223#issuecomment-3047256465
|
417
427
|
# We set the max token num to 128 for allreduce fusion with min-latency case(use_oneshot=True).
|
sglang/srt/layers/linear.py
CHANGED
@@ -13,10 +13,14 @@ from sglang.srt.distributed import (
|
|
13
13
|
divide,
|
14
14
|
get_tensor_model_parallel_rank,
|
15
15
|
get_tensor_model_parallel_world_size,
|
16
|
+
parallel_state,
|
16
17
|
split_tensor_along_last_dim,
|
17
18
|
tensor_model_parallel_all_gather,
|
18
19
|
tensor_model_parallel_all_reduce,
|
19
20
|
)
|
21
|
+
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
|
22
|
+
use_symmetric_memory,
|
23
|
+
)
|
20
24
|
from sglang.srt.layers.parameter import (
|
21
25
|
BasevLLMParameter,
|
22
26
|
BlockQuantScaleParameter,
|
@@ -1292,7 +1296,9 @@ class RowParallelLinear(LinearBase):
|
|
1292
1296
|
# Only fuse bias add into GEMM for rank 0 (this ensures that
|
1293
1297
|
# bias will not get added more than once in TP>1 case)
|
1294
1298
|
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
|
1295
|
-
|
1299
|
+
with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
|
1300
|
+
output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_)
|
1301
|
+
sm.tag(output_parallel)
|
1296
1302
|
if self.reduce_results and self.tp_size > 1 and not can_fuse_mlp_allreduce:
|
1297
1303
|
output = tensor_model_parallel_all_reduce(output_parallel)
|
1298
1304
|
else:
|
@@ -83,6 +83,7 @@ class LogitsProcessorOutput:
|
|
83
83
|
class LogitsMetadata:
|
84
84
|
forward_mode: ForwardMode
|
85
85
|
capture_hidden_mode: CaptureHiddenMode = CaptureHiddenMode.NULL
|
86
|
+
next_token_logits_buffer: Optional[torch.Tensor] = None
|
86
87
|
|
87
88
|
extend_return_logprob: bool = False
|
88
89
|
extend_return_top_logprob: bool = False
|
@@ -148,6 +149,7 @@ class LogitsMetadata:
|
|
148
149
|
return cls(
|
149
150
|
forward_mode=forward_batch.forward_mode,
|
150
151
|
capture_hidden_mode=forward_batch.capture_hidden_mode,
|
152
|
+
next_token_logits_buffer=forward_batch.next_token_logits_buffer,
|
151
153
|
extend_return_logprob=extend_return_logprob,
|
152
154
|
extend_return_top_logprob=extend_return_top_logprob,
|
153
155
|
extend_token_ids_logprob=extend_token_ids_logprob,
|
@@ -508,7 +510,13 @@ class LogitsProcessor(nn.Module):
|
|
508
510
|
)
|
509
511
|
dp_scatter(logits, global_logits, logits_metadata)
|
510
512
|
|
511
|
-
|
513
|
+
if logits_metadata.next_token_logits_buffer is not None:
|
514
|
+
logits_buffer = logits_metadata.next_token_logits_buffer
|
515
|
+
assert logits_buffer.dtype == torch.float
|
516
|
+
logits_buffer.copy_(logits[:, : self.config.vocab_size])
|
517
|
+
logits = logits_buffer
|
518
|
+
else:
|
519
|
+
logits = logits[:, : self.config.vocab_size].float()
|
512
520
|
|
513
521
|
if self.final_logit_softcapping:
|
514
522
|
fused_softcap(logits, self.final_logit_softcapping)
|
@@ -1,59 +1,43 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import logging
|
4
|
-
from typing import TYPE_CHECKING,
|
4
|
+
from typing import TYPE_CHECKING, Optional
|
5
5
|
|
6
6
|
import torch
|
7
7
|
|
8
|
-
from sglang.srt.distributed import
|
9
|
-
get_tensor_model_parallel_rank,
|
10
|
-
get_tensor_model_parallel_world_size,
|
11
|
-
)
|
12
|
-
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
|
8
|
+
from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
|
13
9
|
from sglang.srt.layers.moe.ep_moe.kernels import (
|
14
10
|
ep_gather,
|
15
11
|
ep_scatter,
|
16
|
-
gelu_and_mul_triton_kernel,
|
17
|
-
grouped_gemm_triton,
|
18
12
|
moe_ep_deepgemm_preprocess,
|
19
13
|
post_reorder_triton_kernel,
|
20
|
-
pre_reorder_triton_kernel,
|
21
|
-
pre_reorder_triton_kernel_for_cutlass_moe,
|
22
|
-
run_cutlass_moe_ep_preproess,
|
23
|
-
run_moe_ep_preproess,
|
24
14
|
silu_and_mul_masked_post_quant_fwd,
|
25
|
-
silu_and_mul_triton_kernel,
|
26
15
|
tma_align_input_scale,
|
27
16
|
)
|
28
|
-
from sglang.srt.layers.moe.fused_moe_triton.layer import
|
17
|
+
from sglang.srt.layers.moe.fused_moe_triton.layer import (
|
18
|
+
FlashInferFusedMoE,
|
19
|
+
FusedMoE,
|
20
|
+
should_use_flashinfer_trtllm_moe,
|
21
|
+
)
|
29
22
|
from sglang.srt.layers.moe.topk import TopKOutput
|
23
|
+
from sglang.srt.layers.moe.utils import DeepEPMode
|
30
24
|
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
31
|
-
from sglang.srt.layers.quantization.base_config import
|
32
|
-
|
33
|
-
|
25
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
26
|
+
from sglang.srt.layers.quantization.fp8 import (
|
27
|
+
Fp8Config,
|
28
|
+
Fp8MoEMethod,
|
29
|
+
get_tile_tokens_dim,
|
34
30
|
)
|
35
|
-
from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
|
36
31
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
37
32
|
is_fp8_fnuz,
|
38
33
|
sglang_per_token_group_quant_fp8,
|
39
|
-
sglang_per_token_quant_fp8,
|
40
34
|
)
|
41
|
-
from sglang.srt.layers.quantization.unquant import UnquantizedFusedMoEMethod
|
42
|
-
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config, W4AFp8MoEMethod
|
43
35
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
44
36
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
45
|
-
from sglang.srt.utils import
|
46
|
-
DeepEPMode,
|
47
|
-
ceil_div,
|
48
|
-
dispose_tensor,
|
49
|
-
get_bool_env_var,
|
50
|
-
is_hip,
|
51
|
-
is_npu,
|
52
|
-
next_power_of_2,
|
53
|
-
)
|
37
|
+
from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip, is_npu
|
54
38
|
|
55
39
|
if TYPE_CHECKING:
|
56
|
-
from sglang.srt.layers.moe.
|
40
|
+
from sglang.srt.layers.moe.token_dispatcher import (
|
57
41
|
DeepEPLLOutput,
|
58
42
|
DeepEPNormalOutput,
|
59
43
|
DispatchOutput,
|
@@ -63,10 +47,7 @@ _is_hip = is_hip()
|
|
63
47
|
_is_npu = is_npu()
|
64
48
|
_is_fp8_fnuz = is_fp8_fnuz()
|
65
49
|
_use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
|
66
|
-
|
67
|
-
global_server_args_dict["enable_flashinfer_trtllm_moe"]
|
68
|
-
and global_server_args_dict["enable_ep_moe"]
|
69
|
-
)
|
50
|
+
|
70
51
|
|
71
52
|
if not (_is_npu or _is_hip):
|
72
53
|
from sgl_kernel import silu_and_mul
|
@@ -76,26 +57,9 @@ if _use_aiter:
|
|
76
57
|
from aiter.fused_moe import fused_moe
|
77
58
|
from aiter.ops.shuffle import shuffle_weight
|
78
59
|
|
79
|
-
if use_flashinfer_trtllm_moe:
|
80
|
-
try:
|
81
|
-
import flashinfer.fused_moe as fi_fused_moe
|
82
|
-
except ImportError:
|
83
|
-
fi_fused_moe = None
|
84
|
-
use_flashinfer_trtllm_moe = False
|
85
|
-
|
86
60
|
logger = logging.getLogger(__name__)
|
87
61
|
|
88
62
|
|
89
|
-
def _get_tile_tokens_dim(num_tokens, top_k, num_experts):
|
90
|
-
# Guess tokens per expert assuming perfect expert distribution first.
|
91
|
-
num_tokens_per_expert = (num_tokens * top_k) // num_experts
|
92
|
-
# And pad the number to the next power of 2.
|
93
|
-
tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
|
94
|
-
# Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel.
|
95
|
-
tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
|
96
|
-
return tile_tokens_dim
|
97
|
-
|
98
|
-
|
99
63
|
class EPMoE(FusedMoE):
|
100
64
|
"""
|
101
65
|
MoE Expert Parallel Impl
|
@@ -132,7 +96,6 @@ class EPMoE(FusedMoE):
|
|
132
96
|
activation=activation,
|
133
97
|
# apply_router_weight_on_input=apply_router_weight_on_input,
|
134
98
|
routed_scaling_factor=routed_scaling_factor,
|
135
|
-
enable_ep_moe=True,
|
136
99
|
)
|
137
100
|
|
138
101
|
self.start_expert_id = self.moe_ep_rank * self.num_local_experts
|
@@ -317,6 +280,8 @@ class EPMoE(FusedMoE):
|
|
317
280
|
m_max * self.start_expert_id,
|
318
281
|
BLOCK_SIZE=512,
|
319
282
|
)
|
283
|
+
if self.routed_scaling_factor is not None:
|
284
|
+
output *= self.routed_scaling_factor
|
320
285
|
return output
|
321
286
|
|
322
287
|
|
@@ -341,7 +306,7 @@ class DeepEPMoE(EPMoE):
|
|
341
306
|
prefix: str = "",
|
342
307
|
activation: str = "silu",
|
343
308
|
routed_scaling_factor: Optional[float] = None,
|
344
|
-
deepep_mode: DeepEPMode = DeepEPMode.
|
309
|
+
deepep_mode: DeepEPMode = DeepEPMode.AUTO,
|
345
310
|
):
|
346
311
|
super().__init__(
|
347
312
|
num_experts=num_experts,
|
@@ -361,7 +326,6 @@ class DeepEPMoE(EPMoE):
|
|
361
326
|
|
362
327
|
# TODO: move to the beginning of the file
|
363
328
|
from sglang.srt.distributed.parallel_state import get_tp_group
|
364
|
-
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
365
329
|
from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher
|
366
330
|
|
367
331
|
self.deepep_dispatcher = MaybeTboDeepEPDispatcher(
|
@@ -731,10 +695,10 @@ class FlashInferEPMoE(EPMoE):
|
|
731
695
|
self.num_expert_group = num_expert_group
|
732
696
|
self.topk_group = topk_group
|
733
697
|
self.correction_bias = correction_bias
|
734
|
-
self.use_flashinfer_trtllm_moe =
|
698
|
+
self.use_flashinfer_trtllm_moe = should_use_flashinfer_trtllm_moe()
|
735
699
|
|
736
700
|
def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
|
737
|
-
assert use_flashinfer_trtllm_moe
|
701
|
+
assert self.use_flashinfer_trtllm_moe
|
738
702
|
assert (
|
739
703
|
self.activation == "silu"
|
740
704
|
), "Only silu is supported for flashinfer blockscale fp8 moe"
|
@@ -747,8 +711,9 @@ class FlashInferEPMoE(EPMoE):
|
|
747
711
|
a_q, a_sf = sglang_per_token_group_quant_fp8(hidden_states, self.block_shape[1])
|
748
712
|
# NOTE: scales of hidden states have to be transposed!
|
749
713
|
a_sf_t = a_sf.t().contiguous()
|
750
|
-
|
751
|
-
|
714
|
+
from flashinfer.fused_moe import trtllm_fp8_block_scale_moe
|
715
|
+
|
716
|
+
return trtllm_fp8_block_scale_moe(
|
752
717
|
routing_logits=router_logits.to(torch.float32),
|
753
718
|
routing_bias=self.correction_bias.to(hidden_states.dtype),
|
754
719
|
hidden_states=a_q,
|
@@ -765,7 +730,7 @@ class FlashInferEPMoE(EPMoE):
|
|
765
730
|
local_expert_offset=self.start_expert_id,
|
766
731
|
local_num_experts=self.num_local_experts,
|
767
732
|
routed_scaling_factor=self.routed_scaling_factor,
|
768
|
-
tile_tokens_dim=
|
733
|
+
tile_tokens_dim=get_tile_tokens_dim(
|
769
734
|
hidden_states.shape[0], self.top_k, self.num_experts
|
770
735
|
),
|
771
736
|
routing_method_type=2, # DeepSeek-styled routing method
|
@@ -774,14 +739,10 @@ class FlashInferEPMoE(EPMoE):
|
|
774
739
|
|
775
740
|
|
776
741
|
def get_moe_impl_class():
|
777
|
-
if global_server_args_dict["
|
742
|
+
if global_server_args_dict["moe_a2a_backend"].is_deepep():
|
778
743
|
return DeepEPMoE
|
779
744
|
if global_server_args_dict["enable_flashinfer_cutlass_moe"]:
|
780
|
-
# Must come before EPMoE because FusedMoE also supports enable_ep_moe
|
781
745
|
return FusedMoE
|
782
|
-
if
|
783
|
-
|
784
|
-
|
785
|
-
if global_server_args_dict["enable_ep_moe"]:
|
786
|
-
return EPMoE
|
787
|
-
return FusedMoE
|
746
|
+
if get_moe_expert_parallel_world_size() > 1:
|
747
|
+
return FlashInferEPMoE if should_use_flashinfer_trtllm_moe() else EPMoE
|
748
|
+
return FlashInferFusedMoE if should_use_flashinfer_trtllm_moe() else FusedMoE
|
@@ -0,0 +1,146 @@
|
|
1
|
+
{
|
2
|
+
"1": {
|
3
|
+
"BLOCK_SIZE_M": 16,
|
4
|
+
"BLOCK_SIZE_N": 64,
|
5
|
+
"BLOCK_SIZE_K": 128,
|
6
|
+
"GROUP_SIZE_M": 1,
|
7
|
+
"num_warps": 4,
|
8
|
+
"num_stages": 2
|
9
|
+
},
|
10
|
+
"2": {
|
11
|
+
"BLOCK_SIZE_M": 16,
|
12
|
+
"BLOCK_SIZE_N": 64,
|
13
|
+
"BLOCK_SIZE_K": 128,
|
14
|
+
"GROUP_SIZE_M": 32,
|
15
|
+
"num_warps": 4,
|
16
|
+
"num_stages": 2
|
17
|
+
},
|
18
|
+
"4": {
|
19
|
+
"BLOCK_SIZE_M": 16,
|
20
|
+
"BLOCK_SIZE_N": 32,
|
21
|
+
"BLOCK_SIZE_K": 256,
|
22
|
+
"GROUP_SIZE_M": 64,
|
23
|
+
"num_warps": 4,
|
24
|
+
"num_stages": 2
|
25
|
+
},
|
26
|
+
"8": {
|
27
|
+
"BLOCK_SIZE_M": 16,
|
28
|
+
"BLOCK_SIZE_N": 64,
|
29
|
+
"BLOCK_SIZE_K": 256,
|
30
|
+
"GROUP_SIZE_M": 32,
|
31
|
+
"num_warps": 4,
|
32
|
+
"num_stages": 3
|
33
|
+
},
|
34
|
+
"16": {
|
35
|
+
"BLOCK_SIZE_M": 32,
|
36
|
+
"BLOCK_SIZE_N": 64,
|
37
|
+
"BLOCK_SIZE_K": 256,
|
38
|
+
"GROUP_SIZE_M": 1,
|
39
|
+
"num_warps": 8,
|
40
|
+
"num_stages": 2
|
41
|
+
},
|
42
|
+
"24": {
|
43
|
+
"BLOCK_SIZE_M": 16,
|
44
|
+
"BLOCK_SIZE_N": 64,
|
45
|
+
"BLOCK_SIZE_K": 256,
|
46
|
+
"GROUP_SIZE_M": 32,
|
47
|
+
"num_warps": 4,
|
48
|
+
"num_stages": 3
|
49
|
+
},
|
50
|
+
"32": {
|
51
|
+
"BLOCK_SIZE_M": 16,
|
52
|
+
"BLOCK_SIZE_N": 32,
|
53
|
+
"BLOCK_SIZE_K": 128,
|
54
|
+
"GROUP_SIZE_M": 16,
|
55
|
+
"num_warps": 4,
|
56
|
+
"num_stages": 4
|
57
|
+
},
|
58
|
+
"48": {
|
59
|
+
"BLOCK_SIZE_M": 16,
|
60
|
+
"BLOCK_SIZE_N": 32,
|
61
|
+
"BLOCK_SIZE_K": 128,
|
62
|
+
"GROUP_SIZE_M": 16,
|
63
|
+
"num_warps": 8,
|
64
|
+
"num_stages": 3
|
65
|
+
},
|
66
|
+
"64": {
|
67
|
+
"BLOCK_SIZE_M": 16,
|
68
|
+
"BLOCK_SIZE_N": 32,
|
69
|
+
"BLOCK_SIZE_K": 256,
|
70
|
+
"GROUP_SIZE_M": 16,
|
71
|
+
"num_warps": 8,
|
72
|
+
"num_stages": 2
|
73
|
+
},
|
74
|
+
"96": {
|
75
|
+
"BLOCK_SIZE_M": 16,
|
76
|
+
"BLOCK_SIZE_N": 64,
|
77
|
+
"BLOCK_SIZE_K": 256,
|
78
|
+
"GROUP_SIZE_M": 16,
|
79
|
+
"num_warps": 4,
|
80
|
+
"num_stages": 3
|
81
|
+
},
|
82
|
+
"128": {
|
83
|
+
"BLOCK_SIZE_M": 16,
|
84
|
+
"BLOCK_SIZE_N": 64,
|
85
|
+
"BLOCK_SIZE_K": 256,
|
86
|
+
"GROUP_SIZE_M": 1,
|
87
|
+
"num_warps": 4,
|
88
|
+
"num_stages": 2
|
89
|
+
},
|
90
|
+
"256": {
|
91
|
+
"BLOCK_SIZE_M": 32,
|
92
|
+
"BLOCK_SIZE_N": 64,
|
93
|
+
"BLOCK_SIZE_K": 256,
|
94
|
+
"GROUP_SIZE_M": 32,
|
95
|
+
"num_warps": 4,
|
96
|
+
"num_stages": 2
|
97
|
+
},
|
98
|
+
"512": {
|
99
|
+
"BLOCK_SIZE_M": 32,
|
100
|
+
"BLOCK_SIZE_N": 64,
|
101
|
+
"BLOCK_SIZE_K": 256,
|
102
|
+
"GROUP_SIZE_M": 64,
|
103
|
+
"num_warps": 4,
|
104
|
+
"num_stages": 2
|
105
|
+
},
|
106
|
+
"1024": {
|
107
|
+
"BLOCK_SIZE_M": 64,
|
108
|
+
"BLOCK_SIZE_N": 128,
|
109
|
+
"BLOCK_SIZE_K": 256,
|
110
|
+
"GROUP_SIZE_M": 16,
|
111
|
+
"num_warps": 4,
|
112
|
+
"num_stages": 2
|
113
|
+
},
|
114
|
+
"1536": {
|
115
|
+
"BLOCK_SIZE_M": 128,
|
116
|
+
"BLOCK_SIZE_N": 64,
|
117
|
+
"BLOCK_SIZE_K": 256,
|
118
|
+
"GROUP_SIZE_M": 16,
|
119
|
+
"num_warps": 4,
|
120
|
+
"num_stages": 2
|
121
|
+
},
|
122
|
+
"2048": {
|
123
|
+
"BLOCK_SIZE_M": 64,
|
124
|
+
"BLOCK_SIZE_N": 128,
|
125
|
+
"BLOCK_SIZE_K": 128,
|
126
|
+
"GROUP_SIZE_M": 32,
|
127
|
+
"num_warps": 4,
|
128
|
+
"num_stages": 3
|
129
|
+
},
|
130
|
+
"3072": {
|
131
|
+
"BLOCK_SIZE_M": 128,
|
132
|
+
"BLOCK_SIZE_N": 256,
|
133
|
+
"BLOCK_SIZE_K": 128,
|
134
|
+
"GROUP_SIZE_M": 1,
|
135
|
+
"num_warps": 8,
|
136
|
+
"num_stages": 3
|
137
|
+
},
|
138
|
+
"4096": {
|
139
|
+
"BLOCK_SIZE_M": 128,
|
140
|
+
"BLOCK_SIZE_N": 256,
|
141
|
+
"BLOCK_SIZE_K": 64,
|
142
|
+
"GROUP_SIZE_M": 32,
|
143
|
+
"num_warps": 8,
|
144
|
+
"num_stages": 3
|
145
|
+
}
|
146
|
+
}
|