sglang 0.4.10.post1__py3-none-any.whl → 0.4.10.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/compile_deep_gemm.py +8 -1
- sglang/global_config.py +5 -1
- sglang/srt/conversation.py +0 -112
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +1 -0
- sglang/srt/disaggregation/prefill.py +1 -0
- sglang/srt/distributed/device_communicators/pynccl.py +7 -0
- sglang/srt/distributed/device_communicators/pynccl_allocator.py +133 -0
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +42 -3
- sglang/srt/distributed/parallel_state.py +11 -0
- sglang/srt/entrypoints/engine.py +4 -2
- sglang/srt/entrypoints/http_server.py +35 -15
- sglang/srt/eplb/expert_distribution.py +4 -2
- sglang/srt/hf_transformers_utils.py +25 -10
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/flashattention_backend.py +7 -11
- sglang/srt/layers/attention/trtllm_mla_backend.py +6 -6
- sglang/srt/layers/attention/vision.py +27 -10
- sglang/srt/layers/communicator.py +14 -4
- sglang/srt/layers/linear.py +7 -1
- sglang/srt/layers/logits_processor.py +9 -1
- sglang/srt/layers/moe/ep_moe/layer.py +11 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=352,device_name=NVIDIA_RTX_6000_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +26 -23
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +0 -31
- sglang/srt/layers/moe/token_dispatcher/__init__.py +23 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +12 -1
- sglang/srt/layers/moe/{ep_moe/token_dispatcher.py → token_dispatcher/deepep.py} +8 -15
- sglang/srt/layers/moe/utils.py +43 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +3 -2
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +1 -1
- sglang/srt/layers/quantization/fp8.py +5 -1
- sglang/srt/layers/quantization/fp8_kernel.py +0 -4
- sglang/srt/layers/vocab_parallel_embedding.py +7 -1
- sglang/srt/lora/lora_registry.py +7 -0
- sglang/srt/managers/cache_controller.py +8 -4
- sglang/srt/managers/data_parallel_controller.py +52 -2
- sglang/srt/managers/io_struct.py +6 -1
- sglang/srt/managers/schedule_batch.py +3 -2
- sglang/srt/managers/schedule_policy.py +3 -1
- sglang/srt/managers/scheduler.py +144 -6
- sglang/srt/managers/template_manager.py +25 -22
- sglang/srt/managers/tokenizer_manager.py +114 -62
- sglang/srt/managers/utils.py +45 -1
- sglang/srt/mem_cache/cpp_radix_tree/radix_tree.py +182 -0
- sglang/srt/mem_cache/hicache_storage.py +13 -21
- sglang/srt/mem_cache/radix_cache_cpp.py +229 -0
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp +35 -0
- sglang/srt/model_executor/cuda_graph_runner.py +17 -3
- sglang/srt/model_executor/forward_batch_info.py +13 -3
- sglang/srt/model_executor/model_runner.py +5 -0
- sglang/srt/models/deepseek_v2.py +23 -17
- sglang/srt/models/glm4_moe.py +82 -19
- sglang/srt/models/grok.py +3 -3
- sglang/srt/models/llama4.py +13 -2
- sglang/srt/models/mixtral.py +3 -3
- sglang/srt/models/mllama4.py +428 -19
- sglang/srt/models/qwen2_moe.py +1 -4
- sglang/srt/models/qwen3_moe.py +7 -8
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +4 -3
- sglang/srt/multimodal/processors/gemma3n.py +0 -7
- sglang/srt/operations_strategy.py +1 -1
- sglang/srt/server_args.py +80 -20
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +18 -0
- sglang/srt/two_batch_overlap.py +6 -4
- sglang/srt/utils.py +3 -24
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/runners.py +2 -2
- sglang/test/test_utils.py +3 -3
- sglang/version.py +1 -1
- {sglang-0.4.10.post1.dist-info → sglang-0.4.10.post2.dist-info}/METADATA +3 -2
- {sglang-0.4.10.post1.dist-info → sglang-0.4.10.post2.dist-info}/RECORD +80 -74
- /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/mooncake_store.py +0 -0
- /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/unit_test.py +0 -0
- /sglang/srt/mem_cache/{nixl → storage/nixl}/hicache_nixl.py +0 -0
- /sglang/srt/mem_cache/{nixl → storage/nixl}/nixl_utils.py +0 -0
- /sglang/srt/mem_cache/{nixl → storage/nixl}/test_hicache_nixl_storage.py +0 -0
- {sglang-0.4.10.post1.dist-info → sglang-0.4.10.post2.dist-info}/WHEEL +0 -0
- {sglang-0.4.10.post1.dist-info → sglang-0.4.10.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.10.post1.dist-info → sglang-0.4.10.post2.dist-info}/top_level.txt +0 -0
@@ -102,7 +102,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
|
|
102
102
|
block_kv_indices,
|
103
103
|
self.req_to_token.stride(0),
|
104
104
|
max_seqlen_pad,
|
105
|
-
PAGE_SIZE,
|
105
|
+
PAGED_SIZE=PAGE_SIZE,
|
106
106
|
)
|
107
107
|
workspace_size = cutlass_mla_get_workspace_size(
|
108
108
|
max_seqlen_pad * PAGE_SIZE, bs, num_kv_splits=1
|
@@ -165,7 +165,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
|
|
165
165
|
self.cuda_graph_kv_indices,
|
166
166
|
self.req_to_token.stride(0),
|
167
167
|
self.cuda_graph_kv_indices.stride(0),
|
168
|
-
PAGE_SIZE,
|
168
|
+
PAGED_SIZE=PAGE_SIZE,
|
169
169
|
)
|
170
170
|
self.forward_metadata = CutlassMLADecodeMetadata(
|
171
171
|
self.cuda_graph_mla_workspace,
|
@@ -206,7 +206,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
|
|
206
206
|
self.cuda_graph_kv_indices,
|
207
207
|
self.req_to_token.stride(0),
|
208
208
|
self.cuda_graph_kv_indices.stride(0),
|
209
|
-
PAGE_SIZE,
|
209
|
+
PAGED_SIZE=PAGE_SIZE,
|
210
210
|
)
|
211
211
|
else:
|
212
212
|
super().init_forward_metadata_replay_cuda_graph(
|
@@ -1406,7 +1406,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1406
1406
|
)
|
1407
1407
|
metadata.page_table = self.decode_cuda_graph_metadata[
|
1408
1408
|
"page_table_draft_decode"
|
1409
|
-
][
|
1409
|
+
][:bs, :]
|
1410
1410
|
self.decode_cuda_graph_metadata[bs] = metadata
|
1411
1411
|
else:
|
1412
1412
|
# When top k > 1, we need two specific draft decode metadata, and then merge states
|
@@ -1424,7 +1424,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1424
1424
|
][: bs + 1]
|
1425
1425
|
metadata.page_table = self.draft_decode_metadata_topk_normal[
|
1426
1426
|
"page_table"
|
1427
|
-
][
|
1427
|
+
][:bs, :]
|
1428
1428
|
|
1429
1429
|
# 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
|
1430
1430
|
metadata_expand.cache_seqlens_int32 = (
|
@@ -1461,7 +1461,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1461
1461
|
metadata.max_seq_len_k = seq_lens.max().item()
|
1462
1462
|
# Precompute page table
|
1463
1463
|
metadata.page_table = self.decode_cuda_graph_metadata["page_table"][
|
1464
|
-
|
1464
|
+
:bs, :
|
1465
1465
|
]
|
1466
1466
|
# Precompute cumulative sequence lengths
|
1467
1467
|
metadata.cu_seqlens_q = torch.arange(
|
@@ -1498,9 +1498,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1498
1498
|
: (bs + 1)
|
1499
1499
|
]
|
1500
1500
|
|
1501
|
-
metadata.page_table = self.target_verify_metadata["page_table"][
|
1502
|
-
req_pool_indices, :
|
1503
|
-
]
|
1501
|
+
metadata.page_table = self.target_verify_metadata["page_table"][:bs, :]
|
1504
1502
|
|
1505
1503
|
self.target_verify_metadata[bs] = metadata
|
1506
1504
|
else:
|
@@ -1519,7 +1517,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1519
1517
|
][: bs + 1]
|
1520
1518
|
metadata.page_table = self.target_verify_metadata_topk_normal[
|
1521
1519
|
"page_table"
|
1522
|
-
][
|
1520
|
+
][:bs, :]
|
1523
1521
|
|
1524
1522
|
# 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
|
1525
1523
|
metadata_expand.cache_seqlens_int32 = (
|
@@ -1562,9 +1560,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1562
1560
|
metadata.cu_seqlens_k = self.draft_extend_metadata["cu_seqlens_k"][
|
1563
1561
|
: (bs + 1)
|
1564
1562
|
]
|
1565
|
-
metadata.page_table = self.draft_extend_metadata["page_table"][
|
1566
|
-
req_pool_indices, :
|
1567
|
-
]
|
1563
|
+
metadata.page_table = self.draft_extend_metadata["page_table"][:bs, :]
|
1568
1564
|
|
1569
1565
|
self.draft_extend_metadata[bs] = metadata
|
1570
1566
|
|
@@ -1578,7 +1574,7 @@ class FlashAttentionBackend(AttentionBackend):
|
|
1578
1574
|
][: (encoder_bs + 1)]
|
1579
1575
|
|
1580
1576
|
metadata.encoder_page_table = self.encoder_metadata["encoder_page_table"][
|
1581
|
-
|
1577
|
+
:bs, :
|
1582
1578
|
]
|
1583
1579
|
|
1584
1580
|
self.forward_metadata = metadata
|
@@ -147,8 +147,8 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
147
147
|
block_kv_indices,
|
148
148
|
self.req_to_token.stride(0),
|
149
149
|
max_blocks,
|
150
|
-
TRITON_PAD_NUM_PAGE_PER_BLOCK,
|
151
|
-
self.page_size,
|
150
|
+
NUM_PAGE_PER_BLOCK=TRITON_PAD_NUM_PAGE_PER_BLOCK,
|
151
|
+
PAGED_SIZE=self.page_size,
|
152
152
|
)
|
153
153
|
|
154
154
|
return block_kv_indices
|
@@ -204,8 +204,8 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
204
204
|
block_kv_indices,
|
205
205
|
self.req_to_token.stride(0),
|
206
206
|
max_seqlen_pad,
|
207
|
-
TRITON_PAD_NUM_PAGE_PER_BLOCK,
|
208
|
-
self.page_size,
|
207
|
+
NUM_PAGE_PER_BLOCK=TRITON_PAD_NUM_PAGE_PER_BLOCK,
|
208
|
+
PAGED_SIZE=self.page_size,
|
209
209
|
)
|
210
210
|
|
211
211
|
metadata = TRTLLMMLADecodeMetadata(self.cuda_graph_workspace, block_kv_indices)
|
@@ -248,8 +248,8 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
|
|
248
248
|
metadata.block_kv_indices,
|
249
249
|
self.req_to_token.stride(0),
|
250
250
|
metadata.block_kv_indices.shape[1],
|
251
|
-
TRITON_PAD_NUM_PAGE_PER_BLOCK,
|
252
|
-
self.page_size,
|
251
|
+
NUM_PAGE_PER_BLOCK=TRITON_PAD_NUM_PAGE_PER_BLOCK,
|
252
|
+
PAGED_SIZE=self.page_size,
|
253
253
|
)
|
254
254
|
|
255
255
|
def get_cuda_graph_seq_len_fill_value(self) -> int:
|
@@ -4,7 +4,7 @@ import dataclasses
|
|
4
4
|
import functools
|
5
5
|
import math
|
6
6
|
from functools import lru_cache, partial
|
7
|
-
from typing import Any, Optional, Tuple, Union
|
7
|
+
from typing import Any, Callable, Optional, Tuple, Union
|
8
8
|
|
9
9
|
import torch
|
10
10
|
import torch.nn as nn
|
@@ -308,6 +308,7 @@ class VisionFlash3Attention(nn.Module):
|
|
308
308
|
cu_seqlens = cu_seqlens.to(dtype=torch.int32).to(q.device)
|
309
309
|
seq_lens = cu_seqlens[1:] - cu_seqlens[:-1]
|
310
310
|
max_seqlen = seq_lens.max().item()
|
311
|
+
|
311
312
|
output = flash_attn_varlen_func(
|
312
313
|
q,
|
313
314
|
k,
|
@@ -358,6 +359,9 @@ class VisionAttention(nn.Module):
|
|
358
359
|
qkv_bias: bool = True,
|
359
360
|
qk_normalization: bool = False,
|
360
361
|
layer_norm_eps: float = 1e-06,
|
362
|
+
customized_position_embedding_applier: Callable[
|
363
|
+
[torch.Tensor, torch.Tensor, Any, Any], Tuple[torch.Tensor, torch.Tensor]
|
364
|
+
] = None,
|
361
365
|
**kwargs,
|
362
366
|
):
|
363
367
|
super().__init__()
|
@@ -392,6 +396,7 @@ class VisionAttention(nn.Module):
|
|
392
396
|
self.dummy_dim, eps=layer_norm_eps, var_hidden_size=embed_dim
|
393
397
|
)
|
394
398
|
|
399
|
+
# priority: server_args > passed qkv_backend > sdpa
|
395
400
|
if global_server_args_dict["mm_attention_backend"] is None:
|
396
401
|
if qkv_backend is None:
|
397
402
|
qkv_backend = "sdpa"
|
@@ -401,6 +406,9 @@ class VisionAttention(nn.Module):
|
|
401
406
|
|
402
407
|
print_info_once(f"Using {qkv_backend} as multimodal attention backend.")
|
403
408
|
|
409
|
+
self.customized_position_embedding_applier = (
|
410
|
+
customized_position_embedding_applier
|
411
|
+
)
|
404
412
|
self.qkv_backend = QKV_BACKEND_IMPL[qkv_backend](
|
405
413
|
head_dim=self.head_size,
|
406
414
|
num_heads=self.num_attention_heads_per_partition,
|
@@ -473,13 +481,13 @@ class VisionAttention(nn.Module):
|
|
473
481
|
if x.dim() == 2:
|
474
482
|
x = x.unsqueeze(0)
|
475
483
|
assert x.dim() == 3, x.shape
|
476
|
-
|
484
|
+
x_shape = x.shape
|
485
|
+
bsz, s, _ = x_shape
|
477
486
|
head = self.num_attention_heads_per_partition
|
478
487
|
kv_head = self.num_attention_kv_heads_per_partition
|
479
488
|
if self.use_qkv_parallel:
|
480
489
|
# [b, s, embed_dim] --> [b, s, embed_dim]
|
481
490
|
qkv, _ = self.qkv_proj(x)
|
482
|
-
|
483
491
|
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
484
492
|
|
485
493
|
# [b, s, embed_dim] --> [b * s, head, head_size]
|
@@ -508,16 +516,25 @@ class VisionAttention(nn.Module):
|
|
508
516
|
]
|
509
517
|
|
510
518
|
if position_embeddings is not None:
|
511
|
-
cos, sin = position_embeddings
|
512
519
|
original_shape = q.shape
|
513
|
-
# [total_tokens, head, head_size]
|
514
|
-
q = q.view(-1, head, self.head_size)
|
515
|
-
k = k.view(-1, head, self.head_size)
|
516
520
|
|
517
|
-
|
521
|
+
if self.customized_position_embedding_applier is not None:
|
522
|
+
q, k = self.customized_position_embedding_applier(
|
523
|
+
q, k, position_embeddings, x_shape
|
524
|
+
)
|
525
|
+
q = q.view(original_shape)
|
526
|
+
k = k.view(original_shape)
|
527
|
+
else:
|
528
|
+
cos, sin = position_embeddings
|
529
|
+
|
530
|
+
# [total_tokens, head, head_size]
|
531
|
+
q = q.view(-1, head, self.head_size)
|
532
|
+
k = k.view(-1, head, self.head_size)
|
533
|
+
|
534
|
+
q, k = apply_rotary_pos_emb(q, k, cos, sin)
|
518
535
|
|
519
|
-
|
520
|
-
|
536
|
+
q = q.view(original_shape)
|
537
|
+
k = k.view(original_shape)
|
521
538
|
|
522
539
|
if q.dim() == 4:
|
523
540
|
# [b, s, head, head_size] --> [b * s, head, head_size]
|
@@ -108,7 +108,7 @@ class LayerScatterModes:
|
|
108
108
|
if context.is_layer_sparse:
|
109
109
|
return (
|
110
110
|
ScatterMode.SCATTERED
|
111
|
-
if global_server_args_dict["
|
111
|
+
if not global_server_args_dict["moe_a2a_backend"].is_standard()
|
112
112
|
else ScatterMode.FULL
|
113
113
|
)
|
114
114
|
else:
|
@@ -404,14 +404,24 @@ class CommunicateWithAllReduceAndLayerNormFn:
|
|
404
404
|
if context.attn_dp_size != 1:
|
405
405
|
if context.attn_tp_rank == 0:
|
406
406
|
hidden_states += residual
|
407
|
+
|
408
|
+
# Perform layernorm on smaller data before comm. Only valid when attn_tp_size is 1 (tp_size == dp_size)
|
409
|
+
use_layer_norm_before_gather = context.attn_tp_size == 1
|
410
|
+
if use_layer_norm_before_gather:
|
411
|
+
residual.copy_(hidden_states)
|
412
|
+
if hidden_states.shape[0] != 0:
|
413
|
+
hidden_states = layernorm(hidden_states)
|
414
|
+
|
407
415
|
hidden_states, local_hidden_states = (
|
408
416
|
forward_batch.gathered_buffer,
|
409
417
|
hidden_states,
|
410
418
|
)
|
411
419
|
dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
|
412
|
-
|
413
|
-
if
|
414
|
-
hidden_states
|
420
|
+
|
421
|
+
if not use_layer_norm_before_gather:
|
422
|
+
dp_scatter(residual, hidden_states, forward_batch)
|
423
|
+
if hidden_states.shape[0] != 0:
|
424
|
+
hidden_states = layernorm(hidden_states)
|
415
425
|
else:
|
416
426
|
# According to the discussion in https://github.com/flashinfer-ai/flashinfer/issues/1223#issuecomment-3047256465
|
417
427
|
# We set the max token num to 128 for allreduce fusion with min-latency case(use_oneshot=True).
|
sglang/srt/layers/linear.py
CHANGED
@@ -13,10 +13,14 @@ from sglang.srt.distributed import (
|
|
13
13
|
divide,
|
14
14
|
get_tensor_model_parallel_rank,
|
15
15
|
get_tensor_model_parallel_world_size,
|
16
|
+
parallel_state,
|
16
17
|
split_tensor_along_last_dim,
|
17
18
|
tensor_model_parallel_all_gather,
|
18
19
|
tensor_model_parallel_all_reduce,
|
19
20
|
)
|
21
|
+
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
|
22
|
+
use_symmetric_memory,
|
23
|
+
)
|
20
24
|
from sglang.srt.layers.parameter import (
|
21
25
|
BasevLLMParameter,
|
22
26
|
BlockQuantScaleParameter,
|
@@ -1292,7 +1296,9 @@ class RowParallelLinear(LinearBase):
|
|
1292
1296
|
# Only fuse bias add into GEMM for rank 0 (this ensures that
|
1293
1297
|
# bias will not get added more than once in TP>1 case)
|
1294
1298
|
bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
|
1295
|
-
|
1299
|
+
with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
|
1300
|
+
output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_)
|
1301
|
+
sm.tag(output_parallel)
|
1296
1302
|
if self.reduce_results and self.tp_size > 1 and not can_fuse_mlp_allreduce:
|
1297
1303
|
output = tensor_model_parallel_all_reduce(output_parallel)
|
1298
1304
|
else:
|
@@ -83,6 +83,7 @@ class LogitsProcessorOutput:
|
|
83
83
|
class LogitsMetadata:
|
84
84
|
forward_mode: ForwardMode
|
85
85
|
capture_hidden_mode: CaptureHiddenMode = CaptureHiddenMode.NULL
|
86
|
+
next_token_logits_buffer: Optional[torch.Tensor] = None
|
86
87
|
|
87
88
|
extend_return_logprob: bool = False
|
88
89
|
extend_return_top_logprob: bool = False
|
@@ -148,6 +149,7 @@ class LogitsMetadata:
|
|
148
149
|
return cls(
|
149
150
|
forward_mode=forward_batch.forward_mode,
|
150
151
|
capture_hidden_mode=forward_batch.capture_hidden_mode,
|
152
|
+
next_token_logits_buffer=forward_batch.next_token_logits_buffer,
|
151
153
|
extend_return_logprob=extend_return_logprob,
|
152
154
|
extend_return_top_logprob=extend_return_top_logprob,
|
153
155
|
extend_token_ids_logprob=extend_token_ids_logprob,
|
@@ -508,7 +510,13 @@ class LogitsProcessor(nn.Module):
|
|
508
510
|
)
|
509
511
|
dp_scatter(logits, global_logits, logits_metadata)
|
510
512
|
|
511
|
-
|
513
|
+
if logits_metadata.next_token_logits_buffer is not None:
|
514
|
+
logits_buffer = logits_metadata.next_token_logits_buffer
|
515
|
+
assert logits_buffer.dtype == torch.float
|
516
|
+
logits_buffer.copy_(logits[:, : self.config.vocab_size])
|
517
|
+
logits = logits_buffer
|
518
|
+
else:
|
519
|
+
logits = logits[:, : self.config.vocab_size].float()
|
512
520
|
|
513
521
|
if self.final_logit_softcapping:
|
514
522
|
fused_softcap(logits, self.final_logit_softcapping)
|
@@ -1,28 +1,17 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import logging
|
4
|
-
from typing import TYPE_CHECKING,
|
4
|
+
from typing import TYPE_CHECKING, Optional
|
5
5
|
|
6
6
|
import torch
|
7
7
|
|
8
|
-
from sglang.srt.distributed import
|
9
|
-
get_tensor_model_parallel_rank,
|
10
|
-
get_tensor_model_parallel_world_size,
|
11
|
-
)
|
12
|
-
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
|
8
|
+
from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
|
13
9
|
from sglang.srt.layers.moe.ep_moe.kernels import (
|
14
10
|
ep_gather,
|
15
11
|
ep_scatter,
|
16
|
-
gelu_and_mul_triton_kernel,
|
17
|
-
grouped_gemm_triton,
|
18
12
|
moe_ep_deepgemm_preprocess,
|
19
13
|
post_reorder_triton_kernel,
|
20
|
-
pre_reorder_triton_kernel,
|
21
|
-
pre_reorder_triton_kernel_for_cutlass_moe,
|
22
|
-
run_cutlass_moe_ep_preproess,
|
23
|
-
run_moe_ep_preproess,
|
24
14
|
silu_and_mul_masked_post_quant_fwd,
|
25
|
-
silu_and_mul_triton_kernel,
|
26
15
|
tma_align_input_scale,
|
27
16
|
)
|
28
17
|
from sglang.srt.layers.moe.fused_moe_triton.layer import (
|
@@ -31,11 +20,9 @@ from sglang.srt.layers.moe.fused_moe_triton.layer import (
|
|
31
20
|
should_use_flashinfer_trtllm_moe,
|
32
21
|
)
|
33
22
|
from sglang.srt.layers.moe.topk import TopKOutput
|
23
|
+
from sglang.srt.layers.moe.utils import DeepEPMode
|
34
24
|
from sglang.srt.layers.quantization import deep_gemm_wrapper
|
35
|
-
from sglang.srt.layers.quantization.base_config import
|
36
|
-
QuantizationConfig,
|
37
|
-
QuantizeMethodBase,
|
38
|
-
)
|
25
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
39
26
|
from sglang.srt.layers.quantization.fp8 import (
|
40
27
|
Fp8Config,
|
41
28
|
Fp8MoEMethod,
|
@@ -44,23 +31,13 @@ from sglang.srt.layers.quantization.fp8 import (
|
|
44
31
|
from sglang.srt.layers.quantization.fp8_kernel import (
|
45
32
|
is_fp8_fnuz,
|
46
33
|
sglang_per_token_group_quant_fp8,
|
47
|
-
sglang_per_token_quant_fp8,
|
48
34
|
)
|
49
|
-
from sglang.srt.layers.quantization.unquant import UnquantizedFusedMoEMethod
|
50
|
-
from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config, W4AFp8MoEMethod
|
51
35
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
52
36
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
53
|
-
from sglang.srt.utils import
|
54
|
-
DeepEPMode,
|
55
|
-
ceil_div,
|
56
|
-
dispose_tensor,
|
57
|
-
get_bool_env_var,
|
58
|
-
is_hip,
|
59
|
-
is_npu,
|
60
|
-
)
|
37
|
+
from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip, is_npu
|
61
38
|
|
62
39
|
if TYPE_CHECKING:
|
63
|
-
from sglang.srt.layers.moe.
|
40
|
+
from sglang.srt.layers.moe.token_dispatcher import (
|
64
41
|
DeepEPLLOutput,
|
65
42
|
DeepEPNormalOutput,
|
66
43
|
DispatchOutput,
|
@@ -119,7 +96,6 @@ class EPMoE(FusedMoE):
|
|
119
96
|
activation=activation,
|
120
97
|
# apply_router_weight_on_input=apply_router_weight_on_input,
|
121
98
|
routed_scaling_factor=routed_scaling_factor,
|
122
|
-
enable_ep_moe=True,
|
123
99
|
)
|
124
100
|
|
125
101
|
self.start_expert_id = self.moe_ep_rank * self.num_local_experts
|
@@ -304,6 +280,8 @@ class EPMoE(FusedMoE):
|
|
304
280
|
m_max * self.start_expert_id,
|
305
281
|
BLOCK_SIZE=512,
|
306
282
|
)
|
283
|
+
if self.routed_scaling_factor is not None:
|
284
|
+
output *= self.routed_scaling_factor
|
307
285
|
return output
|
308
286
|
|
309
287
|
|
@@ -328,7 +306,7 @@ class DeepEPMoE(EPMoE):
|
|
328
306
|
prefix: str = "",
|
329
307
|
activation: str = "silu",
|
330
308
|
routed_scaling_factor: Optional[float] = None,
|
331
|
-
deepep_mode: DeepEPMode = DeepEPMode.
|
309
|
+
deepep_mode: DeepEPMode = DeepEPMode.AUTO,
|
332
310
|
):
|
333
311
|
super().__init__(
|
334
312
|
num_experts=num_experts,
|
@@ -348,7 +326,6 @@ class DeepEPMoE(EPMoE):
|
|
348
326
|
|
349
327
|
# TODO: move to the beginning of the file
|
350
328
|
from sglang.srt.distributed.parallel_state import get_tp_group
|
351
|
-
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
352
329
|
from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher
|
353
330
|
|
354
331
|
self.deepep_dispatcher = MaybeTboDeepEPDispatcher(
|
@@ -762,11 +739,10 @@ class FlashInferEPMoE(EPMoE):
|
|
762
739
|
|
763
740
|
|
764
741
|
def get_moe_impl_class():
|
765
|
-
if global_server_args_dict["
|
742
|
+
if global_server_args_dict["moe_a2a_backend"].is_deepep():
|
766
743
|
return DeepEPMoE
|
767
744
|
if global_server_args_dict["enable_flashinfer_cutlass_moe"]:
|
768
|
-
# Must come before EPMoE because FusedMoE also supports enable_ep_moe
|
769
745
|
return FusedMoE
|
770
|
-
if
|
746
|
+
if get_moe_expert_parallel_world_size() > 1:
|
771
747
|
return FlashInferEPMoE if should_use_flashinfer_trtllm_moe() else EPMoE
|
772
748
|
return FlashInferFusedMoE if should_use_flashinfer_trtllm_moe() else FusedMoE
|
@@ -0,0 +1,146 @@
|
|
1
|
+
{
|
2
|
+
"1": {
|
3
|
+
"BLOCK_SIZE_M": 16,
|
4
|
+
"BLOCK_SIZE_N": 64,
|
5
|
+
"BLOCK_SIZE_K": 128,
|
6
|
+
"GROUP_SIZE_M": 1,
|
7
|
+
"num_warps": 4,
|
8
|
+
"num_stages": 2
|
9
|
+
},
|
10
|
+
"2": {
|
11
|
+
"BLOCK_SIZE_M": 16,
|
12
|
+
"BLOCK_SIZE_N": 64,
|
13
|
+
"BLOCK_SIZE_K": 128,
|
14
|
+
"GROUP_SIZE_M": 32,
|
15
|
+
"num_warps": 4,
|
16
|
+
"num_stages": 2
|
17
|
+
},
|
18
|
+
"4": {
|
19
|
+
"BLOCK_SIZE_M": 16,
|
20
|
+
"BLOCK_SIZE_N": 32,
|
21
|
+
"BLOCK_SIZE_K": 256,
|
22
|
+
"GROUP_SIZE_M": 64,
|
23
|
+
"num_warps": 4,
|
24
|
+
"num_stages": 2
|
25
|
+
},
|
26
|
+
"8": {
|
27
|
+
"BLOCK_SIZE_M": 16,
|
28
|
+
"BLOCK_SIZE_N": 64,
|
29
|
+
"BLOCK_SIZE_K": 256,
|
30
|
+
"GROUP_SIZE_M": 32,
|
31
|
+
"num_warps": 4,
|
32
|
+
"num_stages": 3
|
33
|
+
},
|
34
|
+
"16": {
|
35
|
+
"BLOCK_SIZE_M": 32,
|
36
|
+
"BLOCK_SIZE_N": 64,
|
37
|
+
"BLOCK_SIZE_K": 256,
|
38
|
+
"GROUP_SIZE_M": 1,
|
39
|
+
"num_warps": 8,
|
40
|
+
"num_stages": 2
|
41
|
+
},
|
42
|
+
"24": {
|
43
|
+
"BLOCK_SIZE_M": 16,
|
44
|
+
"BLOCK_SIZE_N": 64,
|
45
|
+
"BLOCK_SIZE_K": 256,
|
46
|
+
"GROUP_SIZE_M": 32,
|
47
|
+
"num_warps": 4,
|
48
|
+
"num_stages": 3
|
49
|
+
},
|
50
|
+
"32": {
|
51
|
+
"BLOCK_SIZE_M": 16,
|
52
|
+
"BLOCK_SIZE_N": 32,
|
53
|
+
"BLOCK_SIZE_K": 128,
|
54
|
+
"GROUP_SIZE_M": 16,
|
55
|
+
"num_warps": 4,
|
56
|
+
"num_stages": 4
|
57
|
+
},
|
58
|
+
"48": {
|
59
|
+
"BLOCK_SIZE_M": 16,
|
60
|
+
"BLOCK_SIZE_N": 32,
|
61
|
+
"BLOCK_SIZE_K": 128,
|
62
|
+
"GROUP_SIZE_M": 16,
|
63
|
+
"num_warps": 8,
|
64
|
+
"num_stages": 3
|
65
|
+
},
|
66
|
+
"64": {
|
67
|
+
"BLOCK_SIZE_M": 16,
|
68
|
+
"BLOCK_SIZE_N": 32,
|
69
|
+
"BLOCK_SIZE_K": 256,
|
70
|
+
"GROUP_SIZE_M": 16,
|
71
|
+
"num_warps": 8,
|
72
|
+
"num_stages": 2
|
73
|
+
},
|
74
|
+
"96": {
|
75
|
+
"BLOCK_SIZE_M": 16,
|
76
|
+
"BLOCK_SIZE_N": 64,
|
77
|
+
"BLOCK_SIZE_K": 256,
|
78
|
+
"GROUP_SIZE_M": 16,
|
79
|
+
"num_warps": 4,
|
80
|
+
"num_stages": 3
|
81
|
+
},
|
82
|
+
"128": {
|
83
|
+
"BLOCK_SIZE_M": 16,
|
84
|
+
"BLOCK_SIZE_N": 64,
|
85
|
+
"BLOCK_SIZE_K": 256,
|
86
|
+
"GROUP_SIZE_M": 1,
|
87
|
+
"num_warps": 4,
|
88
|
+
"num_stages": 2
|
89
|
+
},
|
90
|
+
"256": {
|
91
|
+
"BLOCK_SIZE_M": 32,
|
92
|
+
"BLOCK_SIZE_N": 64,
|
93
|
+
"BLOCK_SIZE_K": 256,
|
94
|
+
"GROUP_SIZE_M": 32,
|
95
|
+
"num_warps": 4,
|
96
|
+
"num_stages": 2
|
97
|
+
},
|
98
|
+
"512": {
|
99
|
+
"BLOCK_SIZE_M": 32,
|
100
|
+
"BLOCK_SIZE_N": 64,
|
101
|
+
"BLOCK_SIZE_K": 256,
|
102
|
+
"GROUP_SIZE_M": 64,
|
103
|
+
"num_warps": 4,
|
104
|
+
"num_stages": 2
|
105
|
+
},
|
106
|
+
"1024": {
|
107
|
+
"BLOCK_SIZE_M": 64,
|
108
|
+
"BLOCK_SIZE_N": 128,
|
109
|
+
"BLOCK_SIZE_K": 256,
|
110
|
+
"GROUP_SIZE_M": 16,
|
111
|
+
"num_warps": 4,
|
112
|
+
"num_stages": 2
|
113
|
+
},
|
114
|
+
"1536": {
|
115
|
+
"BLOCK_SIZE_M": 128,
|
116
|
+
"BLOCK_SIZE_N": 64,
|
117
|
+
"BLOCK_SIZE_K": 256,
|
118
|
+
"GROUP_SIZE_M": 16,
|
119
|
+
"num_warps": 4,
|
120
|
+
"num_stages": 2
|
121
|
+
},
|
122
|
+
"2048": {
|
123
|
+
"BLOCK_SIZE_M": 64,
|
124
|
+
"BLOCK_SIZE_N": 128,
|
125
|
+
"BLOCK_SIZE_K": 128,
|
126
|
+
"GROUP_SIZE_M": 32,
|
127
|
+
"num_warps": 4,
|
128
|
+
"num_stages": 3
|
129
|
+
},
|
130
|
+
"3072": {
|
131
|
+
"BLOCK_SIZE_M": 128,
|
132
|
+
"BLOCK_SIZE_N": 256,
|
133
|
+
"BLOCK_SIZE_K": 128,
|
134
|
+
"GROUP_SIZE_M": 1,
|
135
|
+
"num_warps": 8,
|
136
|
+
"num_stages": 3
|
137
|
+
},
|
138
|
+
"4096": {
|
139
|
+
"BLOCK_SIZE_M": 128,
|
140
|
+
"BLOCK_SIZE_N": 256,
|
141
|
+
"BLOCK_SIZE_K": 64,
|
142
|
+
"GROUP_SIZE_M": 32,
|
143
|
+
"num_warps": 8,
|
144
|
+
"num_stages": 3
|
145
|
+
}
|
146
|
+
}
|
@@ -14,10 +14,12 @@ from sglang.srt.distributed import (
|
|
14
14
|
get_moe_expert_parallel_world_size,
|
15
15
|
get_moe_tensor_parallel_rank,
|
16
16
|
get_moe_tensor_parallel_world_size,
|
17
|
-
|
18
|
-
get_tensor_model_parallel_world_size,
|
17
|
+
get_tp_group,
|
19
18
|
tensor_model_parallel_all_reduce,
|
20
19
|
)
|
20
|
+
from sglang.srt.distributed.device_communicators.pynccl_allocator import (
|
21
|
+
use_symmetric_memory,
|
22
|
+
)
|
21
23
|
from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
|
22
24
|
from sglang.srt.layers.moe.topk import StandardTopKOutput
|
23
25
|
from sglang.srt.layers.quantization.base_config import (
|
@@ -94,7 +96,6 @@ class FusedMoE(torch.nn.Module):
|
|
94
96
|
no_combine: bool = False,
|
95
97
|
routed_scaling_factor: Optional[float] = None,
|
96
98
|
enable_flashinfer_cutlass_moe: Optional[bool] = False,
|
97
|
-
enable_ep_moe: Optional[bool] = False,
|
98
99
|
):
|
99
100
|
super().__init__()
|
100
101
|
|
@@ -112,7 +113,6 @@ class FusedMoE(torch.nn.Module):
|
|
112
113
|
if enable_flashinfer_cutlass_moe and quant_config is None:
|
113
114
|
logger.warning("Disable flashinfer MoE when quantization config is None.")
|
114
115
|
enable_flashinfer_cutlass_moe = False
|
115
|
-
enable_ep_moe = False
|
116
116
|
|
117
117
|
self.enable_flashinfer_cutlass_moe = enable_flashinfer_cutlass_moe
|
118
118
|
self.moe_ep_size = get_moe_expert_parallel_world_size()
|
@@ -121,7 +121,7 @@ class FusedMoE(torch.nn.Module):
|
|
121
121
|
self.moe_tp_rank = get_moe_tensor_parallel_rank()
|
122
122
|
assert num_experts % self.moe_ep_size == 0
|
123
123
|
self.num_local_experts = num_experts // self.moe_ep_size
|
124
|
-
if
|
124
|
+
if self.moe_ep_size > 1:
|
125
125
|
# TODO(ch-wan): support shared experts fusion
|
126
126
|
# Create a tensor of size num_experts filled with -1
|
127
127
|
self.expert_map_cpu = torch.full((self.num_experts,), -1, dtype=torch.int32)
|
@@ -630,24 +630,27 @@ class FusedMoE(torch.nn.Module):
|
|
630
630
|
)
|
631
631
|
|
632
632
|
# Matrix multiply.
|
633
|
-
|
634
|
-
|
635
|
-
|
636
|
-
|
637
|
-
|
638
|
-
|
639
|
-
|
640
|
-
|
641
|
-
|
642
|
-
|
643
|
-
|
644
|
-
|
645
|
-
|
646
|
-
|
647
|
-
|
648
|
-
|
649
|
-
|
650
|
-
|
633
|
+
with use_symmetric_memory(get_tp_group()) as sm:
|
634
|
+
final_hidden_states = self.quant_method.apply(
|
635
|
+
layer=self,
|
636
|
+
x=hidden_states,
|
637
|
+
topk_output=topk_output,
|
638
|
+
activation=self.activation,
|
639
|
+
apply_router_weight_on_input=self.apply_router_weight_on_input,
|
640
|
+
routed_scaling_factor=self.routed_scaling_factor,
|
641
|
+
**(
|
642
|
+
dict(
|
643
|
+
tp_rank=self.moe_tp_rank,
|
644
|
+
tp_size=self.moe_tp_size,
|
645
|
+
ep_rank=self.moe_ep_rank,
|
646
|
+
ep_size=self.moe_ep_size,
|
647
|
+
)
|
648
|
+
if self.quant_method.__class__.__name__
|
649
|
+
== "ModelOptNvFp4FusedMoEMethod"
|
650
|
+
else {}
|
651
|
+
),
|
652
|
+
)
|
653
|
+
sm.tag(final_hidden_states)
|
651
654
|
|
652
655
|
if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1):
|
653
656
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|