sglang 0.4.10__py3-none-any.whl → 0.4.10.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (35) hide show
  1. sglang/bench_offline_throughput.py +20 -0
  2. sglang/srt/configs/model_config.py +1 -0
  3. sglang/srt/disaggregation/launch_lb.py +5 -20
  4. sglang/srt/disaggregation/mooncake/conn.py +33 -15
  5. sglang/srt/layers/attention/trtllm_mla_backend.py +372 -0
  6. sglang/srt/layers/attention/utils.py +6 -1
  7. sglang/srt/layers/moe/ep_moe/layer.py +19 -34
  8. sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -2
  9. sglang/srt/layers/quantization/fp8.py +52 -0
  10. sglang/srt/layers/quantization/w8a8_int8.py +4 -1
  11. sglang/srt/managers/cache_controller.py +35 -35
  12. sglang/srt/managers/scheduler.py +1 -0
  13. sglang/srt/mem_cache/hicache_storage.py +15 -6
  14. sglang/srt/mem_cache/hiradix_cache.py +21 -4
  15. sglang/srt/mem_cache/memory_pool.py +15 -118
  16. sglang/srt/mem_cache/memory_pool_host.py +350 -33
  17. sglang/srt/mem_cache/nixl/hicache_nixl.py +163 -0
  18. sglang/srt/mem_cache/nixl/nixl_utils.py +238 -0
  19. sglang/srt/mem_cache/nixl/test_hicache_nixl_storage.py +216 -0
  20. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +8 -2
  21. sglang/srt/model_executor/cuda_graph_runner.py +25 -1
  22. sglang/srt/model_executor/model_runner.py +8 -1
  23. sglang/srt/model_loader/weight_utils.py +2 -0
  24. sglang/srt/models/deepseek_v2.py +5 -6
  25. sglang/srt/models/glm4_moe.py +3 -3
  26. sglang/srt/models/step3_vl.py +0 -3
  27. sglang/srt/server_args.py +40 -6
  28. sglang/srt/utils.py +1 -0
  29. sglang/test/attention/test_trtllm_mla_backend.py +945 -0
  30. sglang/version.py +1 -1
  31. {sglang-0.4.10.dist-info → sglang-0.4.10.post1.dist-info}/METADATA +1 -1
  32. {sglang-0.4.10.dist-info → sglang-0.4.10.post1.dist-info}/RECORD +35 -30
  33. {sglang-0.4.10.dist-info → sglang-0.4.10.post1.dist-info}/WHEEL +0 -0
  34. {sglang-0.4.10.dist-info → sglang-0.4.10.post1.dist-info}/licenses/LICENSE +0 -0
  35. {sglang-0.4.10.dist-info → sglang-0.4.10.post1.dist-info}/top_level.txt +0 -0
@@ -418,6 +418,26 @@ if __name__ == "__main__":
418
418
  ServerArgs.add_cli_args(parser)
419
419
  BenchArgs.add_cli_args(parser)
420
420
  args = parser.parse_args()
421
+
422
+ # handling ModelScope model downloads
423
+ if os.getenv("SGLANG_USE_MODELSCOPE", "false").lower() in ("true", "1"):
424
+ if os.path.exists(args.model_path):
425
+ print(f"Using local model path: {args.model_path}")
426
+ else:
427
+ try:
428
+ from modelscope import snapshot_download
429
+
430
+ print(f"Using ModelScope to download model: {args.model_path}")
431
+
432
+ # download the model and replace args.model_path
433
+ args.model_path = snapshot_download(
434
+ args.model_path,
435
+ )
436
+ print(f"Model downloaded to: {args.model_path}")
437
+ except Exception as e:
438
+ print(f"ModelScope download failed: {str(e)}")
439
+ raise e
440
+
421
441
  server_args = ServerArgs.from_cli_args(args)
422
442
  bench_args = BenchArgs.from_cli_args(args)
423
443
 
@@ -112,6 +112,7 @@ class ModelConfig:
112
112
  mm_disabled_models = [
113
113
  "Gemma3ForConditionalGeneration",
114
114
  "Llama4ForConditionalGeneration",
115
+ "Step3VLForConditionalGeneration",
115
116
  ]
116
117
  if self.hf_config.architectures[0] in mm_disabled_models:
117
118
  enable_multimodal = False
@@ -1,6 +1,8 @@
1
1
  import argparse
2
2
  import dataclasses
3
3
 
4
+ from sglang.srt.disaggregation.mini_lb import PrefillConfig, run
5
+
4
6
 
5
7
  @dataclasses.dataclass
6
8
  class LBArgs:
@@ -18,7 +20,7 @@ class LBArgs:
18
20
  parser.add_argument(
19
21
  "--rust-lb",
20
22
  action="store_true",
21
- help="Use Rust load balancer",
23
+ help="Deprecated, please use SGLang Router instead, this argument will have no effect.",
22
24
  )
23
25
  parser.add_argument(
24
26
  "--host",
@@ -115,25 +117,8 @@ def main():
115
117
  args = parser.parse_args()
116
118
  lb_args = LBArgs.from_cli_args(args)
117
119
 
118
- if lb_args.rust_lb:
119
- from sgl_pdlb._rust import LoadBalancer as RustLB
120
-
121
- RustLB(
122
- host=lb_args.host,
123
- port=lb_args.port,
124
- policy=lb_args.policy,
125
- prefill_infos=lb_args.prefill_infos,
126
- decode_infos=lb_args.decode_infos,
127
- log_interval=lb_args.log_interval,
128
- timeout=lb_args.timeout,
129
- ).start()
130
- else:
131
- from sglang.srt.disaggregation.mini_lb import PrefillConfig, run
132
-
133
- prefill_configs = [
134
- PrefillConfig(url, port) for url, port in lb_args.prefill_infos
135
- ]
136
- run(prefill_configs, lb_args.decode_infos, lb_args.host, lb_args.port)
120
+ prefill_configs = [PrefillConfig(url, port) for url, port in lb_args.prefill_infos]
121
+ run(prefill_configs, lb_args.decode_infos, lb_args.host, lb_args.port)
137
122
 
138
123
 
139
124
  if __name__ == "__main__":
@@ -37,6 +37,7 @@ from sglang.srt.disaggregation.utils import DisaggregationMode
37
37
  from sglang.srt.server_args import ServerArgs
38
38
  from sglang.srt.utils import (
39
39
  format_tcp_address,
40
+ get_bool_env_var,
40
41
  get_free_port,
41
42
  get_int_env_var,
42
43
  get_ip,
@@ -198,6 +199,10 @@ class MooncakeKVManager(BaseKVManager):
198
199
  self.bootstrap_timeout = get_int_env_var(
199
200
  "SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT", 300
200
201
  )
202
+
203
+ self.enable_custom_mem_pool = get_bool_env_var(
204
+ "SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false"
205
+ )
201
206
  elif self.disaggregation_mode == DisaggregationMode.DECODE:
202
207
  self.heartbeat_failures = {}
203
208
  self.session_pool = defaultdict(requests.Session)
@@ -258,6 +263,26 @@ class MooncakeKVManager(BaseKVManager):
258
263
  socket.connect(endpoint)
259
264
  return socket
260
265
 
266
+ def _transfer_data(self, mooncake_session_id, transfer_blocks):
267
+ if not transfer_blocks:
268
+ return 0
269
+
270
+ # TODO(shangming): Fix me when nvlink_transport of Mooncake is bug-free
271
+ if self.enable_custom_mem_pool:
272
+ # batch_transfer_sync has a higher chance to trigger an accuracy drop for MNNVL, fallback to transfer_sync temporarily
273
+ for src_addr, dst_addr, length in transfer_blocks:
274
+ status = self.engine.transfer_sync(
275
+ mooncake_session_id, src_addr, dst_addr, length
276
+ )
277
+ if status != 0:
278
+ return status
279
+ return 0
280
+ else:
281
+ src_addrs, dst_addrs, lengths = zip(*transfer_blocks)
282
+ return self.engine.batch_transfer_sync(
283
+ mooncake_session_id, list(src_addrs), list(dst_addrs), list(lengths)
284
+ )
285
+
261
286
  def send_kvcache(
262
287
  self,
263
288
  mooncake_session_id: str,
@@ -283,17 +308,14 @@ class MooncakeKVManager(BaseKVManager):
283
308
 
284
309
  # Worker function for processing a single layer
285
310
  def process_layer(src_ptr: int, dst_ptr: int, item_len: int) -> int:
311
+ transfer_blocks = []
286
312
  for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks):
287
313
  src_addr = src_ptr + int(prefill_index[0]) * item_len
288
314
  dst_addr = dst_ptr + int(decode_index[0]) * item_len
289
315
  length = item_len * len(prefill_index)
316
+ transfer_blocks.append((src_addr, dst_addr, length))
290
317
 
291
- status = self.engine.transfer_sync(
292
- mooncake_session_id, src_addr, dst_addr, length
293
- )
294
- if status != 0:
295
- return status
296
- return 0
318
+ return self._transfer_data(mooncake_session_id, transfer_blocks)
297
319
 
298
320
  futures = [
299
321
  executor.submit(
@@ -465,21 +487,17 @@ class MooncakeKVManager(BaseKVManager):
465
487
  dst_aux_ptrs: list[int],
466
488
  dst_aux_index: int,
467
489
  ):
468
- src_addr_list = []
469
- dst_addr_list = []
470
- length_list = []
490
+ transfer_blocks = []
471
491
  prefill_aux_ptrs = self.kv_args.aux_data_ptrs
472
492
  prefill_aux_item_lens = self.kv_args.aux_item_lens
493
+
473
494
  for i, dst_aux_ptr in enumerate(dst_aux_ptrs):
474
495
  length = prefill_aux_item_lens[i]
475
496
  src_addr = prefill_aux_ptrs[i] + length * prefill_aux_index
476
497
  dst_addr = dst_aux_ptrs[i] + length * dst_aux_index
477
- src_addr_list.append(src_addr)
478
- dst_addr_list.append(dst_addr)
479
- length_list.append(length)
480
- return self.engine.batch_transfer_sync(
481
- mooncake_session_id, src_addr_list, dst_addr_list, length_list
482
- )
498
+ transfer_blocks.append((src_addr, dst_addr, length))
499
+
500
+ return self._transfer_data(mooncake_session_id, transfer_blocks)
483
501
 
484
502
  def sync_status_to_decode_endpoint(
485
503
  self, remote: str, dst_port: int, room: int, status: int, prefill_rank: int
@@ -0,0 +1,372 @@
1
+ from __future__ import annotations
2
+
3
+ """
4
+ Support attention backend for TRTLLM MLA kernels from flashinfer.
5
+ """
6
+
7
+ import math
8
+ from dataclasses import dataclass
9
+ from typing import TYPE_CHECKING, Optional, Union
10
+
11
+ import torch
12
+ import triton
13
+
14
+ from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
15
+ from sglang.srt.layers.attention.utils import (
16
+ TRITON_PAD_NUM_PAGE_PER_BLOCK,
17
+ create_flashmla_kv_indices_triton,
18
+ )
19
+ from sglang.srt.layers.dp_attention import get_attention_tp_size
20
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
21
+ from sglang.srt.utils import is_flashinfer_available
22
+
23
+ if is_flashinfer_available():
24
+ import flashinfer
25
+
26
+ if TYPE_CHECKING:
27
+ from sglang.srt.layers.radix_attention import RadixAttention
28
+ from sglang.srt.model_executor.model_runner import ModelRunner
29
+ from sglang.srt.speculative.spec_info import SpecInfo
30
+
31
+ # Constants
32
+ DEFAULT_WORKSPACE_SIZE_MB = 128 # Memory workspace size in MB
33
+
34
+ # Block constraint from flashinfer requirements
35
+ # From flashinfer.decode._check_trtllm_gen_mla_shape:
36
+ # block_num % (128 / block_size) == 0
37
+ # This imposes that the total number of blocks must be divisible by
38
+ # (128 / block_size). We capture the 128 constant here so we can
39
+ # compute the LCM with other padding constraints.
40
+ TRTLLM_BLOCK_CONSTRAINT = 128
41
+
42
+
43
+ @dataclass
44
+ class TRTLLMMLADecodeMetadata:
45
+ """Metadata for TRTLLM MLA decode operations."""
46
+
47
+ workspace: Optional[torch.Tensor] = None
48
+ block_kv_indices: Optional[torch.Tensor] = None
49
+
50
+
51
+ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
52
+ """TRTLLM MLA attention kernel from flashinfer."""
53
+
54
+ def __init__(
55
+ self,
56
+ model_runner: ModelRunner,
57
+ skip_prefill: bool = False,
58
+ kv_indptr_buf: Optional[torch.Tensor] = None,
59
+ q_indptr_decode_buf: Optional[torch.Tensor] = None,
60
+ ):
61
+ super().__init__(model_runner, skip_prefill, kv_indptr_buf, q_indptr_decode_buf)
62
+
63
+ config = model_runner.model_config
64
+
65
+ # Model parameters
66
+ self.num_q_heads = config.num_attention_heads // get_attention_tp_size()
67
+ self.num_kv_heads = config.get_num_kv_heads(get_attention_tp_size())
68
+ self.num_local_heads = config.num_attention_heads // get_attention_tp_size()
69
+
70
+ # MLA-specific dimensions
71
+ self.kv_lora_rank = config.kv_lora_rank
72
+ self.qk_nope_head_dim = config.qk_nope_head_dim
73
+ self.qk_rope_head_dim = config.qk_rope_head_dim
74
+ self.v_head_dim = config.v_head_dim
75
+ self.kv_cache_dim = self.kv_lora_rank + self.qk_rope_head_dim
76
+
77
+ # Runtime parameters
78
+ self.scaling = config.scaling
79
+ self.data_type = model_runner.kv_cache_dtype
80
+ self.q_data_type = model_runner.dtype
81
+ self.page_size = model_runner.page_size
82
+ self.req_to_token = model_runner.req_to_token_pool.req_to_token
83
+
84
+ # Workspace allocation
85
+ self.workspace_size = DEFAULT_WORKSPACE_SIZE_MB * 1024 * 1024
86
+ self.workspace_buffer = torch.empty(
87
+ self.workspace_size, dtype=torch.int8, device=self.device
88
+ )
89
+
90
+ # CUDA graph state
91
+ self.decode_cuda_graph_metadata = {}
92
+ self.cuda_graph_kv_indices = None
93
+ self.forward_metadata: Union[TRTLLMMLADecodeMetadata, None] = None
94
+
95
+ def _calc_padded_blocks(self, max_seq_len: int) -> int:
96
+ """
97
+ Calculate padded block count that satisfies both TRT-LLM and Triton constraints.
98
+
99
+ Args:
100
+ max_seq_len: Maximum sequence length in tokens
101
+
102
+ Returns:
103
+ Number of blocks padded to satisfy all constraints
104
+ """
105
+ blocks = triton.cdiv(max_seq_len, self.page_size)
106
+
107
+ # Apply dual constraints (take LCM to satisfy both):
108
+ # 1. TRT-LLM: block_num % (128 / page_size) == 0
109
+ # 2. Triton: page table builder uses 64-index bursts, needs multiple of 64
110
+ trtllm_constraint = TRTLLM_BLOCK_CONSTRAINT // self.page_size
111
+ constraint_lcm = math.lcm(trtllm_constraint, TRITON_PAD_NUM_PAGE_PER_BLOCK)
112
+
113
+ if blocks % constraint_lcm != 0:
114
+ blocks = triton.cdiv(blocks, constraint_lcm) * constraint_lcm
115
+ return blocks
116
+
117
+ def _create_block_kv_indices(
118
+ self,
119
+ batch_size: int,
120
+ max_blocks: int,
121
+ req_pool_indices: torch.Tensor,
122
+ seq_lens: torch.Tensor,
123
+ device: torch.device,
124
+ ) -> torch.Tensor:
125
+ """
126
+ Create block KV indices tensor using Triton kernel.
127
+
128
+ Args:
129
+ batch_size: Batch size
130
+ max_blocks: Maximum number of blocks per sequence
131
+ req_pool_indices: Request pool indices
132
+ seq_lens: Sequence lengths
133
+ device: Target device
134
+
135
+ Returns:
136
+ Block KV indices tensor
137
+ """
138
+ block_kv_indices = torch.full(
139
+ (batch_size, max_blocks), -1, dtype=torch.int32, device=device
140
+ )
141
+
142
+ create_flashmla_kv_indices_triton[(batch_size,)](
143
+ self.req_to_token,
144
+ req_pool_indices,
145
+ seq_lens,
146
+ None,
147
+ block_kv_indices,
148
+ self.req_to_token.stride(0),
149
+ max_blocks,
150
+ TRITON_PAD_NUM_PAGE_PER_BLOCK,
151
+ self.page_size,
152
+ )
153
+
154
+ return block_kv_indices
155
+
156
+ def init_cuda_graph_state(
157
+ self,
158
+ max_bs: int,
159
+ max_num_tokens: int,
160
+ kv_indices_buf: Optional[torch.Tensor] = None,
161
+ ):
162
+ """Initialize CUDA graph state for TRTLLM MLA."""
163
+ max_blocks_per_seq = self._calc_padded_blocks(self.max_context_len)
164
+
165
+ self.cuda_graph_kv_indices = torch.full(
166
+ (max_bs, max_blocks_per_seq), -1, dtype=torch.int32, device=self.device
167
+ )
168
+ self.cuda_graph_workspace = torch.empty(
169
+ self.workspace_size, dtype=torch.int8, device=self.device
170
+ )
171
+
172
+ def init_forward_metadata_capture_cuda_graph(
173
+ self,
174
+ bs: int,
175
+ num_tokens: int,
176
+ req_pool_indices: torch.Tensor,
177
+ seq_lens: torch.Tensor,
178
+ encoder_lens: Optional[torch.Tensor],
179
+ forward_mode: ForwardMode,
180
+ spec_info: Optional[SpecInfo],
181
+ ):
182
+ """Initialize metadata for CUDA graph capture."""
183
+ # Delegate to parent for non-decode modes or when speculative execution is used.
184
+ if not (forward_mode.is_decode_or_idle() and spec_info is None):
185
+ return super().init_forward_metadata_capture_cuda_graph(
186
+ bs,
187
+ num_tokens,
188
+ req_pool_indices,
189
+ seq_lens,
190
+ encoder_lens,
191
+ forward_mode,
192
+ spec_info,
193
+ )
194
+
195
+ # Custom fast-path for decode/idle without speculative execution.
196
+ max_seqlen_pad = self._calc_padded_blocks(seq_lens.max().item())
197
+ block_kv_indices = self.cuda_graph_kv_indices[:bs, :max_seqlen_pad]
198
+
199
+ create_flashmla_kv_indices_triton[(bs,)](
200
+ self.req_to_token,
201
+ req_pool_indices,
202
+ seq_lens,
203
+ None,
204
+ block_kv_indices,
205
+ self.req_to_token.stride(0),
206
+ max_seqlen_pad,
207
+ TRITON_PAD_NUM_PAGE_PER_BLOCK,
208
+ self.page_size,
209
+ )
210
+
211
+ metadata = TRTLLMMLADecodeMetadata(self.cuda_graph_workspace, block_kv_indices)
212
+ self.decode_cuda_graph_metadata[bs] = metadata
213
+ self.forward_metadata = metadata
214
+
215
+ def init_forward_metadata_replay_cuda_graph(
216
+ self,
217
+ bs: int,
218
+ req_pool_indices: torch.Tensor,
219
+ seq_lens: torch.Tensor,
220
+ seq_lens_sum: int,
221
+ encoder_lens: Optional[torch.Tensor],
222
+ forward_mode: ForwardMode,
223
+ spec_info: Optional[SpecInfo],
224
+ seq_lens_cpu: Optional[torch.Tensor],
225
+ ):
226
+ """Replay CUDA graph with new inputs."""
227
+ # Delegate to parent for non-decode modes or when speculative execution is used.
228
+ if not (forward_mode.is_decode_or_idle() and spec_info is None):
229
+ return super().init_forward_metadata_replay_cuda_graph(
230
+ bs,
231
+ req_pool_indices,
232
+ seq_lens,
233
+ seq_lens_sum,
234
+ encoder_lens,
235
+ forward_mode,
236
+ spec_info,
237
+ seq_lens_cpu,
238
+ )
239
+
240
+ metadata = self.decode_cuda_graph_metadata[bs]
241
+
242
+ # Update block indices for new sequences.
243
+ create_flashmla_kv_indices_triton[(bs,)](
244
+ self.req_to_token,
245
+ req_pool_indices[:bs],
246
+ seq_lens[:bs],
247
+ None,
248
+ metadata.block_kv_indices,
249
+ self.req_to_token.stride(0),
250
+ metadata.block_kv_indices.shape[1],
251
+ TRITON_PAD_NUM_PAGE_PER_BLOCK,
252
+ self.page_size,
253
+ )
254
+
255
+ def get_cuda_graph_seq_len_fill_value(self) -> int:
256
+ """Get the fill value for sequence lengths in CUDA graph."""
257
+ return 1
258
+
259
+ def init_forward_metadata(self, forward_batch: ForwardBatch):
260
+ """Initialize the metadata for a forward pass."""
261
+ # Delegate to parent for non-decode modes or when speculative execution is used.
262
+ if not (
263
+ forward_batch.forward_mode.is_decode_or_idle()
264
+ and forward_batch.spec_info is None
265
+ ):
266
+ return super().init_forward_metadata(forward_batch)
267
+
268
+ bs = forward_batch.batch_size
269
+
270
+ # Get maximum sequence length.
271
+ if getattr(forward_batch, "seq_lens_cpu", None) is not None:
272
+ max_seq = forward_batch.seq_lens_cpu.max().item()
273
+ else:
274
+ max_seq = forward_batch.seq_lens.max().item()
275
+
276
+ max_seqlen_pad = self._calc_padded_blocks(max_seq)
277
+ block_kv_indices = self._create_block_kv_indices(
278
+ bs,
279
+ max_seqlen_pad,
280
+ forward_batch.req_pool_indices,
281
+ forward_batch.seq_lens,
282
+ forward_batch.seq_lens.device,
283
+ )
284
+
285
+ self.forward_metadata = TRTLLMMLADecodeMetadata(
286
+ self.workspace_buffer, block_kv_indices
287
+ )
288
+ forward_batch.decode_trtllm_mla_metadata = self.forward_metadata
289
+
290
+ def forward_decode(
291
+ self,
292
+ q: torch.Tensor,
293
+ k: torch.Tensor,
294
+ v: torch.Tensor,
295
+ layer: RadixAttention,
296
+ forward_batch: ForwardBatch,
297
+ save_kv_cache: bool = True,
298
+ q_rope: Optional[torch.Tensor] = None,
299
+ k_rope: Optional[torch.Tensor] = None,
300
+ ) -> torch.Tensor:
301
+ """Run forward for decode using TRTLLM MLA kernel."""
302
+ # Save KV cache if requested
303
+ if k is not None and save_kv_cache:
304
+ cache_loc = forward_batch.out_cache_loc
305
+ if k_rope is not None:
306
+ forward_batch.token_to_kv_pool.set_mla_kv_buffer(
307
+ layer, cache_loc, k, k_rope
308
+ )
309
+ elif v is not None:
310
+ forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
311
+
312
+ # Prepare query tensor inline
313
+ if q_rope is not None:
314
+ # q contains NOPE part (v_head_dim)
315
+ q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
316
+ q_rope_reshaped = q_rope.view(
317
+ -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
318
+ )
319
+ query = torch.cat([q_nope, q_rope_reshaped], dim=-1)
320
+ else:
321
+ # q already has both parts
322
+ query = q.view(-1, layer.tp_q_head_num, layer.head_dim)
323
+
324
+ # Ensure query has shape [bs, acc_q_len, num_q_heads, head_dim] when seq_len 1
325
+ if query.dim() == 3:
326
+ query = query.unsqueeze(1)
327
+
328
+ # Prepare KV cache inline
329
+ k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
330
+ pages = k_cache.view(-1, self.page_size, self.kv_cache_dim)
331
+ # TRT-LLM expects single KV data with extra dimension
332
+ kv_cache = pages.unsqueeze(1)
333
+
334
+ # Get metadata
335
+ metadata = (
336
+ getattr(forward_batch, "decode_trtllm_mla_metadata", None)
337
+ or self.forward_metadata
338
+ )
339
+
340
+ # Scale computation for TRTLLM MLA kernel:
341
+ # - BMM1 scale = q_scale * k_scale * softmax_scale
342
+ # - For FP16 path we keep q_scale = 1.0, softmax_scale = 1/sqrt(head_dim) which is pre-computed as layer.scaling
343
+ # - k_scale is read from model checkpoint if available
344
+ # TODO: Change once fp8 path is supported
345
+ q_scale = 1.0
346
+ k_scale = (
347
+ layer.k_scale_float
348
+ if getattr(layer, "k_scale_float", None) is not None
349
+ else 1.0
350
+ )
351
+
352
+ bmm1_scale = q_scale * k_scale * layer.scaling
353
+
354
+ # Call TRT-LLM kernel
355
+ raw_out = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(
356
+ query=query,
357
+ kv_cache=kv_cache,
358
+ workspace_buffer=metadata.workspace,
359
+ qk_nope_head_dim=self.qk_nope_head_dim,
360
+ kv_lora_rank=self.kv_lora_rank,
361
+ qk_rope_head_dim=self.qk_rope_head_dim,
362
+ block_tables=metadata.block_kv_indices,
363
+ seq_lens=forward_batch.seq_lens.to(torch.int32),
364
+ max_seq_len=int(metadata.block_kv_indices.shape[1] * self.page_size),
365
+ bmm1_scale=bmm1_scale,
366
+ )
367
+
368
+ # Extract value projection part and reshape
369
+ raw_out_v = raw_out[..., : layer.v_head_dim].contiguous()
370
+ output = raw_out_v.view(-1, layer.tp_q_head_num * layer.v_head_dim)
371
+
372
+ return output
@@ -1,6 +1,11 @@
1
1
  import triton
2
2
  import triton.language as tl
3
3
 
4
+ # Keep this in sync with the Triton kernel inside `create_flashmla_kv_indices_triton`.
5
+ # Number of pages that the kernel writes per iteration.
6
+ # Exposed here so other Python modules can import it instead of hard-coding 64.
7
+ TRITON_PAD_NUM_PAGE_PER_BLOCK = 64
8
+
4
9
 
5
10
  @triton.jit
6
11
  def create_flashinfer_kv_indices_triton(
@@ -50,10 +55,10 @@ def create_flashmla_kv_indices_triton(
50
55
  kv_indices_ptr,
51
56
  req_to_token_ptr_stride: tl.constexpr,
52
57
  kv_indices_ptr_stride: tl.constexpr,
58
+ NUM_PAGE_PER_BLOCK: tl.constexpr = TRITON_PAD_NUM_PAGE_PER_BLOCK,
53
59
  PAGED_SIZE: tl.constexpr = 64,
54
60
  ):
55
61
  BLOCK_SIZE: tl.constexpr = 4096
56
- NUM_PAGE_PER_BLOCK: tl.constexpr = 64
57
62
  pid = tl.program_id(axis=0)
58
63
 
59
64
  # find the req pool idx, this is for batch to token
@@ -25,14 +25,22 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
25
25
  silu_and_mul_triton_kernel,
26
26
  tma_align_input_scale,
27
27
  )
28
- from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
28
+ from sglang.srt.layers.moe.fused_moe_triton.layer import (
29
+ FlashInferFusedMoE,
30
+ FusedMoE,
31
+ should_use_flashinfer_trtllm_moe,
32
+ )
29
33
  from sglang.srt.layers.moe.topk import TopKOutput
30
34
  from sglang.srt.layers.quantization import deep_gemm_wrapper
31
35
  from sglang.srt.layers.quantization.base_config import (
32
36
  QuantizationConfig,
33
37
  QuantizeMethodBase,
34
38
  )
35
- from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
39
+ from sglang.srt.layers.quantization.fp8 import (
40
+ Fp8Config,
41
+ Fp8MoEMethod,
42
+ get_tile_tokens_dim,
43
+ )
36
44
  from sglang.srt.layers.quantization.fp8_kernel import (
37
45
  is_fp8_fnuz,
38
46
  sglang_per_token_group_quant_fp8,
@@ -49,7 +57,6 @@ from sglang.srt.utils import (
49
57
  get_bool_env_var,
50
58
  is_hip,
51
59
  is_npu,
52
- next_power_of_2,
53
60
  )
54
61
 
55
62
  if TYPE_CHECKING:
@@ -63,10 +70,7 @@ _is_hip = is_hip()
63
70
  _is_npu = is_npu()
64
71
  _is_fp8_fnuz = is_fp8_fnuz()
65
72
  _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
66
- use_flashinfer_trtllm_moe = (
67
- global_server_args_dict["enable_flashinfer_trtllm_moe"]
68
- and global_server_args_dict["enable_ep_moe"]
69
- )
73
+
70
74
 
71
75
  if not (_is_npu or _is_hip):
72
76
  from sgl_kernel import silu_and_mul
@@ -76,26 +80,9 @@ if _use_aiter:
76
80
  from aiter.fused_moe import fused_moe
77
81
  from aiter.ops.shuffle import shuffle_weight
78
82
 
79
- if use_flashinfer_trtllm_moe:
80
- try:
81
- import flashinfer.fused_moe as fi_fused_moe
82
- except ImportError:
83
- fi_fused_moe = None
84
- use_flashinfer_trtllm_moe = False
85
-
86
83
  logger = logging.getLogger(__name__)
87
84
 
88
85
 
89
- def _get_tile_tokens_dim(num_tokens, top_k, num_experts):
90
- # Guess tokens per expert assuming perfect expert distribution first.
91
- num_tokens_per_expert = (num_tokens * top_k) // num_experts
92
- # And pad the number to the next power of 2.
93
- tile_tokens_dim = next_power_of_2(num_tokens_per_expert)
94
- # Cap to 8-64 tokens per CTA tile as it's the range supported by the kernel.
95
- tile_tokens_dim = min(max(tile_tokens_dim, 8), 64)
96
- return tile_tokens_dim
97
-
98
-
99
86
  class EPMoE(FusedMoE):
100
87
  """
101
88
  MoE Expert Parallel Impl
@@ -731,10 +718,10 @@ class FlashInferEPMoE(EPMoE):
731
718
  self.num_expert_group = num_expert_group
732
719
  self.topk_group = topk_group
733
720
  self.correction_bias = correction_bias
734
- self.use_flashinfer_trtllm_moe = use_flashinfer_trtllm_moe
721
+ self.use_flashinfer_trtllm_moe = should_use_flashinfer_trtllm_moe()
735
722
 
736
723
  def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
737
- assert use_flashinfer_trtllm_moe
724
+ assert self.use_flashinfer_trtllm_moe
738
725
  assert (
739
726
  self.activation == "silu"
740
727
  ), "Only silu is supported for flashinfer blockscale fp8 moe"
@@ -747,8 +734,9 @@ class FlashInferEPMoE(EPMoE):
747
734
  a_q, a_sf = sglang_per_token_group_quant_fp8(hidden_states, self.block_shape[1])
748
735
  # NOTE: scales of hidden states have to be transposed!
749
736
  a_sf_t = a_sf.t().contiguous()
750
- assert fi_fused_moe is not None
751
- return fi_fused_moe.trtllm_fp8_block_scale_moe(
737
+ from flashinfer.fused_moe import trtllm_fp8_block_scale_moe
738
+
739
+ return trtllm_fp8_block_scale_moe(
752
740
  routing_logits=router_logits.to(torch.float32),
753
741
  routing_bias=self.correction_bias.to(hidden_states.dtype),
754
742
  hidden_states=a_q,
@@ -765,7 +753,7 @@ class FlashInferEPMoE(EPMoE):
765
753
  local_expert_offset=self.start_expert_id,
766
754
  local_num_experts=self.num_local_experts,
767
755
  routed_scaling_factor=self.routed_scaling_factor,
768
- tile_tokens_dim=_get_tile_tokens_dim(
756
+ tile_tokens_dim=get_tile_tokens_dim(
769
757
  hidden_states.shape[0], self.top_k, self.num_experts
770
758
  ),
771
759
  routing_method_type=2, # DeepSeek-styled routing method
@@ -779,9 +767,6 @@ def get_moe_impl_class():
779
767
  if global_server_args_dict["enable_flashinfer_cutlass_moe"]:
780
768
  # Must come before EPMoE because FusedMoE also supports enable_ep_moe
781
769
  return FusedMoE
782
- if use_flashinfer_trtllm_moe:
783
- # Must come before EPMoE because FusedMoE also supports enable_ep_moe
784
- return FlashInferEPMoE
785
770
  if global_server_args_dict["enable_ep_moe"]:
786
- return EPMoE
787
- return FusedMoE
771
+ return FlashInferEPMoE if should_use_flashinfer_trtllm_moe() else EPMoE
772
+ return FlashInferFusedMoE if should_use_flashinfer_trtllm_moe() else FusedMoE