sglang 0.4.10__py3-none-any.whl → 0.4.10.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (92) hide show
  1. sglang/bench_offline_throughput.py +20 -0
  2. sglang/compile_deep_gemm.py +8 -1
  3. sglang/global_config.py +5 -1
  4. sglang/srt/configs/model_config.py +1 -0
  5. sglang/srt/conversation.py +0 -112
  6. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +1 -0
  7. sglang/srt/disaggregation/launch_lb.py +5 -20
  8. sglang/srt/disaggregation/mooncake/conn.py +33 -15
  9. sglang/srt/disaggregation/prefill.py +1 -0
  10. sglang/srt/distributed/device_communicators/pynccl.py +7 -0
  11. sglang/srt/distributed/device_communicators/pynccl_allocator.py +133 -0
  12. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +42 -3
  13. sglang/srt/distributed/parallel_state.py +11 -0
  14. sglang/srt/entrypoints/engine.py +4 -2
  15. sglang/srt/entrypoints/http_server.py +35 -15
  16. sglang/srt/eplb/expert_distribution.py +4 -2
  17. sglang/srt/hf_transformers_utils.py +25 -10
  18. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  19. sglang/srt/layers/attention/flashattention_backend.py +7 -11
  20. sglang/srt/layers/attention/trtllm_mla_backend.py +372 -0
  21. sglang/srt/layers/attention/utils.py +6 -1
  22. sglang/srt/layers/attention/vision.py +27 -10
  23. sglang/srt/layers/communicator.py +14 -4
  24. sglang/srt/layers/linear.py +7 -1
  25. sglang/srt/layers/logits_processor.py +9 -1
  26. sglang/srt/layers/moe/ep_moe/layer.py +29 -68
  27. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=352,device_name=NVIDIA_RTX_6000_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  28. sglang/srt/layers/moe/fused_moe_triton/layer.py +82 -25
  29. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +0 -31
  30. sglang/srt/layers/moe/token_dispatcher/__init__.py +23 -0
  31. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +12 -1
  32. sglang/srt/layers/moe/{ep_moe/token_dispatcher.py → token_dispatcher/deepep.py} +8 -15
  33. sglang/srt/layers/moe/utils.py +43 -0
  34. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +3 -2
  35. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +1 -1
  36. sglang/srt/layers/quantization/fp8.py +57 -1
  37. sglang/srt/layers/quantization/fp8_kernel.py +0 -4
  38. sglang/srt/layers/quantization/w8a8_int8.py +4 -1
  39. sglang/srt/layers/vocab_parallel_embedding.py +7 -1
  40. sglang/srt/lora/lora_registry.py +7 -0
  41. sglang/srt/managers/cache_controller.py +43 -39
  42. sglang/srt/managers/data_parallel_controller.py +52 -2
  43. sglang/srt/managers/io_struct.py +6 -1
  44. sglang/srt/managers/schedule_batch.py +3 -2
  45. sglang/srt/managers/schedule_policy.py +3 -1
  46. sglang/srt/managers/scheduler.py +145 -6
  47. sglang/srt/managers/template_manager.py +25 -22
  48. sglang/srt/managers/tokenizer_manager.py +114 -62
  49. sglang/srt/managers/utils.py +45 -1
  50. sglang/srt/mem_cache/cpp_radix_tree/radix_tree.py +182 -0
  51. sglang/srt/mem_cache/hicache_storage.py +13 -12
  52. sglang/srt/mem_cache/hiradix_cache.py +21 -4
  53. sglang/srt/mem_cache/memory_pool.py +15 -118
  54. sglang/srt/mem_cache/memory_pool_host.py +350 -33
  55. sglang/srt/mem_cache/radix_cache_cpp.py +229 -0
  56. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +8 -2
  57. sglang/srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp +35 -0
  58. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +163 -0
  59. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +238 -0
  60. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +216 -0
  61. sglang/srt/model_executor/cuda_graph_runner.py +42 -4
  62. sglang/srt/model_executor/forward_batch_info.py +13 -3
  63. sglang/srt/model_executor/model_runner.py +13 -1
  64. sglang/srt/model_loader/weight_utils.py +2 -0
  65. sglang/srt/models/deepseek_v2.py +28 -23
  66. sglang/srt/models/glm4_moe.py +85 -22
  67. sglang/srt/models/grok.py +3 -3
  68. sglang/srt/models/llama4.py +13 -2
  69. sglang/srt/models/mixtral.py +3 -3
  70. sglang/srt/models/mllama4.py +428 -19
  71. sglang/srt/models/qwen2_moe.py +1 -4
  72. sglang/srt/models/qwen3_moe.py +7 -8
  73. sglang/srt/models/step3_vl.py +1 -4
  74. sglang/srt/multimodal/processors/base_processor.py +4 -3
  75. sglang/srt/multimodal/processors/gemma3n.py +0 -7
  76. sglang/srt/operations_strategy.py +1 -1
  77. sglang/srt/server_args.py +115 -21
  78. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +18 -0
  79. sglang/srt/two_batch_overlap.py +6 -4
  80. sglang/srt/utils.py +4 -24
  81. sglang/srt/weight_sync/utils.py +1 -1
  82. sglang/test/attention/test_trtllm_mla_backend.py +945 -0
  83. sglang/test/runners.py +2 -2
  84. sglang/test/test_utils.py +3 -3
  85. sglang/version.py +1 -1
  86. {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/METADATA +3 -2
  87. {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/RECORD +92 -81
  88. /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/mooncake_store.py +0 -0
  89. /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/unit_test.py +0 -0
  90. {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/WHEEL +0 -0
  91. {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/licenses/LICENSE +0 -0
  92. {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/top_level.txt +0 -0
@@ -45,6 +45,7 @@ from fastapi.responses import ORJSONResponse, Response, StreamingResponse
45
45
 
46
46
  from sglang.srt.disaggregation.utils import (
47
47
  FAKE_BOOTSTRAP_HOST,
48
+ DisaggregationMode,
48
49
  register_disaggregation_server,
49
50
  )
50
51
  from sglang.srt.entrypoints.engine import _launch_subprocesses
@@ -88,7 +89,7 @@ from sglang.srt.managers.io_struct import (
88
89
  VertexGenerateReqInput,
89
90
  )
90
91
  from sglang.srt.managers.template_manager import TemplateManager
91
- from sglang.srt.managers.tokenizer_manager import TokenizerManager
92
+ from sglang.srt.managers.tokenizer_manager import ServerStatus, TokenizerManager
92
93
  from sglang.srt.metrics.func_timer import enable_func_timer
93
94
  from sglang.srt.reasoning_parser import ReasoningParser
94
95
  from sglang.srt.server_args import ServerArgs
@@ -230,23 +231,28 @@ async def validate_json_request(raw_request: Request):
230
231
 
231
232
 
232
233
  @app.get("/health")
233
- async def health() -> Response:
234
- """Check the health of the http server."""
235
- return Response(status_code=200)
236
-
237
-
238
234
  @app.get("/health_generate")
239
235
  async def health_generate(request: Request) -> Response:
240
- """Check the health of the inference server by generating one token."""
236
+ """
237
+ Check the health of the inference server by sending a special request to generate one token.
238
+
239
+ If the server is running something, this request will be ignored, so it creates zero overhead.
240
+ If the server is not running anything, this request will be run, so we know whether the server is healthy.
241
+ """
242
+
241
243
  if _global_state.tokenizer_manager.gracefully_exit:
242
244
  logger.info("Health check request received during shutdown. Returning 503.")
243
245
  return Response(status_code=503)
244
246
 
247
+ if not _global_state.tokenizer_manager.server_status.is_healthy():
248
+ return Response(status_code=503)
249
+
245
250
  sampling_params = {"max_new_tokens": 1, "temperature": 0.0}
246
251
  rid = f"HEALTH_CHECK_{time.time()}"
247
252
 
248
253
  if _global_state.tokenizer_manager.is_image_gen:
249
- raise NotImplementedError()
254
+ # Keep this branch for some internal use cases.
255
+ raise NotImplementedError("Image generation is not supported yet.")
250
256
  elif _global_state.tokenizer_manager.is_generation:
251
257
  gri = GenerateReqInput(
252
258
  rid=rid,
@@ -254,6 +260,12 @@ async def health_generate(request: Request) -> Response:
254
260
  sampling_params=sampling_params,
255
261
  log_metrics=False,
256
262
  )
263
+ if (
264
+ _global_state.tokenizer_manager.server_args.disaggregation_mode
265
+ != DisaggregationMode.NULL
266
+ ):
267
+ gri.bootstrap_host = FAKE_BOOTSTRAP_HOST
268
+ gri.bootstrap_room = 0
257
269
  else:
258
270
  gri = EmbeddingReqInput(
259
271
  rid=rid, input_ids=[0], sampling_params=sampling_params, log_metrics=False
@@ -263,9 +275,6 @@ async def health_generate(request: Request) -> Response:
263
275
  async for _ in _global_state.tokenizer_manager.generate_request(gri, request):
264
276
  break
265
277
 
266
- # This request is a special request.
267
- # If the server already has something running, this request will be ignored, so it creates zero overhead.
268
- # If the server is not running, this request will be run, so we know whether the server is healthy.
269
278
  task = asyncio.create_task(gen())
270
279
 
271
280
  # As long as we receive any response from the detokenizer/scheduler, we consider the server is healthy.
@@ -1032,8 +1041,10 @@ def _execute_server_warmup(
1032
1041
  timeout=600,
1033
1042
  )
1034
1043
  assert res.status_code == 200, f"{res}"
1044
+ _global_state.tokenizer_manager.server_status = ServerStatus.Up
1045
+
1035
1046
  else:
1036
- logger.info(f"Start of prefill warmup ...")
1047
+ logger.info(f"Start of pd disaggregation warmup ...")
1037
1048
  json_data = {
1038
1049
  "sampling_params": {
1039
1050
  "temperature": 0.0,
@@ -1055,9 +1066,18 @@ def _execute_server_warmup(
1055
1066
  headers=headers,
1056
1067
  timeout=1800, # because of deep gemm precache is very long if not precache.
1057
1068
  )
1058
- logger.info(
1059
- f"End of prefill warmup with status {res.status_code}, resp: {res.json()}"
1060
- )
1069
+ if res.status_code == 200:
1070
+ logger.info(
1071
+ f"End of prefill disaggregation mode warmup with status {res.status_code}, resp: {res.json()}"
1072
+ )
1073
+ _global_state.tokenizer_manager.server_status = ServerStatus.Up
1074
+ else:
1075
+ logger.info(
1076
+ "Prefill disaggregation mode warm Up Failed, status code: {}".format(
1077
+ res.status_code
1078
+ )
1079
+ )
1080
+ _global_state.tokenizer_manager.server_status = ServerStatus.UnHealthy
1061
1081
 
1062
1082
  except Exception:
1063
1083
  last_traceback = get_exception_traceback()
@@ -288,12 +288,14 @@ class _SinglePassGatherer(ABC):
288
288
  )
289
289
 
290
290
  if server_args.expert_distribution_recorder_mode == "stat_approx":
291
- if server_args.enable_deepep_moe and (server_args.deepep_mode == "normal"):
291
+ if server_args.moe_a2a_backend is not None and (
292
+ server_args.deepep_mode == "normal"
293
+ ):
292
294
  return _DeepepNormalSinglePassGatherer(expert_location_metadata, rank)
293
295
  else:
294
296
  raise NotImplementedError
295
297
 
296
- if server_args.enable_deepep_moe:
298
+ if server_args.moe_a2a_backend is not None:
297
299
  if server_args.deepep_mode == "normal":
298
300
  return _SelectExpertsSinglePassGatherer(expert_location_metadata, rank)
299
301
  elif server_args.deepep_mode == "low_latency":
@@ -14,7 +14,6 @@
14
14
  """Utilities for Huggingface Transformers."""
15
15
 
16
16
  import contextlib
17
- import logging
18
17
  import os
19
18
  import warnings
20
19
  from pathlib import Path
@@ -45,7 +44,7 @@ from sglang.srt.configs import (
45
44
  )
46
45
  from sglang.srt.configs.internvl import InternVLChatConfig
47
46
  from sglang.srt.connector import create_remote_connector
48
- from sglang.srt.utils import is_remote_url, lru_cache_frozenset
47
+ from sglang.srt.utils import is_remote_url, logger, lru_cache_frozenset
49
48
 
50
49
  _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
51
50
  ChatGLMConfig.model_type: ChatGLMConfig,
@@ -317,15 +316,31 @@ def get_processor(
317
316
 
318
317
  if config.model_type not in {"llava", "clip"}:
319
318
  kwargs["use_fast"] = use_fast
319
+ try:
320
+ processor = AutoProcessor.from_pretrained(
321
+ tokenizer_name,
322
+ *args,
323
+ trust_remote_code=trust_remote_code,
324
+ revision=revision,
325
+ **kwargs,
326
+ )
320
327
 
321
- processor = AutoProcessor.from_pretrained(
322
- tokenizer_name,
323
- *args,
324
- trust_remote_code=trust_remote_code,
325
- revision=revision,
326
- **kwargs,
327
- )
328
-
328
+ except ValueError as e:
329
+ error_message = str(e)
330
+ if "does not have a slow version" in error_message:
331
+ logger.info(
332
+ f"Processor {tokenizer_name} does not have a slow version. Automatically use fast version"
333
+ )
334
+ kwargs["use_fast"] = True
335
+ processor = AutoProcessor.from_pretrained(
336
+ tokenizer_name,
337
+ *args,
338
+ trust_remote_code=trust_remote_code,
339
+ revision=revision,
340
+ **kwargs,
341
+ )
342
+ else:
343
+ raise e
329
344
  tokenizer = get_tokenizer_from_processor(processor)
330
345
 
331
346
  attach_additional_stop_token_ids(tokenizer)
@@ -102,7 +102,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
102
102
  block_kv_indices,
103
103
  self.req_to_token.stride(0),
104
104
  max_seqlen_pad,
105
- PAGE_SIZE,
105
+ PAGED_SIZE=PAGE_SIZE,
106
106
  )
107
107
  workspace_size = cutlass_mla_get_workspace_size(
108
108
  max_seqlen_pad * PAGE_SIZE, bs, num_kv_splits=1
@@ -165,7 +165,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
165
165
  self.cuda_graph_kv_indices,
166
166
  self.req_to_token.stride(0),
167
167
  self.cuda_graph_kv_indices.stride(0),
168
- PAGE_SIZE,
168
+ PAGED_SIZE=PAGE_SIZE,
169
169
  )
170
170
  self.forward_metadata = CutlassMLADecodeMetadata(
171
171
  self.cuda_graph_mla_workspace,
@@ -206,7 +206,7 @@ class CutlassMLABackend(FlashInferMLAAttnBackend):
206
206
  self.cuda_graph_kv_indices,
207
207
  self.req_to_token.stride(0),
208
208
  self.cuda_graph_kv_indices.stride(0),
209
- PAGE_SIZE,
209
+ PAGED_SIZE=PAGE_SIZE,
210
210
  )
211
211
  else:
212
212
  super().init_forward_metadata_replay_cuda_graph(
@@ -1406,7 +1406,7 @@ class FlashAttentionBackend(AttentionBackend):
1406
1406
  )
1407
1407
  metadata.page_table = self.decode_cuda_graph_metadata[
1408
1408
  "page_table_draft_decode"
1409
- ][req_pool_indices, :]
1409
+ ][:bs, :]
1410
1410
  self.decode_cuda_graph_metadata[bs] = metadata
1411
1411
  else:
1412
1412
  # When top k > 1, we need two specific draft decode metadata, and then merge states
@@ -1424,7 +1424,7 @@ class FlashAttentionBackend(AttentionBackend):
1424
1424
  ][: bs + 1]
1425
1425
  metadata.page_table = self.draft_decode_metadata_topk_normal[
1426
1426
  "page_table"
1427
- ][req_pool_indices, :]
1427
+ ][:bs, :]
1428
1428
 
1429
1429
  # 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
1430
1430
  metadata_expand.cache_seqlens_int32 = (
@@ -1461,7 +1461,7 @@ class FlashAttentionBackend(AttentionBackend):
1461
1461
  metadata.max_seq_len_k = seq_lens.max().item()
1462
1462
  # Precompute page table
1463
1463
  metadata.page_table = self.decode_cuda_graph_metadata["page_table"][
1464
- req_pool_indices, :
1464
+ :bs, :
1465
1465
  ]
1466
1466
  # Precompute cumulative sequence lengths
1467
1467
  metadata.cu_seqlens_q = torch.arange(
@@ -1498,9 +1498,7 @@ class FlashAttentionBackend(AttentionBackend):
1498
1498
  : (bs + 1)
1499
1499
  ]
1500
1500
 
1501
- metadata.page_table = self.target_verify_metadata["page_table"][
1502
- req_pool_indices, :
1503
- ]
1501
+ metadata.page_table = self.target_verify_metadata["page_table"][:bs, :]
1504
1502
 
1505
1503
  self.target_verify_metadata[bs] = metadata
1506
1504
  else:
@@ -1519,7 +1517,7 @@ class FlashAttentionBackend(AttentionBackend):
1519
1517
  ][: bs + 1]
1520
1518
  metadata.page_table = self.target_verify_metadata_topk_normal[
1521
1519
  "page_table"
1522
- ][req_pool_indices, :]
1520
+ ][:bs, :]
1523
1521
 
1524
1522
  # 2. The second half of metadata for draft tokens (per_batch_num_tokens = topk)
1525
1523
  metadata_expand.cache_seqlens_int32 = (
@@ -1562,9 +1560,7 @@ class FlashAttentionBackend(AttentionBackend):
1562
1560
  metadata.cu_seqlens_k = self.draft_extend_metadata["cu_seqlens_k"][
1563
1561
  : (bs + 1)
1564
1562
  ]
1565
- metadata.page_table = self.draft_extend_metadata["page_table"][
1566
- req_pool_indices, :
1567
- ]
1563
+ metadata.page_table = self.draft_extend_metadata["page_table"][:bs, :]
1568
1564
 
1569
1565
  self.draft_extend_metadata[bs] = metadata
1570
1566
 
@@ -1578,7 +1574,7 @@ class FlashAttentionBackend(AttentionBackend):
1578
1574
  ][: (encoder_bs + 1)]
1579
1575
 
1580
1576
  metadata.encoder_page_table = self.encoder_metadata["encoder_page_table"][
1581
- req_pool_indices, :
1577
+ :bs, :
1582
1578
  ]
1583
1579
 
1584
1580
  self.forward_metadata = metadata
@@ -0,0 +1,372 @@
1
+ from __future__ import annotations
2
+
3
+ """
4
+ Support attention backend for TRTLLM MLA kernels from flashinfer.
5
+ """
6
+
7
+ import math
8
+ from dataclasses import dataclass
9
+ from typing import TYPE_CHECKING, Optional, Union
10
+
11
+ import torch
12
+ import triton
13
+
14
+ from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
15
+ from sglang.srt.layers.attention.utils import (
16
+ TRITON_PAD_NUM_PAGE_PER_BLOCK,
17
+ create_flashmla_kv_indices_triton,
18
+ )
19
+ from sglang.srt.layers.dp_attention import get_attention_tp_size
20
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
21
+ from sglang.srt.utils import is_flashinfer_available
22
+
23
+ if is_flashinfer_available():
24
+ import flashinfer
25
+
26
+ if TYPE_CHECKING:
27
+ from sglang.srt.layers.radix_attention import RadixAttention
28
+ from sglang.srt.model_executor.model_runner import ModelRunner
29
+ from sglang.srt.speculative.spec_info import SpecInfo
30
+
31
+ # Constants
32
+ DEFAULT_WORKSPACE_SIZE_MB = 128 # Memory workspace size in MB
33
+
34
+ # Block constraint from flashinfer requirements
35
+ # From flashinfer.decode._check_trtllm_gen_mla_shape:
36
+ # block_num % (128 / block_size) == 0
37
+ # This imposes that the total number of blocks must be divisible by
38
+ # (128 / block_size). We capture the 128 constant here so we can
39
+ # compute the LCM with other padding constraints.
40
+ TRTLLM_BLOCK_CONSTRAINT = 128
41
+
42
+
43
+ @dataclass
44
+ class TRTLLMMLADecodeMetadata:
45
+ """Metadata for TRTLLM MLA decode operations."""
46
+
47
+ workspace: Optional[torch.Tensor] = None
48
+ block_kv_indices: Optional[torch.Tensor] = None
49
+
50
+
51
+ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
52
+ """TRTLLM MLA attention kernel from flashinfer."""
53
+
54
+ def __init__(
55
+ self,
56
+ model_runner: ModelRunner,
57
+ skip_prefill: bool = False,
58
+ kv_indptr_buf: Optional[torch.Tensor] = None,
59
+ q_indptr_decode_buf: Optional[torch.Tensor] = None,
60
+ ):
61
+ super().__init__(model_runner, skip_prefill, kv_indptr_buf, q_indptr_decode_buf)
62
+
63
+ config = model_runner.model_config
64
+
65
+ # Model parameters
66
+ self.num_q_heads = config.num_attention_heads // get_attention_tp_size()
67
+ self.num_kv_heads = config.get_num_kv_heads(get_attention_tp_size())
68
+ self.num_local_heads = config.num_attention_heads // get_attention_tp_size()
69
+
70
+ # MLA-specific dimensions
71
+ self.kv_lora_rank = config.kv_lora_rank
72
+ self.qk_nope_head_dim = config.qk_nope_head_dim
73
+ self.qk_rope_head_dim = config.qk_rope_head_dim
74
+ self.v_head_dim = config.v_head_dim
75
+ self.kv_cache_dim = self.kv_lora_rank + self.qk_rope_head_dim
76
+
77
+ # Runtime parameters
78
+ self.scaling = config.scaling
79
+ self.data_type = model_runner.kv_cache_dtype
80
+ self.q_data_type = model_runner.dtype
81
+ self.page_size = model_runner.page_size
82
+ self.req_to_token = model_runner.req_to_token_pool.req_to_token
83
+
84
+ # Workspace allocation
85
+ self.workspace_size = DEFAULT_WORKSPACE_SIZE_MB * 1024 * 1024
86
+ self.workspace_buffer = torch.empty(
87
+ self.workspace_size, dtype=torch.int8, device=self.device
88
+ )
89
+
90
+ # CUDA graph state
91
+ self.decode_cuda_graph_metadata = {}
92
+ self.cuda_graph_kv_indices = None
93
+ self.forward_metadata: Union[TRTLLMMLADecodeMetadata, None] = None
94
+
95
+ def _calc_padded_blocks(self, max_seq_len: int) -> int:
96
+ """
97
+ Calculate padded block count that satisfies both TRT-LLM and Triton constraints.
98
+
99
+ Args:
100
+ max_seq_len: Maximum sequence length in tokens
101
+
102
+ Returns:
103
+ Number of blocks padded to satisfy all constraints
104
+ """
105
+ blocks = triton.cdiv(max_seq_len, self.page_size)
106
+
107
+ # Apply dual constraints (take LCM to satisfy both):
108
+ # 1. TRT-LLM: block_num % (128 / page_size) == 0
109
+ # 2. Triton: page table builder uses 64-index bursts, needs multiple of 64
110
+ trtllm_constraint = TRTLLM_BLOCK_CONSTRAINT // self.page_size
111
+ constraint_lcm = math.lcm(trtllm_constraint, TRITON_PAD_NUM_PAGE_PER_BLOCK)
112
+
113
+ if blocks % constraint_lcm != 0:
114
+ blocks = triton.cdiv(blocks, constraint_lcm) * constraint_lcm
115
+ return blocks
116
+
117
+ def _create_block_kv_indices(
118
+ self,
119
+ batch_size: int,
120
+ max_blocks: int,
121
+ req_pool_indices: torch.Tensor,
122
+ seq_lens: torch.Tensor,
123
+ device: torch.device,
124
+ ) -> torch.Tensor:
125
+ """
126
+ Create block KV indices tensor using Triton kernel.
127
+
128
+ Args:
129
+ batch_size: Batch size
130
+ max_blocks: Maximum number of blocks per sequence
131
+ req_pool_indices: Request pool indices
132
+ seq_lens: Sequence lengths
133
+ device: Target device
134
+
135
+ Returns:
136
+ Block KV indices tensor
137
+ """
138
+ block_kv_indices = torch.full(
139
+ (batch_size, max_blocks), -1, dtype=torch.int32, device=device
140
+ )
141
+
142
+ create_flashmla_kv_indices_triton[(batch_size,)](
143
+ self.req_to_token,
144
+ req_pool_indices,
145
+ seq_lens,
146
+ None,
147
+ block_kv_indices,
148
+ self.req_to_token.stride(0),
149
+ max_blocks,
150
+ NUM_PAGE_PER_BLOCK=TRITON_PAD_NUM_PAGE_PER_BLOCK,
151
+ PAGED_SIZE=self.page_size,
152
+ )
153
+
154
+ return block_kv_indices
155
+
156
+ def init_cuda_graph_state(
157
+ self,
158
+ max_bs: int,
159
+ max_num_tokens: int,
160
+ kv_indices_buf: Optional[torch.Tensor] = None,
161
+ ):
162
+ """Initialize CUDA graph state for TRTLLM MLA."""
163
+ max_blocks_per_seq = self._calc_padded_blocks(self.max_context_len)
164
+
165
+ self.cuda_graph_kv_indices = torch.full(
166
+ (max_bs, max_blocks_per_seq), -1, dtype=torch.int32, device=self.device
167
+ )
168
+ self.cuda_graph_workspace = torch.empty(
169
+ self.workspace_size, dtype=torch.int8, device=self.device
170
+ )
171
+
172
+ def init_forward_metadata_capture_cuda_graph(
173
+ self,
174
+ bs: int,
175
+ num_tokens: int,
176
+ req_pool_indices: torch.Tensor,
177
+ seq_lens: torch.Tensor,
178
+ encoder_lens: Optional[torch.Tensor],
179
+ forward_mode: ForwardMode,
180
+ spec_info: Optional[SpecInfo],
181
+ ):
182
+ """Initialize metadata for CUDA graph capture."""
183
+ # Delegate to parent for non-decode modes or when speculative execution is used.
184
+ if not (forward_mode.is_decode_or_idle() and spec_info is None):
185
+ return super().init_forward_metadata_capture_cuda_graph(
186
+ bs,
187
+ num_tokens,
188
+ req_pool_indices,
189
+ seq_lens,
190
+ encoder_lens,
191
+ forward_mode,
192
+ spec_info,
193
+ )
194
+
195
+ # Custom fast-path for decode/idle without speculative execution.
196
+ max_seqlen_pad = self._calc_padded_blocks(seq_lens.max().item())
197
+ block_kv_indices = self.cuda_graph_kv_indices[:bs, :max_seqlen_pad]
198
+
199
+ create_flashmla_kv_indices_triton[(bs,)](
200
+ self.req_to_token,
201
+ req_pool_indices,
202
+ seq_lens,
203
+ None,
204
+ block_kv_indices,
205
+ self.req_to_token.stride(0),
206
+ max_seqlen_pad,
207
+ NUM_PAGE_PER_BLOCK=TRITON_PAD_NUM_PAGE_PER_BLOCK,
208
+ PAGED_SIZE=self.page_size,
209
+ )
210
+
211
+ metadata = TRTLLMMLADecodeMetadata(self.cuda_graph_workspace, block_kv_indices)
212
+ self.decode_cuda_graph_metadata[bs] = metadata
213
+ self.forward_metadata = metadata
214
+
215
+ def init_forward_metadata_replay_cuda_graph(
216
+ self,
217
+ bs: int,
218
+ req_pool_indices: torch.Tensor,
219
+ seq_lens: torch.Tensor,
220
+ seq_lens_sum: int,
221
+ encoder_lens: Optional[torch.Tensor],
222
+ forward_mode: ForwardMode,
223
+ spec_info: Optional[SpecInfo],
224
+ seq_lens_cpu: Optional[torch.Tensor],
225
+ ):
226
+ """Replay CUDA graph with new inputs."""
227
+ # Delegate to parent for non-decode modes or when speculative execution is used.
228
+ if not (forward_mode.is_decode_or_idle() and spec_info is None):
229
+ return super().init_forward_metadata_replay_cuda_graph(
230
+ bs,
231
+ req_pool_indices,
232
+ seq_lens,
233
+ seq_lens_sum,
234
+ encoder_lens,
235
+ forward_mode,
236
+ spec_info,
237
+ seq_lens_cpu,
238
+ )
239
+
240
+ metadata = self.decode_cuda_graph_metadata[bs]
241
+
242
+ # Update block indices for new sequences.
243
+ create_flashmla_kv_indices_triton[(bs,)](
244
+ self.req_to_token,
245
+ req_pool_indices[:bs],
246
+ seq_lens[:bs],
247
+ None,
248
+ metadata.block_kv_indices,
249
+ self.req_to_token.stride(0),
250
+ metadata.block_kv_indices.shape[1],
251
+ NUM_PAGE_PER_BLOCK=TRITON_PAD_NUM_PAGE_PER_BLOCK,
252
+ PAGED_SIZE=self.page_size,
253
+ )
254
+
255
+ def get_cuda_graph_seq_len_fill_value(self) -> int:
256
+ """Get the fill value for sequence lengths in CUDA graph."""
257
+ return 1
258
+
259
+ def init_forward_metadata(self, forward_batch: ForwardBatch):
260
+ """Initialize the metadata for a forward pass."""
261
+ # Delegate to parent for non-decode modes or when speculative execution is used.
262
+ if not (
263
+ forward_batch.forward_mode.is_decode_or_idle()
264
+ and forward_batch.spec_info is None
265
+ ):
266
+ return super().init_forward_metadata(forward_batch)
267
+
268
+ bs = forward_batch.batch_size
269
+
270
+ # Get maximum sequence length.
271
+ if getattr(forward_batch, "seq_lens_cpu", None) is not None:
272
+ max_seq = forward_batch.seq_lens_cpu.max().item()
273
+ else:
274
+ max_seq = forward_batch.seq_lens.max().item()
275
+
276
+ max_seqlen_pad = self._calc_padded_blocks(max_seq)
277
+ block_kv_indices = self._create_block_kv_indices(
278
+ bs,
279
+ max_seqlen_pad,
280
+ forward_batch.req_pool_indices,
281
+ forward_batch.seq_lens,
282
+ forward_batch.seq_lens.device,
283
+ )
284
+
285
+ self.forward_metadata = TRTLLMMLADecodeMetadata(
286
+ self.workspace_buffer, block_kv_indices
287
+ )
288
+ forward_batch.decode_trtllm_mla_metadata = self.forward_metadata
289
+
290
+ def forward_decode(
291
+ self,
292
+ q: torch.Tensor,
293
+ k: torch.Tensor,
294
+ v: torch.Tensor,
295
+ layer: RadixAttention,
296
+ forward_batch: ForwardBatch,
297
+ save_kv_cache: bool = True,
298
+ q_rope: Optional[torch.Tensor] = None,
299
+ k_rope: Optional[torch.Tensor] = None,
300
+ ) -> torch.Tensor:
301
+ """Run forward for decode using TRTLLM MLA kernel."""
302
+ # Save KV cache if requested
303
+ if k is not None and save_kv_cache:
304
+ cache_loc = forward_batch.out_cache_loc
305
+ if k_rope is not None:
306
+ forward_batch.token_to_kv_pool.set_mla_kv_buffer(
307
+ layer, cache_loc, k, k_rope
308
+ )
309
+ elif v is not None:
310
+ forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
311
+
312
+ # Prepare query tensor inline
313
+ if q_rope is not None:
314
+ # q contains NOPE part (v_head_dim)
315
+ q_nope = q.view(-1, layer.tp_q_head_num, layer.v_head_dim)
316
+ q_rope_reshaped = q_rope.view(
317
+ -1, layer.tp_q_head_num, layer.head_dim - layer.v_head_dim
318
+ )
319
+ query = torch.cat([q_nope, q_rope_reshaped], dim=-1)
320
+ else:
321
+ # q already has both parts
322
+ query = q.view(-1, layer.tp_q_head_num, layer.head_dim)
323
+
324
+ # Ensure query has shape [bs, acc_q_len, num_q_heads, head_dim] when seq_len 1
325
+ if query.dim() == 3:
326
+ query = query.unsqueeze(1)
327
+
328
+ # Prepare KV cache inline
329
+ k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
330
+ pages = k_cache.view(-1, self.page_size, self.kv_cache_dim)
331
+ # TRT-LLM expects single KV data with extra dimension
332
+ kv_cache = pages.unsqueeze(1)
333
+
334
+ # Get metadata
335
+ metadata = (
336
+ getattr(forward_batch, "decode_trtllm_mla_metadata", None)
337
+ or self.forward_metadata
338
+ )
339
+
340
+ # Scale computation for TRTLLM MLA kernel:
341
+ # - BMM1 scale = q_scale * k_scale * softmax_scale
342
+ # - For FP16 path we keep q_scale = 1.0, softmax_scale = 1/sqrt(head_dim) which is pre-computed as layer.scaling
343
+ # - k_scale is read from model checkpoint if available
344
+ # TODO: Change once fp8 path is supported
345
+ q_scale = 1.0
346
+ k_scale = (
347
+ layer.k_scale_float
348
+ if getattr(layer, "k_scale_float", None) is not None
349
+ else 1.0
350
+ )
351
+
352
+ bmm1_scale = q_scale * k_scale * layer.scaling
353
+
354
+ # Call TRT-LLM kernel
355
+ raw_out = flashinfer.decode.trtllm_batch_decode_with_kv_cache_mla(
356
+ query=query,
357
+ kv_cache=kv_cache,
358
+ workspace_buffer=metadata.workspace,
359
+ qk_nope_head_dim=self.qk_nope_head_dim,
360
+ kv_lora_rank=self.kv_lora_rank,
361
+ qk_rope_head_dim=self.qk_rope_head_dim,
362
+ block_tables=metadata.block_kv_indices,
363
+ seq_lens=forward_batch.seq_lens.to(torch.int32),
364
+ max_seq_len=int(metadata.block_kv_indices.shape[1] * self.page_size),
365
+ bmm1_scale=bmm1_scale,
366
+ )
367
+
368
+ # Extract value projection part and reshape
369
+ raw_out_v = raw_out[..., : layer.v_head_dim].contiguous()
370
+ output = raw_out_v.view(-1, layer.tp_q_head_num * layer.v_head_dim)
371
+
372
+ return output
@@ -1,6 +1,11 @@
1
1
  import triton
2
2
  import triton.language as tl
3
3
 
4
+ # Keep this in sync with the Triton kernel inside `create_flashmla_kv_indices_triton`.
5
+ # Number of pages that the kernel writes per iteration.
6
+ # Exposed here so other Python modules can import it instead of hard-coding 64.
7
+ TRITON_PAD_NUM_PAGE_PER_BLOCK = 64
8
+
4
9
 
5
10
  @triton.jit
6
11
  def create_flashinfer_kv_indices_triton(
@@ -50,10 +55,10 @@ def create_flashmla_kv_indices_triton(
50
55
  kv_indices_ptr,
51
56
  req_to_token_ptr_stride: tl.constexpr,
52
57
  kv_indices_ptr_stride: tl.constexpr,
58
+ NUM_PAGE_PER_BLOCK: tl.constexpr = TRITON_PAD_NUM_PAGE_PER_BLOCK,
53
59
  PAGED_SIZE: tl.constexpr = 64,
54
60
  ):
55
61
  BLOCK_SIZE: tl.constexpr = 4096
56
- NUM_PAGE_PER_BLOCK: tl.constexpr = 64
57
62
  pid = tl.program_id(axis=0)
58
63
 
59
64
  # find the req pool idx, this is for batch to token