sglang 0.4.10.post1__py3-none-any.whl → 0.4.10.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (80) hide show
  1. sglang/compile_deep_gemm.py +8 -1
  2. sglang/global_config.py +5 -1
  3. sglang/srt/conversation.py +0 -112
  4. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +1 -0
  5. sglang/srt/disaggregation/prefill.py +1 -0
  6. sglang/srt/distributed/device_communicators/pynccl.py +7 -0
  7. sglang/srt/distributed/device_communicators/pynccl_allocator.py +133 -0
  8. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +42 -3
  9. sglang/srt/distributed/parallel_state.py +11 -0
  10. sglang/srt/entrypoints/engine.py +4 -2
  11. sglang/srt/entrypoints/http_server.py +35 -15
  12. sglang/srt/eplb/expert_distribution.py +4 -2
  13. sglang/srt/hf_transformers_utils.py +25 -10
  14. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  15. sglang/srt/layers/attention/flashattention_backend.py +7 -11
  16. sglang/srt/layers/attention/trtllm_mla_backend.py +6 -6
  17. sglang/srt/layers/attention/vision.py +27 -10
  18. sglang/srt/layers/communicator.py +14 -4
  19. sglang/srt/layers/linear.py +7 -1
  20. sglang/srt/layers/logits_processor.py +9 -1
  21. sglang/srt/layers/moe/ep_moe/layer.py +11 -35
  22. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=352,device_name=NVIDIA_RTX_6000_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  23. sglang/srt/layers/moe/fused_moe_triton/layer.py +26 -23
  24. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +0 -31
  25. sglang/srt/layers/moe/token_dispatcher/__init__.py +23 -0
  26. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +12 -1
  27. sglang/srt/layers/moe/{ep_moe/token_dispatcher.py → token_dispatcher/deepep.py} +8 -15
  28. sglang/srt/layers/moe/utils.py +43 -0
  29. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +3 -2
  30. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +1 -1
  31. sglang/srt/layers/quantization/fp8.py +5 -1
  32. sglang/srt/layers/quantization/fp8_kernel.py +0 -4
  33. sglang/srt/layers/vocab_parallel_embedding.py +7 -1
  34. sglang/srt/lora/lora_registry.py +7 -0
  35. sglang/srt/managers/cache_controller.py +8 -4
  36. sglang/srt/managers/data_parallel_controller.py +52 -2
  37. sglang/srt/managers/io_struct.py +6 -1
  38. sglang/srt/managers/schedule_batch.py +3 -2
  39. sglang/srt/managers/schedule_policy.py +3 -1
  40. sglang/srt/managers/scheduler.py +144 -6
  41. sglang/srt/managers/template_manager.py +25 -22
  42. sglang/srt/managers/tokenizer_manager.py +114 -62
  43. sglang/srt/managers/utils.py +45 -1
  44. sglang/srt/mem_cache/cpp_radix_tree/radix_tree.py +182 -0
  45. sglang/srt/mem_cache/hicache_storage.py +13 -21
  46. sglang/srt/mem_cache/radix_cache_cpp.py +229 -0
  47. sglang/srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp +35 -0
  48. sglang/srt/model_executor/cuda_graph_runner.py +17 -3
  49. sglang/srt/model_executor/forward_batch_info.py +13 -3
  50. sglang/srt/model_executor/model_runner.py +5 -0
  51. sglang/srt/models/deepseek_v2.py +23 -17
  52. sglang/srt/models/glm4_moe.py +82 -19
  53. sglang/srt/models/grok.py +3 -3
  54. sglang/srt/models/llama4.py +13 -2
  55. sglang/srt/models/mixtral.py +3 -3
  56. sglang/srt/models/mllama4.py +428 -19
  57. sglang/srt/models/qwen2_moe.py +1 -4
  58. sglang/srt/models/qwen3_moe.py +7 -8
  59. sglang/srt/models/step3_vl.py +1 -1
  60. sglang/srt/multimodal/processors/base_processor.py +4 -3
  61. sglang/srt/multimodal/processors/gemma3n.py +0 -7
  62. sglang/srt/operations_strategy.py +1 -1
  63. sglang/srt/server_args.py +80 -20
  64. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +18 -0
  65. sglang/srt/two_batch_overlap.py +6 -4
  66. sglang/srt/utils.py +3 -24
  67. sglang/srt/weight_sync/utils.py +1 -1
  68. sglang/test/runners.py +2 -2
  69. sglang/test/test_utils.py +3 -3
  70. sglang/version.py +1 -1
  71. {sglang-0.4.10.post1.dist-info → sglang-0.4.10.post2.dist-info}/METADATA +3 -2
  72. {sglang-0.4.10.post1.dist-info → sglang-0.4.10.post2.dist-info}/RECORD +80 -74
  73. /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/mooncake_store.py +0 -0
  74. /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/unit_test.py +0 -0
  75. /sglang/srt/mem_cache/{nixl → storage/nixl}/hicache_nixl.py +0 -0
  76. /sglang/srt/mem_cache/{nixl → storage/nixl}/nixl_utils.py +0 -0
  77. /sglang/srt/mem_cache/{nixl → storage/nixl}/test_hicache_nixl_storage.py +0 -0
  78. {sglang-0.4.10.post1.dist-info → sglang-0.4.10.post2.dist-info}/WHEEL +0 -0
  79. {sglang-0.4.10.post1.dist-info → sglang-0.4.10.post2.dist-info}/licenses/LICENSE +0 -0
  80. {sglang-0.4.10.post1.dist-info → sglang-0.4.10.post2.dist-info}/top_level.txt +0 -0
@@ -102,7 +102,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
102
102
  block_kv_indices,
103
103
  self.req_to_token.stride(0),
104
104
  max_seqlen_pad,
105
- PAGE_SIZE,
105
+ PAGED_SIZE=PAGE_SIZE,
106
106
  )
107
107
  workspace_size = cutlass_mla_get_workspace_size(
108
108
  max_seqlen_pad * PAGE_SIZE, bs, num_kv_splits=1
@@ -165,7 +165,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
165
165
  self.cuda_graph_kv_indices,
166
166
  self.req_to_token.stride(0),
167
167
  self.cuda_graph_kv_indices.stride(0),
168
- PAGE_SIZE,
168
+ PAGED_SIZE=PAGE_SIZE,
169
169
  )
170
170
  self.forward_metadata = CutlassMLADecodeMetadata(
171
171
  self.cuda_graph_mla_workspace,
@@ -206,7 +206,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
206
206
  self.cuda_graph_kv_indices,
207
207
  self.req_to_token.stride(0),
208
208
  self.cuda_graph_kv_indices.stride(0),
209
- PAGE_SIZE,
209
+ PAGED_SIZE=PAGE_SIZE,
210
210
  )
211
211
  else:
212
212
  super().init_forward_metadata_replay_cuda_graph(
@@ -1406,7 +1406,7 @@ class FlashAttentionBackend(AttentionBackend):
1406
1406
  )
1407
1407
  metadata.page_table = self.decode_cuda_graph_metadata[
1408
1408
  "page_table_draft_decode"
1409
- ][req_pool_indices, :]
1409
+ ][:bs, :]
1410
1410
  self.decode_cuda_graph_metadata[bs] = metadata
1411
1411
  else:
1412
1412
  # When top k > 1, we need two specific draft decode metadata, and then merge states
@@ -1424,7 +1424,7 @@ class FlashAttentionBackend(AttentionBackend):
1424
1424
  ][: bs + 1]
1425
1425
  metadata.page_table = self.draft_decode_metadata_topk_normal[
1426
1426
  "page_table"
1427
- ][req_pool_indices, :]
1427
+ ][:bs, :]
1428
1428
 
1429
1429
  # 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
1430
1430
  metadata_expand.cache_seqlens_int32 = (
@@ -1461,7 +1461,7 @@ class FlashAttentionBackend(AttentionBackend):
1461
1461
  metadata.max_seq_len_k = seq_lens.max().item()
1462
1462
  # Precompute page table
1463
1463
  metadata.page_table = self.decode_cuda_graph_metadata["page_table"][
1464
- req_pool_indices, :
1464
+ :bs, :
1465
1465
  ]
1466
1466
  # Precompute cumulative sequence lengths
1467
1467
  metadata.cu_seqlens_q = torch.arange(
@@ -1498,9 +1498,7 @@ class FlashAttentionBackend(AttentionBackend):
1498
1498
  : (bs + 1)
1499
1499
  ]
1500
1500
 
1501
- metadata.page_table = self.target_verify_metadata["page_table"][
1502
- req_pool_indices, :
1503
- ]
1501
+ metadata.page_table = self.target_verify_metadata["page_table"][:bs, :]
1504
1502
 
1505
1503
  self.target_verify_metadata[bs] = metadata
1506
1504
  else:
@@ -1519,7 +1517,7 @@ class FlashAttentionBackend(AttentionBackend):
1519
1517
  ][: bs + 1]
1520
1518
  metadata.page_table = self.target_verify_metadata_topk_normal[
1521
1519
  "page_table"
1522
- ][req_pool_indices, :]
1520
+ ][:bs, :]
1523
1521
 
1524
1522
  # 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
1525
1523
  metadata_expand.cache_seqlens_int32 = (
@@ -1562,9 +1560,7 @@ class FlashAttentionBackend(AttentionBackend):
1562
1560
  metadata.cu_seqlens_k = self.draft_extend_metadata["cu_seqlens_k"][
1563
1561
  : (bs + 1)
1564
1562
  ]
1565
- metadata.page_table = self.draft_extend_metadata["page_table"][
1566
- req_pool_indices, :
1567
- ]
1563
+ metadata.page_table = self.draft_extend_metadata["page_table"][:bs, :]
1568
1564
 
1569
1565
  self.draft_extend_metadata[bs] = metadata
1570
1566
 
@@ -1578,7 +1574,7 @@ class FlashAttentionBackend(AttentionBackend):
1578
1574
  ][: (encoder_bs + 1)]
1579
1575
 
1580
1576
  metadata.encoder_page_table = self.encoder_metadata["encoder_page_table"][
1581
- req_pool_indices, :
1577
+ :bs, :
1582
1578
  ]
1583
1579
 
1584
1580
  self.forward_metadata = metadata
@@ -147,8 +147,8 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
147
147
  block_kv_indices,
148
148
  self.req_to_token.stride(0),
149
149
  max_blocks,
150
- TRITON_PAD_NUM_PAGE_PER_BLOCK,
151
- self.page_size,
150
+ NUM_PAGE_PER_BLOCK=TRITON_PAD_NUM_PAGE_PER_BLOCK,
151
+ PAGED_SIZE=self.page_size,
152
152
  )
153
153
 
154
154
  return block_kv_indices
@@ -204,8 +204,8 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
204
204
  block_kv_indices,
205
205
  self.req_to_token.stride(0),
206
206
  max_seqlen_pad,
207
- TRITON_PAD_NUM_PAGE_PER_BLOCK,
208
- self.page_size,
207
+ NUM_PAGE_PER_BLOCK=TRITON_PAD_NUM_PAGE_PER_BLOCK,
208
+ PAGED_SIZE=self.page_size,
209
209
  )
210
210
 
211
211
  metadata = TRTLLMMLADecodeMetadata(self.cuda_graph_workspace, block_kv_indices)
@@ -248,8 +248,8 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
248
248
  metadata.block_kv_indices,
249
249
  self.req_to_token.stride(0),
250
250
  metadata.block_kv_indices.shape[1],
251
- TRITON_PAD_NUM_PAGE_PER_BLOCK,
252
- self.page_size,
251
+ NUM_PAGE_PER_BLOCK=TRITON_PAD_NUM_PAGE_PER_BLOCK,
252
+ PAGED_SIZE=self.page_size,
253
253
  )
254
254
 
255
255
  def get_cuda_graph_seq_len_fill_value(self) -> int:
@@ -4,7 +4,7 @@ import dataclasses
4
4
  import functools
5
5
  import math
6
6
  from functools import lru_cache, partial
7
- from typing import Any, Optional, Tuple, Union
7
+ from typing import Any, Callable, Optional, Tuple, Union
8
8
 
9
9
  import torch
10
10
  import torch.nn as nn
@@ -308,6 +308,7 @@ class VisionFlash3Attention(nn.Module):
308
308
  cu_seqlens = cu_seqlens.to(dtype=torch.int32).to(q.device)
309
309
  seq_lens = cu_seqlens[1:] - cu_seqlens[:-1]
310
310
  max_seqlen = seq_lens.max().item()
311
+
311
312
  output = flash_attn_varlen_func(
312
313
  q,
313
314
  k,
@@ -358,6 +359,9 @@ class VisionAttention(nn.Module):
358
359
  qkv_bias: bool = True,
359
360
  qk_normalization: bool = False,
360
361
  layer_norm_eps: float = 1e-06,
362
+ customized_position_embedding_applier: Callable[
363
+ [torch.Tensor, torch.Tensor, Any, Any], Tuple[torch.Tensor, torch.Tensor]
364
+ ] = None,
361
365
  **kwargs,
362
366
  ):
363
367
  super().__init__()
@@ -392,6 +396,7 @@ class VisionAttention(nn.Module):
392
396
  self.dummy_dim, eps=layer_norm_eps, var_hidden_size=embed_dim
393
397
  )
394
398
 
399
+ # priority: server_args > passed qkv_backend > sdpa
395
400
  if global_server_args_dict["mm_attention_backend"] is None:
396
401
  if qkv_backend is None:
397
402
  qkv_backend = "sdpa"
@@ -401,6 +406,9 @@ class VisionAttention(nn.Module):
401
406
 
402
407
  print_info_once(f"Using {qkv_backend} as multimodal attention backend.")
403
408
 
409
+ self.customized_position_embedding_applier = (
410
+ customized_position_embedding_applier
411
+ )
404
412
  self.qkv_backend = QKV_BACKEND_IMPL[qkv_backend](
405
413
  head_dim=self.head_size,
406
414
  num_heads=self.num_attention_heads_per_partition,
@@ -473,13 +481,13 @@ class VisionAttention(nn.Module):
473
481
  if x.dim() == 2:
474
482
  x = x.unsqueeze(0)
475
483
  assert x.dim() == 3, x.shape
476
- bsz, s, _ = x.shape
484
+ x_shape = x.shape
485
+ bsz, s, _ = x_shape
477
486
  head = self.num_attention_heads_per_partition
478
487
  kv_head = self.num_attention_kv_heads_per_partition
479
488
  if self.use_qkv_parallel:
480
489
  # [b, s, embed_dim] --> [b, s, embed_dim]
481
490
  qkv, _ = self.qkv_proj(x)
482
-
483
491
  q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
484
492
 
485
493
  # [b, s, embed_dim] --> [b * s, head, head_size]
@@ -508,16 +516,25 @@ class VisionAttention(nn.Module):
508
516
  ]
509
517
 
510
518
  if position_embeddings is not None:
511
- cos, sin = position_embeddings
512
519
  original_shape = q.shape
513
- # [total_tokens, head, head_size]
514
- q = q.view(-1, head, self.head_size)
515
- k = k.view(-1, head, self.head_size)
516
520
 
517
- q, k = apply_rotary_pos_emb(q, k, cos, sin)
521
+ if self.customized_position_embedding_applier is not None:
522
+ q, k = self.customized_position_embedding_applier(
523
+ q, k, position_embeddings, x_shape
524
+ )
525
+ q = q.view(original_shape)
526
+ k = k.view(original_shape)
527
+ else:
528
+ cos, sin = position_embeddings
529
+
530
+ # [total_tokens, head, head_size]
531
+ q = q.view(-1, head, self.head_size)
532
+ k = k.view(-1, head, self.head_size)
533
+
534
+ q, k = apply_rotary_pos_emb(q, k, cos, sin)
518
535
 
519
- q = q.view(original_shape)
520
- k = k.view(original_shape)
536
+ q = q.view(original_shape)
537
+ k = k.view(original_shape)
521
538
 
522
539
  if q.dim() == 4:
523
540
  # [b, s, head, head_size] --> [b * s, head, head_size]
@@ -108,7 +108,7 @@ class LayerScatterModes:
108
108
  if context.is_layer_sparse:
109
109
  return (
110
110
  ScatterMode.SCATTERED
111
- if global_server_args_dict["enable_deepep_moe"]
111
+ if not global_server_args_dict["moe_a2a_backend"].is_standard()
112
112
  else ScatterMode.FULL
113
113
  )
114
114
  else:
@@ -404,14 +404,24 @@ class CommunicateWithAllReduceAndLayerNormFn:
404
404
  if context.attn_dp_size != 1:
405
405
  if context.attn_tp_rank == 0:
406
406
  hidden_states += residual
407
+
408
+ # Perform layernorm on smaller data before comm. Only valid when attn_tp_size is 1 (tp_size == dp_size)
409
+ use_layer_norm_before_gather = context.attn_tp_size == 1
410
+ if use_layer_norm_before_gather:
411
+ residual.copy_(hidden_states)
412
+ if hidden_states.shape[0] != 0:
413
+ hidden_states = layernorm(hidden_states)
414
+
407
415
  hidden_states, local_hidden_states = (
408
416
  forward_batch.gathered_buffer,
409
417
  hidden_states,
410
418
  )
411
419
  dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
412
- dp_scatter(residual, hidden_states, forward_batch)
413
- if hidden_states.shape[0] != 0:
414
- hidden_states = layernorm(hidden_states)
420
+
421
+ if not use_layer_norm_before_gather:
422
+ dp_scatter(residual, hidden_states, forward_batch)
423
+ if hidden_states.shape[0] != 0:
424
+ hidden_states = layernorm(hidden_states)
415
425
  else:
416
426
  # According to the discussion in https://github.com/flashinfer-ai/flashinfer/issues/1223#issuecomment-3047256465
417
427
  # We set the max token num to 128 for allreduce fusion with min-latency case(use_oneshot=True).
@@ -13,10 +13,14 @@ from sglang.srt.distributed import (
13
13
  divide,
14
14
  get_tensor_model_parallel_rank,
15
15
  get_tensor_model_parallel_world_size,
16
+ parallel_state,
16
17
  split_tensor_along_last_dim,
17
18
  tensor_model_parallel_all_gather,
18
19
  tensor_model_parallel_all_reduce,
19
20
  )
21
+ from sglang.srt.distributed.device_communicators.pynccl_allocator import (
22
+ use_symmetric_memory,
23
+ )
20
24
  from sglang.srt.layers.parameter import (
21
25
  BasevLLMParameter,
22
26
  BlockQuantScaleParameter,
@@ -1292,7 +1296,9 @@ class RowParallelLinear(LinearBase):
1292
1296
  # Only fuse bias add into GEMM for rank 0 (this ensures that
1293
1297
  # bias will not get added more than once in TP>1 case)
1294
1298
  bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
1295
- output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_)
1299
+ with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
1300
+ output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_)
1301
+ sm.tag(output_parallel)
1296
1302
  if self.reduce_results and self.tp_size > 1 and not can_fuse_mlp_allreduce:
1297
1303
  output = tensor_model_parallel_all_reduce(output_parallel)
1298
1304
  else:
@@ -83,6 +83,7 @@ class LogitsProcessorOutput:
83
83
  class LogitsMetadata:
84
84
  forward_mode: ForwardMode
85
85
  capture_hidden_mode: CaptureHiddenMode = CaptureHiddenMode.NULL
86
+ next_token_logits_buffer: Optional[torch.Tensor] = None
86
87
 
87
88
  extend_return_logprob: bool = False
88
89
  extend_return_top_logprob: bool = False
@@ -148,6 +149,7 @@ class LogitsMetadata:
148
149
  return cls(
149
150
  forward_mode=forward_batch.forward_mode,
150
151
  capture_hidden_mode=forward_batch.capture_hidden_mode,
152
+ next_token_logits_buffer=forward_batch.next_token_logits_buffer,
151
153
  extend_return_logprob=extend_return_logprob,
152
154
  extend_return_top_logprob=extend_return_top_logprob,
153
155
  extend_token_ids_logprob=extend_token_ids_logprob,
@@ -508,7 +510,13 @@ class LogitsProcessor(nn.Module):
508
510
  )
509
511
  dp_scatter(logits, global_logits, logits_metadata)
510
512
 
511
- logits = logits[:, : self.config.vocab_size].float()
513
+ if logits_metadata.next_token_logits_buffer is not None:
514
+ logits_buffer = logits_metadata.next_token_logits_buffer
515
+ assert logits_buffer.dtype == torch.float
516
+ logits_buffer.copy_(logits[:, : self.config.vocab_size])
517
+ logits = logits_buffer
518
+ else:
519
+ logits = logits[:, : self.config.vocab_size].float()
512
520
 
513
521
  if self.final_logit_softcapping:
514
522
  fused_softcap(logits, self.final_logit_softcapping)
@@ -1,28 +1,17 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import logging
4
- from typing import TYPE_CHECKING, List, Optional, Tuple
4
+ from typing import TYPE_CHECKING, Optional
5
5
 
6
6
  import torch
7
7
 
8
- from sglang.srt.distributed import (
9
- get_tensor_model_parallel_rank,
10
- get_tensor_model_parallel_world_size,
11
- )
12
- from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
8
+ from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
13
9
  from sglang.srt.layers.moe.ep_moe.kernels import (
14
10
  ep_gather,
15
11
  ep_scatter,
16
- gelu_and_mul_triton_kernel,
17
- grouped_gemm_triton,
18
12
  moe_ep_deepgemm_preprocess,
19
13
  post_reorder_triton_kernel,
20
- pre_reorder_triton_kernel,
21
- pre_reorder_triton_kernel_for_cutlass_moe,
22
- run_cutlass_moe_ep_preproess,
23
- run_moe_ep_preproess,
24
14
  silu_and_mul_masked_post_quant_fwd,
25
- silu_and_mul_triton_kernel,
26
15
  tma_align_input_scale,
27
16
  )
28
17
  from sglang.srt.layers.moe.fused_moe_triton.layer import (
@@ -31,11 +20,9 @@ from sglang.srt.layers.moe.fused_moe_triton.layer import (
31
20
  should_use_flashinfer_trtllm_moe,
32
21
  )
33
22
  from sglang.srt.layers.moe.topk import TopKOutput
23
+ from sglang.srt.layers.moe.utils import DeepEPMode
34
24
  from sglang.srt.layers.quantization import deep_gemm_wrapper
35
- from sglang.srt.layers.quantization.base_config import (
36
- QuantizationConfig,
37
- QuantizeMethodBase,
38
- )
25
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
39
26
  from sglang.srt.layers.quantization.fp8 import (
40
27
  Fp8Config,
41
28
  Fp8MoEMethod,
@@ -44,23 +31,13 @@ from sglang.srt.layers.quantization.fp8 import (
44
31
  from sglang.srt.layers.quantization.fp8_kernel import (
45
32
  is_fp8_fnuz,
46
33
  sglang_per_token_group_quant_fp8,
47
- sglang_per_token_quant_fp8,
48
34
  )
49
- from sglang.srt.layers.quantization.unquant import UnquantizedFusedMoEMethod
50
- from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config, W4AFp8MoEMethod
51
35
  from sglang.srt.managers.schedule_batch import global_server_args_dict
52
36
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
53
- from sglang.srt.utils import (
54
- DeepEPMode,
55
- ceil_div,
56
- dispose_tensor,
57
- get_bool_env_var,
58
- is_hip,
59
- is_npu,
60
- )
37
+ from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip, is_npu
61
38
 
62
39
  if TYPE_CHECKING:
63
- from sglang.srt.layers.moe.ep_moe.token_dispatcher import (
40
+ from sglang.srt.layers.moe.token_dispatcher import (
64
41
  DeepEPLLOutput,
65
42
  DeepEPNormalOutput,
66
43
  DispatchOutput,
@@ -119,7 +96,6 @@ class EPMoE(FusedMoE):
119
96
  activation=activation,
120
97
  # apply_router_weight_on_input=apply_router_weight_on_input,
121
98
  routed_scaling_factor=routed_scaling_factor,
122
- enable_ep_moe=True,
123
99
  )
124
100
 
125
101
  self.start_expert_id = self.moe_ep_rank * self.num_local_experts
@@ -304,6 +280,8 @@ class EPMoE(FusedMoE):
304
280
  m_max * self.start_expert_id,
305
281
  BLOCK_SIZE=512,
306
282
  )
283
+ if self.routed_scaling_factor is not None:
284
+ output *= self.routed_scaling_factor
307
285
  return output
308
286
 
309
287
 
@@ -328,7 +306,7 @@ class DeepEPMoE(EPMoE):
328
306
  prefix: str = "",
329
307
  activation: str = "silu",
330
308
  routed_scaling_factor: Optional[float] = None,
331
- deepep_mode: DeepEPMode = DeepEPMode.auto,
309
+ deepep_mode: DeepEPMode = DeepEPMode.AUTO,
332
310
  ):
333
311
  super().__init__(
334
312
  num_experts=num_experts,
@@ -348,7 +326,6 @@ class DeepEPMoE(EPMoE):
348
326
 
349
327
  # TODO: move to the beginning of the file
350
328
  from sglang.srt.distributed.parallel_state import get_tp_group
351
- from sglang.srt.managers.schedule_batch import global_server_args_dict
352
329
  from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher
353
330
 
354
331
  self.deepep_dispatcher = MaybeTboDeepEPDispatcher(
@@ -762,11 +739,10 @@ class FlashInferEPMoE(EPMoE):
762
739
 
763
740
 
764
741
  def get_moe_impl_class():
765
- if global_server_args_dict["enable_deepep_moe"]:
742
+ if global_server_args_dict["moe_a2a_backend"].is_deepep():
766
743
  return DeepEPMoE
767
744
  if global_server_args_dict["enable_flashinfer_cutlass_moe"]:
768
- # Must come before EPMoE because FusedMoE also supports enable_ep_moe
769
745
  return FusedMoE
770
- if global_server_args_dict["enable_ep_moe"]:
746
+ if get_moe_expert_parallel_world_size() > 1:
771
747
  return FlashInferEPMoE if should_use_flashinfer_trtllm_moe() else EPMoE
772
748
  return FlashInferFusedMoE if should_use_flashinfer_trtllm_moe() else FusedMoE
@@ -0,0 +1,146 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 64,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 1,
7
+ "num_warps": 4,
8
+ "num_stages": 2
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 64,
13
+ "BLOCK_SIZE_K": 128,
14
+ "GROUP_SIZE_M": 32,
15
+ "num_warps": 4,
16
+ "num_stages": 2
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 32,
21
+ "BLOCK_SIZE_K": 256,
22
+ "GROUP_SIZE_M": 64,
23
+ "num_warps": 4,
24
+ "num_stages": 2
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 64,
29
+ "BLOCK_SIZE_K": 256,
30
+ "GROUP_SIZE_M": 32,
31
+ "num_warps": 4,
32
+ "num_stages": 3
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 32,
36
+ "BLOCK_SIZE_N": 64,
37
+ "BLOCK_SIZE_K": 256,
38
+ "GROUP_SIZE_M": 1,
39
+ "num_warps": 8,
40
+ "num_stages": 2
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 64,
45
+ "BLOCK_SIZE_K": 256,
46
+ "GROUP_SIZE_M": 32,
47
+ "num_warps": 4,
48
+ "num_stages": 3
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 32,
53
+ "BLOCK_SIZE_K": 128,
54
+ "GROUP_SIZE_M": 16,
55
+ "num_warps": 4,
56
+ "num_stages": 4
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 32,
61
+ "BLOCK_SIZE_K": 128,
62
+ "GROUP_SIZE_M": 16,
63
+ "num_warps": 8,
64
+ "num_stages": 3
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 32,
69
+ "BLOCK_SIZE_K": 256,
70
+ "GROUP_SIZE_M": 16,
71
+ "num_warps": 8,
72
+ "num_stages": 2
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 16,
76
+ "BLOCK_SIZE_N": 64,
77
+ "BLOCK_SIZE_K": 256,
78
+ "GROUP_SIZE_M": 16,
79
+ "num_warps": 4,
80
+ "num_stages": 3
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 16,
84
+ "BLOCK_SIZE_N": 64,
85
+ "BLOCK_SIZE_K": 256,
86
+ "GROUP_SIZE_M": 1,
87
+ "num_warps": 4,
88
+ "num_stages": 2
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 32,
92
+ "BLOCK_SIZE_N": 64,
93
+ "BLOCK_SIZE_K": 256,
94
+ "GROUP_SIZE_M": 32,
95
+ "num_warps": 4,
96
+ "num_stages": 2
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 32,
100
+ "BLOCK_SIZE_N": 64,
101
+ "BLOCK_SIZE_K": 256,
102
+ "GROUP_SIZE_M": 64,
103
+ "num_warps": 4,
104
+ "num_stages": 2
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 64,
108
+ "BLOCK_SIZE_N": 128,
109
+ "BLOCK_SIZE_K": 256,
110
+ "GROUP_SIZE_M": 16,
111
+ "num_warps": 4,
112
+ "num_stages": 2
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 128,
116
+ "BLOCK_SIZE_N": 64,
117
+ "BLOCK_SIZE_K": 256,
118
+ "GROUP_SIZE_M": 16,
119
+ "num_warps": 4,
120
+ "num_stages": 2
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 64,
124
+ "BLOCK_SIZE_N": 128,
125
+ "BLOCK_SIZE_K": 128,
126
+ "GROUP_SIZE_M": 32,
127
+ "num_warps": 4,
128
+ "num_stages": 3
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 128,
132
+ "BLOCK_SIZE_N": 256,
133
+ "BLOCK_SIZE_K": 128,
134
+ "GROUP_SIZE_M": 1,
135
+ "num_warps": 8,
136
+ "num_stages": 3
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 128,
140
+ "BLOCK_SIZE_N": 256,
141
+ "BLOCK_SIZE_K": 64,
142
+ "GROUP_SIZE_M": 32,
143
+ "num_warps": 8,
144
+ "num_stages": 3
145
+ }
146
+ }
@@ -14,10 +14,12 @@ from sglang.srt.distributed import (
14
14
  get_moe_expert_parallel_world_size,
15
15
  get_moe_tensor_parallel_rank,
16
16
  get_moe_tensor_parallel_world_size,
17
- get_tensor_model_parallel_rank,
18
- get_tensor_model_parallel_world_size,
17
+ get_tp_group,
19
18
  tensor_model_parallel_all_reduce,
20
19
  )
20
+ from sglang.srt.distributed.device_communicators.pynccl_allocator import (
21
+ use_symmetric_memory,
22
+ )
21
23
  from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
22
24
  from sglang.srt.layers.moe.topk import StandardTopKOutput
23
25
  from sglang.srt.layers.quantization.base_config import (
@@ -94,7 +96,6 @@ class FusedMoE(torch.nn.Module):
94
96
  no_combine: bool = False,
95
97
  routed_scaling_factor: Optional[float] = None,
96
98
  enable_flashinfer_cutlass_moe: Optional[bool] = False,
97
- enable_ep_moe: Optional[bool] = False,
98
99
  ):
99
100
  super().__init__()
100
101
 
@@ -112,7 +113,6 @@ class FusedMoE(torch.nn.Module):
112
113
  if enable_flashinfer_cutlass_moe and quant_config is None:
113
114
  logger.warning("Disable flashinfer MoE when quantization config is None.")
114
115
  enable_flashinfer_cutlass_moe = False
115
- enable_ep_moe = False
116
116
 
117
117
  self.enable_flashinfer_cutlass_moe = enable_flashinfer_cutlass_moe
118
118
  self.moe_ep_size = get_moe_expert_parallel_world_size()
@@ -121,7 +121,7 @@ class FusedMoE(torch.nn.Module):
121
121
  self.moe_tp_rank = get_moe_tensor_parallel_rank()
122
122
  assert num_experts % self.moe_ep_size == 0
123
123
  self.num_local_experts = num_experts // self.moe_ep_size
124
- if enable_ep_moe:
124
+ if self.moe_ep_size > 1:
125
125
  # TODO(ch-wan): support shared experts fusion
126
126
  # Create a tensor of size num_experts filled with -1
127
127
  self.expert_map_cpu = torch.full((self.num_experts,), -1, dtype=torch.int32)
@@ -630,24 +630,27 @@ class FusedMoE(torch.nn.Module):
630
630
  )
631
631
 
632
632
  # Matrix multiply.
633
- final_hidden_states = self.quant_method.apply(
634
- layer=self,
635
- x=hidden_states,
636
- topk_output=topk_output,
637
- activation=self.activation,
638
- apply_router_weight_on_input=self.apply_router_weight_on_input,
639
- routed_scaling_factor=self.routed_scaling_factor,
640
- **(
641
- dict(
642
- tp_rank=self.moe_tp_rank,
643
- tp_size=self.moe_tp_size,
644
- ep_rank=self.moe_ep_rank,
645
- ep_size=self.moe_ep_size,
646
- )
647
- if self.quant_method.__class__.__name__ == "ModelOptNvFp4FusedMoEMethod"
648
- else {}
649
- ),
650
- )
633
+ with use_symmetric_memory(get_tp_group()) as sm:
634
+ final_hidden_states = self.quant_method.apply(
635
+ layer=self,
636
+ x=hidden_states,
637
+ topk_output=topk_output,
638
+ activation=self.activation,
639
+ apply_router_weight_on_input=self.apply_router_weight_on_input,
640
+ routed_scaling_factor=self.routed_scaling_factor,
641
+ **(
642
+ dict(
643
+ tp_rank=self.moe_tp_rank,
644
+ tp_size=self.moe_tp_size,
645
+ ep_rank=self.moe_ep_rank,
646
+ ep_size=self.moe_ep_size,
647
+ )
648
+ if self.quant_method.__class__.__name__
649
+ == "ModelOptNvFp4FusedMoEMethod"
650
+ else {}
651
+ ),
652
+ )
653
+ sm.tag(final_hidden_states)
651
654
 
652
655
  if self.reduce_results and (self.moe_tp_size > 1 or self.moe_ep_size > 1):
653
656
  final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)