sglang 0.4.10__py3-none-any.whl → 0.4.10.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (92) hide show
  1. sglang/bench_offline_throughput.py +20 -0
  2. sglang/compile_deep_gemm.py +8 -1
  3. sglang/global_config.py +5 -1
  4. sglang/srt/configs/model_config.py +1 -0
  5. sglang/srt/conversation.py +0 -112
  6. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +1 -0
  7. sglang/srt/disaggregation/launch_lb.py +5 -20
  8. sglang/srt/disaggregation/mooncake/conn.py +33 -15
  9. sglang/srt/disaggregation/prefill.py +1 -0
  10. sglang/srt/distributed/device_communicators/pynccl.py +7 -0
  11. sglang/srt/distributed/device_communicators/pynccl_allocator.py +133 -0
  12. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +42 -3
  13. sglang/srt/distributed/parallel_state.py +11 -0
  14. sglang/srt/entrypoints/engine.py +4 -2
  15. sglang/srt/entrypoints/http_server.py +35 -15
  16. sglang/srt/eplb/expert_distribution.py +4 -2
  17. sglang/srt/hf_transformers_utils.py +25 -10
  18. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  19. sglang/srt/layers/attention/flashattention_backend.py +7 -11
  20. sglang/srt/layers/attention/trtllm_mla_backend.py +372 -0
  21. sglang/srt/layers/attention/utils.py +6 -1
  22. sglang/srt/layers/attention/vision.py +27 -10
  23. sglang/srt/layers/communicator.py +14 -4
  24. sglang/srt/layers/linear.py +7 -1
  25. sglang/srt/layers/logits_processor.py +9 -1
  26. sglang/srt/layers/moe/ep_moe/layer.py +29 -68
  27. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=352,device_name=NVIDIA_RTX_6000_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  28. sglang/srt/layers/moe/fused_moe_triton/layer.py +82 -25
  29. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +0 -31
  30. sglang/srt/layers/moe/token_dispatcher/__init__.py +23 -0
  31. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +12 -1
  32. sglang/srt/layers/moe/{ep_moe/token_dispatcher.py → token_dispatcher/deepep.py} +8 -15
  33. sglang/srt/layers/moe/utils.py +43 -0
  34. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +3 -2
  35. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +1 -1
  36. sglang/srt/layers/quantization/fp8.py +57 -1
  37. sglang/srt/layers/quantization/fp8_kernel.py +0 -4
  38. sglang/srt/layers/quantization/w8a8_int8.py +4 -1
  39. sglang/srt/layers/vocab_parallel_embedding.py +7 -1
  40. sglang/srt/lora/lora_registry.py +7 -0
  41. sglang/srt/managers/cache_controller.py +43 -39
  42. sglang/srt/managers/data_parallel_controller.py +52 -2
  43. sglang/srt/managers/io_struct.py +6 -1
  44. sglang/srt/managers/schedule_batch.py +3 -2
  45. sglang/srt/managers/schedule_policy.py +3 -1
  46. sglang/srt/managers/scheduler.py +145 -6
  47. sglang/srt/managers/template_manager.py +25 -22
  48. sglang/srt/managers/tokenizer_manager.py +114 -62
  49. sglang/srt/managers/utils.py +45 -1
  50. sglang/srt/mem_cache/cpp_radix_tree/radix_tree.py +182 -0
  51. sglang/srt/mem_cache/hicache_storage.py +13 -12
  52. sglang/srt/mem_cache/hiradix_cache.py +21 -4
  53. sglang/srt/mem_cache/memory_pool.py +15 -118
  54. sglang/srt/mem_cache/memory_pool_host.py +350 -33
  55. sglang/srt/mem_cache/radix_cache_cpp.py +229 -0
  56. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +8 -2
  57. sglang/srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp +35 -0
  58. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +163 -0
  59. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +238 -0
  60. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +216 -0
  61. sglang/srt/model_executor/cuda_graph_runner.py +42 -4
  62. sglang/srt/model_executor/forward_batch_info.py +13 -3
  63. sglang/srt/model_executor/model_runner.py +13 -1
  64. sglang/srt/model_loader/weight_utils.py +2 -0
  65. sglang/srt/models/deepseek_v2.py +28 -23
  66. sglang/srt/models/glm4_moe.py +85 -22
  67. sglang/srt/models/grok.py +3 -3
  68. sglang/srt/models/llama4.py +13 -2
  69. sglang/srt/models/mixtral.py +3 -3
  70. sglang/srt/models/mllama4.py +428 -19
  71. sglang/srt/models/qwen2_moe.py +1 -4
  72. sglang/srt/models/qwen3_moe.py +7 -8
  73. sglang/srt/models/step3_vl.py +1 -4
  74. sglang/srt/multimodal/processors/base_processor.py +4 -3
  75. sglang/srt/multimodal/processors/gemma3n.py +0 -7
  76. sglang/srt/operations_strategy.py +1 -1
  77. sglang/srt/server_args.py +115 -21
  78. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +18 -0
  79. sglang/srt/two_batch_overlap.py +6 -4
  80. sglang/srt/utils.py +4 -24
  81. sglang/srt/weight_sync/utils.py +1 -1
  82. sglang/test/attention/test_trtllm_mla_backend.py +945 -0
  83. sglang/test/runners.py +2 -2
  84. sglang/test/test_utils.py +3 -3
  85. sglang/version.py +1 -1
  86. {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/METADATA +3 -2
  87. {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/RECORD +92 -81
  88. /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/mooncake_store.py +0 -0
  89. /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/unit_test.py +0 -0
  90. {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/WHEEL +0 -0
  91. {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/licenses/LICENSE +0 -0
  92. {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/top_level.txt +0 -0
@@ -4,7 +4,7 @@ import dataclasses
4
4
  import functools
5
5
  import math
6
6
  from functools import lru_cache, partial
7
- from typing import Any, Optional, Tuple, Union
7
+ from typing import Any, Callable, Optional, Tuple, Union
8
8
 
9
9
  import torch
10
10
  import torch.nn as nn
@@ -308,6 +308,7 @@ class VisionFlash3Attention(nn.Module):
308
308
  cu_seqlens = cu_seqlens.to(dtype=torch.int32).to(q.device)
309
309
  seq_lens = cu_seqlens[1:] - cu_seqlens[:-1]
310
310
  max_seqlen = seq_lens.max().item()
311
+
311
312
  output = flash_attn_varlen_func(
312
313
  q,
313
314
  k,
@@ -358,6 +359,9 @@ class VisionAttention(nn.Module):
358
359
  qkv_bias: bool = True,
359
360
  qk_normalization: bool = False,
360
361
  layer_norm_eps: float = 1e-06,
362
+ customized_position_embedding_applier: Callable[
363
+ [torch.Tensor, torch.Tensor, Any, Any], Tuple[torch.Tensor, torch.Tensor]
364
+ ] = None,
361
365
  **kwargs,
362
366
  ):
363
367
  super().__init__()
@@ -392,6 +396,7 @@ class VisionAttention(nn.Module):
392
396
  self.dummy_dim, eps=layer_norm_eps, var_hidden_size=embed_dim
393
397
  )
394
398
 
399
+ # priority: server_args > passed qkv_backend > sdpa
395
400
  if global_server_args_dict["mm_attention_backend"] is None:
396
401
  if qkv_backend is None:
397
402
  qkv_backend = "sdpa"
@@ -401,6 +406,9 @@ class VisionAttention(nn.Module):
401
406
 
402
407
  print_info_once(f"Using {qkv_backend} as multimodal attention backend.")
403
408
 
409
+ self.customized_position_embedding_applier = (
410
+ customized_position_embedding_applier
411
+ )
404
412
  self.qkv_backend = QKV_BACKEND_IMPL[qkv_backend](
405
413
  head_dim=self.head_size,
406
414
  num_heads=self.num_attention_heads_per_partition,
@@ -473,13 +481,13 @@ class VisionAttention(nn.Module):
473
481
  if x.dim() == 2:
474
482
  x = x.unsqueeze(0)
475
483
  assert x.dim() == 3, x.shape
476
- bsz, s, _ = x.shape
484
+ x_shape = x.shape
485
+ bsz, s, _ = x_shape
477
486
  head = self.num_attention_heads_per_partition
478
487
  kv_head = self.num_attention_kv_heads_per_partition
479
488
  if self.use_qkv_parallel:
480
489
  # [b, s, embed_dim] --> [b, s, embed_dim]
481
490
  qkv, _ = self.qkv_proj(x)
482
-
483
491
  q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
484
492
 
485
493
  # [b, s, embed_dim] --> [b * s, head, head_size]
@@ -508,16 +516,25 @@ class VisionAttention(nn.Module):
508
516
  ]
509
517
 
510
518
  if position_embeddings is not None:
511
- cos, sin = position_embeddings
512
519
  original_shape = q.shape
513
- # [total_tokens, head, head_size]
514
- q = q.view(-1, head, self.head_size)
515
- k = k.view(-1, head, self.head_size)
516
520
 
517
- q, k = apply_rotary_pos_emb(q, k, cos, sin)
521
+ if self.customized_position_embedding_applier is not None:
522
+ q, k = self.customized_position_embedding_applier(
523
+ q, k, position_embeddings, x_shape
524
+ )
525
+ q = q.view(original_shape)
526
+ k = k.view(original_shape)
527
+ else:
528
+ cos, sin = position_embeddings
529
+
530
+ # [total_tokens, head, head_size]
531
+ q = q.view(-1, head, self.head_size)
532
+ k = k.view(-1, head, self.head_size)
533
+
534
+ q, k = apply_rotary_pos_emb(q, k, cos, sin)
518
535
 
519
- q = q.view(original_shape)
520
- k = k.view(original_shape)
536
+ q = q.view(original_shape)
537
+ k = k.view(original_shape)
521
538
 
522
539
  if q.dim() == 4:
523
540
  # [b, s, head, head_size] --> [b * s, head, head_size]
@@ -108,7 +108,7 @@ class LayerScatterModes:
108
108
  if context.is_layer_sparse:
109
109
  return (
110
110
  ScatterMode.SCATTERED
111
- if global_server_args_dict["enable_deepep_moe"]
111
+ if not global_server_args_dict["moe_a2a_backend"].is_standard()
112
112
  else ScatterMode.FULL
113
113
  )
114
114
  else:
@@ -404,14 +404,24 @@ class CommunicateWithAllReduceAndLayerNormFn:
404
404
  if context.attn_dp_size != 1:
405
405
  if context.attn_tp_rank == 0:
406
406
  hidden_states += residual
407
+
408
+ # Perform layernorm on smaller data before comm. Only valid when attn_tp_size is 1 (tp_size == dp_size)
409
+ use_layer_norm_before_gather = context.attn_tp_size == 1
410
+ if use_layer_norm_before_gather:
411
+ residual.copy_(hidden_states)
412
+ if hidden_states.shape[0] != 0:
413
+ hidden_states = layernorm(hidden_states)
414
+
407
415
  hidden_states, local_hidden_states = (
408
416
  forward_batch.gathered_buffer,
409
417
  hidden_states,
410
418
  )
411
419
  dp_gather_partial(hidden_states, local_hidden_states, forward_batch)
412
- dp_scatter(residual, hidden_states, forward_batch)
413
- if hidden_states.shape[0] != 0:
414
- hidden_states = layernorm(hidden_states)
420
+
421
+ if not use_layer_norm_before_gather:
422
+ dp_scatter(residual, hidden_states, forward_batch)
423
+ if hidden_states.shape[0] != 0:
424
+ hidden_states = layernorm(hidden_states)
415
425
  else:
416
426
  # According to the discussion in https://github.com/flashinfer-ai/flashinfer/issues/1223#issuecomment-3047256465
417
427
  # We set the max token num to 128 for allreduce fusion with min-latency case(use_oneshot=True).
@@ -13,10 +13,14 @@ from sglang.srt.distributed import (
13
13
  divide,
14
14
  get_tensor_model_parallel_rank,
15
15
  get_tensor_model_parallel_world_size,
16
+ parallel_state,
16
17
  split_tensor_along_last_dim,
17
18
  tensor_model_parallel_all_gather,
18
19
  tensor_model_parallel_all_reduce,
19
20
  )
21
+ from sglang.srt.distributed.device_communicators.pynccl_allocator import (
22
+ use_symmetric_memory,
23
+ )
20
24
  from sglang.srt.layers.parameter import (
21
25
  BasevLLMParameter,
22
26
  BlockQuantScaleParameter,
@@ -1292,7 +1296,9 @@ class RowParallelLinear(LinearBase):
1292
1296
  # Only fuse bias add into GEMM for rank 0 (this ensures that
1293
1297
  # bias will not get added more than once in TP>1 case)
1294
1298
  bias_ = None if (self.tp_rank > 0 or self.skip_bias_add) else self.bias
1295
- output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_)
1299
+ with use_symmetric_memory(parallel_state.get_tp_group()) as sm:
1300
+ output_parallel = self.quant_method.apply(self, input_parallel, bias=bias_)
1301
+ sm.tag(output_parallel)
1296
1302
  if self.reduce_results and self.tp_size > 1 and not can_fuse_mlp_allreduce:
1297
1303
  output = tensor_model_parallel_all_reduce(output_parallel)
1298
1304
  else:
@@ -83,6 +83,7 @@ class LogitsProcessorOutput:
83
83
  class LogitsMetadata:
84
84
  forward_mode: ForwardMode
85
85
  capture_hidden_mode: CaptureHiddenMode = CaptureHiddenMode.NULL
86
+ next_token_logits_buffer: Optional[torch.Tensor] = None
86
87
 
87
88
  extend_return_logprob: bool = False
88
89
  extend_return_top_logprob: bool = False
@@ -148,6 +149,7 @@ class LogitsMetadata:
148
149
  return cls(
149
150
  forward_mode=forward_batch.forward_mode,
150
151
  capture_hidden_mode=forward_batch.capture_hidden_mode,
152
+ next_token_logits_buffer=forward_batch.next_token_logits_buffer,
151
153
  extend_return_logprob=extend_return_logprob,
152
154
  extend_return_top_logprob=extend_return_top_logprob,
153
155
  extend_token_ids_logprob=extend_token_ids_logprob,
@@ -508,7 +510,13 @@ class LogitsProcessor(nn.Module):
508
510
  )
509
511
  dp_scatter(logits, global_logits, logits_metadata)
510
512
 
511
- logits = logits[:, : self.config.vocab_size].float()
513
+ if logits_metadata.next_token_logits_buffer is not None:
514
+ logits_buffer = logits_metadata.next_token_logits_buffer
515
+ assert logits_buffer.dtype == torch.float
516
+ logits_buffer.copy_(logits[:, : self.config.vocab_size])
517
+ logits = logits_buffer
518
+ else:
519
+ logits = logits[:, : self.config.vocab_size].float()
512
520
 
513
521
  if self.final_logit_softcapping:
514
522
  fused_softcap(logits, self.final_logit_softcapping)
@@ -1,59 +1,43 @@
1
1
  from __future__ import annotations
2
2
 
3
3
  import logging
4
- from typing import TYPE_CHECKING, List, Optional, Tuple
4
+ from typing import TYPE_CHECKING, Optional
5
5
 
6
6
  import torch
7
7
 
8
- from sglang.srt.distributed import (
9
- get_tensor_model_parallel_rank,
10
- get_tensor_model_parallel_world_size,
11
- )
12
- from sglang.srt.eplb.expert_location import get_global_expert_location_metadata
8
+ from sglang.srt.distributed.parallel_state import get_moe_expert_parallel_world_size
13
9
  from sglang.srt.layers.moe.ep_moe.kernels import (
14
10
  ep_gather,
15
11
  ep_scatter,
16
- gelu_and_mul_triton_kernel,
17
- grouped_gemm_triton,
18
12
  moe_ep_deepgemm_preprocess,
19
13
  post_reorder_triton_kernel,
20
- pre_reorder_triton_kernel,
21
- pre_reorder_triton_kernel_for_cutlass_moe,
22
- run_cutlass_moe_ep_preproess,
23
- run_moe_ep_preproess,
24
14
  silu_and_mul_masked_post_quant_fwd,
25
- silu_and_mul_triton_kernel,
26
15
  tma_align_input_scale,
27
16
  )
28
- from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
17
+ from sglang.srt.layers.moe.fused_moe_triton.layer import (
18
+ FlashInferFusedMoE,
19
+ FusedMoE,
20
+ should_use_flashinfer_trtllm_moe,
21
+ )
29
22
  from sglang.srt.layers.moe.topk import TopKOutput
23
+ from sglang.srt.layers.moe.utils import DeepEPMode
30
24
  from sglang.srt.layers.quantization import deep_gemm_wrapper
31
- from sglang.srt.layers.quantization.base_config import (
32
- QuantizationConfig,
33
- QuantizeMethodBase,
25
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
26
+ from sglang.srt.layers.quantization.fp8 import (
27
+ Fp8Config,
28
+ Fp8MoEMethod,
29
+ get_tile_tokens_dim,
34
30
  )
35
- from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
36
31
  from sglang.srt.layers.quantization.fp8_kernel import (
37
32
  is_fp8_fnuz,
38
33
  sglang_per_token_group_quant_fp8,
39
- sglang_per_token_quant_fp8,
40
34
  )
41
- from sglang.srt.layers.quantization.unquant import UnquantizedFusedMoEMethod
42
- from sglang.srt.layers.quantization.w4afp8 import W4AFp8Config, W4AFp8MoEMethod
43
35
  from sglang.srt.managers.schedule_batch import global_server_args_dict
44
36
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
45
- from sglang.srt.utils import (
46
- DeepEPMode,
47
- ceil_div,
48
- dispose_tensor,
49
- get_bool_env_var,
50
- is_hip,
51
- is_npu,
52
- next_power_of_2,
53
- )
37
+ from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip, is_npu
54
38
 
55
39
  if TYPE_CHECKING:
56
- from sglang.srt.layers.moe.ep_moe.token_dispatcher import (
40
+ from sglang.srt.layers.moe.token_dispatcher import (
57
41
  DeepEPLLOutput,
58
42
  DeepEPNormalOutput,
59
43
  DispatchOutput,
@@ -63,10 +47,7 @@ _is_hip = is_hip()
63
47
  _is_npu = is_npu()
64
48
  _is_fp8_fnuz = is_fp8_fnuz()
65
49
  _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
66
- use_flashinfer_trtllm_moe = (
67
- global_server_args_dict["enable_flashinfer_trtllm_moe"]
68
- and global_server_args_dict["enable_ep_moe"]
69
- )
50
+
70
51
 
71
52
  if not (_is_npu or _is_hip):
72
53
  from sgl_kernel import silu_and_mul
@@ -76,26 +57,9 @@ if _use_aiter:
76
57
  from aiter.fused_moe import fused_moe
77
58
  from aiter.ops.shuffle import shuffle_weight
78
59
 
79
- if use_flashinfer_trtllm_moe:
80
- try:
81
- import flashinfer.fused_moe as fi_fused_moe
82
- except ImportError:
83
- fi_fused_moe = None
84
- use_flashinfer_trtllm_moe = False
85
-
86
60
  logger = logging.getLogger(__name__)
87
61
 
88
62
 
89
- def _get_tile_tokens_dim(num_tokens, top_k, num_experts):
90
- # Guess tokens per expert assuming perfect expert distribution first.
91
- num_tokens_per_expert = (num_tokens * top_k) // num_experts
92
- # And pad the number to the next power of 2.
93
- tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
94
- # Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel.
95
- tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
96
- return tile_tokens_dim
97
-
98
-
99
63
  class EPMoE(FusedMoE):
100
64
  """
101
65
  MoE Expert Parallel Impl
@@ -132,7 +96,6 @@ class EPMoE(FusedMoE):
132
96
  activation=activation,
133
97
  # apply_router_weight_on_input=apply_router_weight_on_input,
134
98
  routed_scaling_factor=routed_scaling_factor,
135
- enable_ep_moe=True,
136
99
  )
137
100
 
138
101
  self.start_expert_id = self.moe_ep_rank * self.num_local_experts
@@ -317,6 +280,8 @@ class EPMoE(FusedMoE):
317
280
  m_max * self.start_expert_id,
318
281
  BLOCK_SIZE=512,
319
282
  )
283
+ if self.routed_scaling_factor is not None:
284
+ output *= self.routed_scaling_factor
320
285
  return output
321
286
 
322
287
 
@@ -341,7 +306,7 @@ class DeepEPMoE(EPMoE):
341
306
  prefix: str = "",
342
307
  activation: str = "silu",
343
308
  routed_scaling_factor: Optional[float] = None,
344
- deepep_mode: DeepEPMode = DeepEPMode.auto,
309
+ deepep_mode: DeepEPMode = DeepEPMode.AUTO,
345
310
  ):
346
311
  super().__init__(
347
312
  num_experts=num_experts,
@@ -361,7 +326,6 @@ class DeepEPMoE(EPMoE):
361
326
 
362
327
  # TODO: move to the beginning of the file
363
328
  from sglang.srt.distributed.parallel_state import get_tp_group
364
- from sglang.srt.managers.schedule_batch import global_server_args_dict
365
329
  from sglang.srt.two_batch_overlap import MaybeTboDeepEPDispatcher
366
330
 
367
331
  self.deepep_dispatcher = MaybeTboDeepEPDispatcher(
@@ -731,10 +695,10 @@ class FlashInferEPMoE(EPMoE):
731
695
  self.num_expert_group = num_expert_group
732
696
  self.topk_group = topk_group
733
697
  self.correction_bias = correction_bias
734
- self.use_flashinfer_trtllm_moe = use_flashinfer_trtllm_moe
698
+ self.use_flashinfer_trtllm_moe = should_use_flashinfer_trtllm_moe()
735
699
 
736
700
  def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
737
- assert use_flashinfer_trtllm_moe
701
+ assert self.use_flashinfer_trtllm_moe
738
702
  assert (
739
703
  self.activation == "silu"
740
704
  ), "Only silu is supported for flashinfer blockscale fp8 moe"
@@ -747,8 +711,9 @@ class FlashInferEPMoE(EPMoE):
747
711
  a_q, a_sf = sglang_per_token_group_quant_fp8(hidden_states, self.block_shape[1])
748
712
  # NOTE: scales of hidden states have to be transposed!
749
713
  a_sf_t = a_sf.t().contiguous()
750
- assert fi_fused_moe is not None
751
- return fi_fused_moe.trtllm_fp8_block_scale_moe(
714
+ from flashinfer.fused_moe import trtllm_fp8_block_scale_moe
715
+
716
+ return trtllm_fp8_block_scale_moe(
752
717
  routing_logits=router_logits.to(torch.float32),
753
718
  routing_bias=self.correction_bias.to(hidden_states.dtype),
754
719
  hidden_states=a_q,
@@ -765,7 +730,7 @@ class FlashInferEPMoE(EPMoE):
765
730
  local_expert_offset=self.start_expert_id,
766
731
  local_num_experts=self.num_local_experts,
767
732
  routed_scaling_factor=self.routed_scaling_factor,
768
- tile_tokens_dim=_get_tile_tokens_dim(
733
+ tile_tokens_dim=get_tile_tokens_dim(
769
734
  hidden_states.shape[0], self.top_k, self.num_experts
770
735
  ),
771
736
  routing_method_type=2, # DeepSeek-styled routing method
@@ -774,14 +739,10 @@ class FlashInferEPMoE(EPMoE):
774
739
 
775
740
 
776
741
  def get_moe_impl_class():
777
- if global_server_args_dict["enable_deepep_moe"]:
742
+ if global_server_args_dict["moe_a2a_backend"].is_deepep():
778
743
  return DeepEPMoE
779
744
  if global_server_args_dict["enable_flashinfer_cutlass_moe"]:
780
- # Must come before EPMoE because FusedMoE also supports enable_ep_moe
781
745
  return FusedMoE
782
- if use_flashinfer_trtllm_moe:
783
- # Must come before EPMoE because FusedMoE also supports enable_ep_moe
784
- return FlashInferEPMoE
785
- if global_server_args_dict["enable_ep_moe"]:
786
- return EPMoE
787
- return FusedMoE
746
+ if get_moe_expert_parallel_world_size() > 1:
747
+ return FlashInferEPMoE if should_use_flashinfer_trtllm_moe() else EPMoE
748
+ return FlashInferFusedMoE if should_use_flashinfer_trtllm_moe() else FusedMoE
@@ -0,0 +1,146 @@
1
+ {
2
+ "1": {
3
+ "BLOCK_SIZE_M": 16,
4
+ "BLOCK_SIZE_N": 64,
5
+ "BLOCK_SIZE_K": 128,
6
+ "GROUP_SIZE_M": 1,
7
+ "num_warps": 4,
8
+ "num_stages": 2
9
+ },
10
+ "2": {
11
+ "BLOCK_SIZE_M": 16,
12
+ "BLOCK_SIZE_N": 64,
13
+ "BLOCK_SIZE_K": 128,
14
+ "GROUP_SIZE_M": 32,
15
+ "num_warps": 4,
16
+ "num_stages": 2
17
+ },
18
+ "4": {
19
+ "BLOCK_SIZE_M": 16,
20
+ "BLOCK_SIZE_N": 32,
21
+ "BLOCK_SIZE_K": 256,
22
+ "GROUP_SIZE_M": 64,
23
+ "num_warps": 4,
24
+ "num_stages": 2
25
+ },
26
+ "8": {
27
+ "BLOCK_SIZE_M": 16,
28
+ "BLOCK_SIZE_N": 64,
29
+ "BLOCK_SIZE_K": 256,
30
+ "GROUP_SIZE_M": 32,
31
+ "num_warps": 4,
32
+ "num_stages": 3
33
+ },
34
+ "16": {
35
+ "BLOCK_SIZE_M": 32,
36
+ "BLOCK_SIZE_N": 64,
37
+ "BLOCK_SIZE_K": 256,
38
+ "GROUP_SIZE_M": 1,
39
+ "num_warps": 8,
40
+ "num_stages": 2
41
+ },
42
+ "24": {
43
+ "BLOCK_SIZE_M": 16,
44
+ "BLOCK_SIZE_N": 64,
45
+ "BLOCK_SIZE_K": 256,
46
+ "GROUP_SIZE_M": 32,
47
+ "num_warps": 4,
48
+ "num_stages": 3
49
+ },
50
+ "32": {
51
+ "BLOCK_SIZE_M": 16,
52
+ "BLOCK_SIZE_N": 32,
53
+ "BLOCK_SIZE_K": 128,
54
+ "GROUP_SIZE_M": 16,
55
+ "num_warps": 4,
56
+ "num_stages": 4
57
+ },
58
+ "48": {
59
+ "BLOCK_SIZE_M": 16,
60
+ "BLOCK_SIZE_N": 32,
61
+ "BLOCK_SIZE_K": 128,
62
+ "GROUP_SIZE_M": 16,
63
+ "num_warps": 8,
64
+ "num_stages": 3
65
+ },
66
+ "64": {
67
+ "BLOCK_SIZE_M": 16,
68
+ "BLOCK_SIZE_N": 32,
69
+ "BLOCK_SIZE_K": 256,
70
+ "GROUP_SIZE_M": 16,
71
+ "num_warps": 8,
72
+ "num_stages": 2
73
+ },
74
+ "96": {
75
+ "BLOCK_SIZE_M": 16,
76
+ "BLOCK_SIZE_N": 64,
77
+ "BLOCK_SIZE_K": 256,
78
+ "GROUP_SIZE_M": 16,
79
+ "num_warps": 4,
80
+ "num_stages": 3
81
+ },
82
+ "128": {
83
+ "BLOCK_SIZE_M": 16,
84
+ "BLOCK_SIZE_N": 64,
85
+ "BLOCK_SIZE_K": 256,
86
+ "GROUP_SIZE_M": 1,
87
+ "num_warps": 4,
88
+ "num_stages": 2
89
+ },
90
+ "256": {
91
+ "BLOCK_SIZE_M": 32,
92
+ "BLOCK_SIZE_N": 64,
93
+ "BLOCK_SIZE_K": 256,
94
+ "GROUP_SIZE_M": 32,
95
+ "num_warps": 4,
96
+ "num_stages": 2
97
+ },
98
+ "512": {
99
+ "BLOCK_SIZE_M": 32,
100
+ "BLOCK_SIZE_N": 64,
101
+ "BLOCK_SIZE_K": 256,
102
+ "GROUP_SIZE_M": 64,
103
+ "num_warps": 4,
104
+ "num_stages": 2
105
+ },
106
+ "1024": {
107
+ "BLOCK_SIZE_M": 64,
108
+ "BLOCK_SIZE_N": 128,
109
+ "BLOCK_SIZE_K": 256,
110
+ "GROUP_SIZE_M": 16,
111
+ "num_warps": 4,
112
+ "num_stages": 2
113
+ },
114
+ "1536": {
115
+ "BLOCK_SIZE_M": 128,
116
+ "BLOCK_SIZE_N": 64,
117
+ "BLOCK_SIZE_K": 256,
118
+ "GROUP_SIZE_M": 16,
119
+ "num_warps": 4,
120
+ "num_stages": 2
121
+ },
122
+ "2048": {
123
+ "BLOCK_SIZE_M": 64,
124
+ "BLOCK_SIZE_N": 128,
125
+ "BLOCK_SIZE_K": 128,
126
+ "GROUP_SIZE_M": 32,
127
+ "num_warps": 4,
128
+ "num_stages": 3
129
+ },
130
+ "3072": {
131
+ "BLOCK_SIZE_M": 128,
132
+ "BLOCK_SIZE_N": 256,
133
+ "BLOCK_SIZE_K": 128,
134
+ "GROUP_SIZE_M": 1,
135
+ "num_warps": 8,
136
+ "num_stages": 3
137
+ },
138
+ "4096": {
139
+ "BLOCK_SIZE_M": 128,
140
+ "BLOCK_SIZE_N": 256,
141
+ "BLOCK_SIZE_K": 64,
142
+ "GROUP_SIZE_M": 32,
143
+ "num_warps": 8,
144
+ "num_stages": 3
145
+ }
146
+ }