sglang 0.5.0rc1__py3-none-any.whl → 0.5.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (203) hide show
  1. sglang/bench_one_batch.py +0 -7
  2. sglang/bench_one_batch_server.py +7 -2
  3. sglang/bench_serving.py +3 -3
  4. sglang/eval/llama3_eval.py +0 -1
  5. sglang/srt/configs/model_config.py +25 -9
  6. sglang/srt/configs/update_config.py +40 -5
  7. sglang/srt/constrained/xgrammar_backend.py +23 -11
  8. sglang/srt/conversation.py +2 -15
  9. sglang/srt/disaggregation/ascend/conn.py +1 -3
  10. sglang/srt/disaggregation/base/conn.py +1 -0
  11. sglang/srt/disaggregation/decode.py +1 -2
  12. sglang/srt/disaggregation/launch_lb.py +7 -1
  13. sglang/srt/disaggregation/mini_lb.py +11 -5
  14. sglang/srt/disaggregation/mooncake/conn.py +141 -47
  15. sglang/srt/disaggregation/prefill.py +261 -5
  16. sglang/srt/disaggregation/utils.py +2 -1
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
  18. sglang/srt/distributed/device_communicators/pynccl.py +68 -18
  19. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
  20. sglang/srt/distributed/naive_distributed.py +112 -0
  21. sglang/srt/distributed/parallel_state.py +90 -4
  22. sglang/srt/entrypoints/context.py +20 -1
  23. sglang/srt/entrypoints/engine.py +29 -4
  24. sglang/srt/entrypoints/http_server.py +76 -0
  25. sglang/srt/entrypoints/openai/protocol.py +4 -2
  26. sglang/srt/entrypoints/openai/serving_chat.py +23 -6
  27. sglang/srt/entrypoints/openai/serving_completions.py +10 -1
  28. sglang/srt/entrypoints/openai/serving_responses.py +2 -2
  29. sglang/srt/eplb/expert_distribution.py +2 -3
  30. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  31. sglang/srt/hf_transformers_utils.py +24 -0
  32. sglang/srt/host_shared_memory.py +83 -0
  33. sglang/srt/layers/attention/ascend_backend.py +132 -22
  34. sglang/srt/layers/attention/flashattention_backend.py +24 -17
  35. sglang/srt/layers/attention/flashinfer_backend.py +14 -3
  36. sglang/srt/layers/attention/flashinfer_mla_backend.py +227 -76
  37. sglang/srt/layers/attention/triton_backend.py +109 -73
  38. sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
  39. sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
  40. sglang/srt/layers/attention/trtllm_mha_backend.py +398 -36
  41. sglang/srt/layers/attention/trtllm_mla_backend.py +49 -19
  42. sglang/srt/layers/attention/utils.py +94 -15
  43. sglang/srt/layers/attention/vision.py +40 -13
  44. sglang/srt/layers/attention/vision_utils.py +65 -0
  45. sglang/srt/layers/communicator.py +58 -10
  46. sglang/srt/layers/dp_attention.py +137 -27
  47. sglang/srt/layers/elementwise.py +94 -0
  48. sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
  49. sglang/srt/layers/layernorm.py +8 -1
  50. sglang/srt/layers/linear.py +24 -0
  51. sglang/srt/layers/logits_processor.py +16 -18
  52. sglang/srt/layers/moe/__init__.py +31 -0
  53. sglang/srt/layers/moe/ep_moe/layer.py +37 -33
  54. sglang/srt/layers/moe/fused_moe_native.py +14 -25
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  57. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  63. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  64. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  65. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  66. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  67. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  68. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  69. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
  70. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
  71. sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
  72. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
  73. sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
  74. sglang/srt/layers/moe/moe_runner/base.py +13 -0
  75. sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
  76. sglang/srt/layers/moe/router.py +15 -9
  77. sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
  78. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
  79. sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
  80. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  81. sglang/srt/layers/moe/topk.py +167 -83
  82. sglang/srt/layers/moe/utils.py +159 -18
  83. sglang/srt/layers/multimodal.py +156 -40
  84. sglang/srt/layers/quantization/__init__.py +18 -46
  85. sglang/srt/layers/quantization/awq.py +22 -23
  86. sglang/srt/layers/quantization/base_config.py +2 -6
  87. sglang/srt/layers/quantization/blockwise_int8.py +4 -12
  88. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -29
  89. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -1
  90. sglang/srt/layers/quantization/fp8.py +127 -119
  91. sglang/srt/layers/quantization/fp8_kernel.py +195 -24
  92. sglang/srt/layers/quantization/fp8_utils.py +34 -9
  93. sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
  94. sglang/srt/layers/quantization/gptq.py +17 -21
  95. sglang/srt/layers/quantization/marlin_utils.py +26 -8
  96. sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
  97. sglang/srt/layers/quantization/modelopt_quant.py +217 -98
  98. sglang/srt/layers/quantization/moe_wna16.py +10 -15
  99. sglang/srt/layers/quantization/mxfp4.py +222 -39
  100. sglang/srt/layers/quantization/quark/quark.py +390 -0
  101. sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
  102. sglang/srt/layers/quantization/unquant.py +34 -70
  103. sglang/srt/layers/quantization/utils.py +77 -2
  104. sglang/srt/layers/quantization/w4afp8.py +7 -8
  105. sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
  106. sglang/srt/layers/quantization/w8a8_int8.py +5 -13
  107. sglang/srt/layers/radix_attention.py +6 -0
  108. sglang/srt/layers/rotary_embedding.py +1 -0
  109. sglang/srt/layers/sampler.py +5 -2
  110. sglang/srt/lora/layers.py +6 -2
  111. sglang/srt/lora/lora_manager.py +21 -22
  112. sglang/srt/lora/lora_registry.py +3 -3
  113. sglang/srt/lora/mem_pool.py +26 -24
  114. sglang/srt/lora/utils.py +10 -12
  115. sglang/srt/managers/cache_controller.py +80 -19
  116. sglang/srt/managers/detokenizer_manager.py +10 -2
  117. sglang/srt/managers/io_struct.py +23 -0
  118. sglang/srt/managers/mm_utils.py +1 -1
  119. sglang/srt/managers/schedule_batch.py +22 -48
  120. sglang/srt/managers/scheduler.py +28 -20
  121. sglang/srt/managers/session_controller.py +1 -1
  122. sglang/srt/managers/template_manager.py +7 -5
  123. sglang/srt/managers/tokenizer_manager.py +88 -39
  124. sglang/srt/managers/tp_worker.py +1 -0
  125. sglang/srt/managers/utils.py +59 -1
  126. sglang/srt/mem_cache/allocator.py +10 -157
  127. sglang/srt/mem_cache/allocator_ascend.py +147 -0
  128. sglang/srt/mem_cache/chunk_cache.py +1 -1
  129. sglang/srt/mem_cache/hicache_storage.py +14 -4
  130. sglang/srt/mem_cache/memory_pool.py +3 -3
  131. sglang/srt/mem_cache/memory_pool_host.py +35 -2
  132. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
  133. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
  134. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
  135. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
  136. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
  137. sglang/srt/model_executor/cuda_graph_runner.py +33 -33
  138. sglang/srt/model_executor/forward_batch_info.py +11 -10
  139. sglang/srt/model_executor/model_runner.py +93 -78
  140. sglang/srt/model_executor/npu_graph_runner.py +94 -0
  141. sglang/srt/model_loader/loader.py +24 -6
  142. sglang/srt/models/dbrx.py +12 -6
  143. sglang/srt/models/deepseek.py +2 -1
  144. sglang/srt/models/deepseek_nextn.py +5 -2
  145. sglang/srt/models/deepseek_v2.py +226 -223
  146. sglang/srt/models/ernie4.py +2 -2
  147. sglang/srt/models/glm4_moe.py +27 -65
  148. sglang/srt/models/glm4_moe_nextn.py +2 -1
  149. sglang/srt/models/glm4v.py +52 -1
  150. sglang/srt/models/glm4v_moe.py +8 -11
  151. sglang/srt/models/gpt_oss.py +41 -76
  152. sglang/srt/models/granitemoe.py +0 -1
  153. sglang/srt/models/grok.py +376 -48
  154. sglang/srt/models/interns1.py +12 -47
  155. sglang/srt/models/internvl.py +6 -51
  156. sglang/srt/models/llama.py +10 -2
  157. sglang/srt/models/llama4.py +18 -7
  158. sglang/srt/models/minicpm3.py +0 -1
  159. sglang/srt/models/mixtral.py +0 -2
  160. sglang/srt/models/nemotron_nas.py +435 -0
  161. sglang/srt/models/olmoe.py +0 -1
  162. sglang/srt/models/phi4mm.py +3 -21
  163. sglang/srt/models/qwen2.py +2 -2
  164. sglang/srt/models/qwen2_5_vl.py +2 -0
  165. sglang/srt/models/qwen2_moe.py +23 -23
  166. sglang/srt/models/qwen3.py +2 -2
  167. sglang/srt/models/qwen3_classification.py +84 -0
  168. sglang/srt/models/qwen3_moe.py +27 -43
  169. sglang/srt/models/step3_vl.py +8 -3
  170. sglang/srt/models/xverse_moe.py +11 -5
  171. sglang/srt/multimodal/processors/base_processor.py +3 -3
  172. sglang/srt/multimodal/processors/internvl.py +7 -2
  173. sglang/srt/multimodal/processors/llava.py +11 -7
  174. sglang/srt/offloader.py +433 -0
  175. sglang/srt/operations.py +22 -2
  176. sglang/srt/reasoning_parser.py +4 -3
  177. sglang/srt/sampling/sampling_batch_info.py +7 -4
  178. sglang/srt/server_args.py +264 -105
  179. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +8 -21
  180. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
  181. sglang/srt/speculative/eagle_utils.py +36 -13
  182. sglang/srt/speculative/eagle_worker.py +56 -3
  183. sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
  184. sglang/srt/two_batch_overlap.py +20 -19
  185. sglang/srt/utils.py +68 -70
  186. sglang/test/runners.py +8 -5
  187. sglang/test/test_block_fp8.py +5 -6
  188. sglang/test/test_block_fp8_ep.py +13 -19
  189. sglang/test/test_cutlass_moe.py +4 -6
  190. sglang/test/test_cutlass_w4a8_moe.py +4 -3
  191. sglang/test/test_fp4_moe.py +4 -3
  192. sglang/test/test_marlin_moe.py +1 -1
  193. sglang/test/test_marlin_utils.py +1 -1
  194. sglang/test/test_utils.py +7 -0
  195. sglang/utils.py +0 -1
  196. sglang/version.py +1 -1
  197. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/METADATA +11 -11
  198. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/RECORD +201 -171
  199. sglang/srt/layers/quantization/fp4.py +0 -557
  200. sglang/srt/layers/quantization/scalar_type.py +0 -352
  201. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/WHEEL +0 -0
  202. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/licenses/LICENSE +0 -0
  203. {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/top_level.txt +0 -0
@@ -10,23 +10,30 @@ from typing import TYPE_CHECKING, Optional
10
10
 
11
11
  import torch
12
12
 
13
- from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
13
+ from sglang.srt.layers.attention.flashinfer_backend import (
14
+ FlashInferAttnBackend,
15
+ FlashInferMultiStepDraftBackend,
16
+ )
14
17
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
15
18
  from sglang.srt.utils import is_flashinfer_available
16
19
 
17
20
  if is_flashinfer_available():
18
21
  import flashinfer
19
22
 
23
+ from sglang.srt.speculative.eagle_utils import EagleDraftInput
24
+
20
25
  if TYPE_CHECKING:
21
26
  from sglang.srt.layers.radix_attention import RadixAttention
22
27
  from sglang.srt.model_executor.model_runner import ModelRunner
23
28
  from sglang.srt.speculative.spec_info import SpecInfo
24
29
 
25
30
  # Constants
26
- DEFAULT_WORKSPACE_SIZE_MB = 128 # Memory workspace size in MB
31
+ DEFAULT_WORKSPACE_SIZE_MB = (
32
+ 512 # Memory workspace size in MB, todo(Yingyi): read from config
33
+ )
27
34
 
28
35
  # Reuse this workspace buffer across all TRTLLM MHA wrappers
29
- global_workspace_buffer = None
36
+ global_zero_init_workspace_buffer = None
30
37
 
31
38
 
32
39
  @dataclass
@@ -53,9 +60,12 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
53
60
  model_runner: ModelRunner,
54
61
  skip_prefill: bool = False,
55
62
  kv_indptr_buf: Optional[torch.Tensor] = None,
56
- q_indptr_decode_buf: Optional[torch.Tensor] = None,
63
+ kv_last_page_len_buf: Optional[torch.Tensor] = None,
64
+ speculative_step_id: int = 0,
57
65
  ):
58
- super().__init__(model_runner, skip_prefill, kv_indptr_buf, q_indptr_decode_buf)
66
+ super().__init__(
67
+ model_runner, skip_prefill, kv_indptr_buf, kv_last_page_len_buf
68
+ )
59
69
 
60
70
  config = model_runner.model_config
61
71
 
@@ -73,18 +83,28 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
73
83
  # Workspace allocation
74
84
  self.workspace_size = DEFAULT_WORKSPACE_SIZE_MB * 1024 * 1024
75
85
  # Allocate buffers
76
- global global_workspace_buffer
77
- if global_workspace_buffer is None:
78
- global_workspace_buffer = torch.empty(
86
+ global global_zero_init_workspace_buffer
87
+ if global_zero_init_workspace_buffer is None:
88
+ global_zero_init_workspace_buffer = torch.zeros(
79
89
  self.workspace_size,
80
90
  dtype=torch.uint8,
81
91
  device=model_runner.device,
82
92
  )
83
- self.workspace_buffer = global_workspace_buffer
93
+ self.workspace_buffer = global_zero_init_workspace_buffer
84
94
 
85
95
  # CUDA graph state
86
96
  self.decode_cuda_graph_metadata = {}
87
97
 
98
+ # Speculative decoding
99
+ # Only support topk <= 1 for now.
100
+ self.topk = model_runner.server_args.speculative_eagle_topk or 0
101
+ self.speculative_step_id = speculative_step_id
102
+ self.target_verify_metadata = {}
103
+
104
+ self.speculative_num_draft_tokens = (
105
+ model_runner.server_args.speculative_num_draft_tokens
106
+ )
107
+
88
108
  # Forward metadata
89
109
  self.forward_metadata: Optional[TRTLLMMHAMetadata] = None
90
110
 
@@ -95,11 +115,12 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
95
115
  kv_indices_buf: Optional[torch.Tensor] = None,
96
116
  ):
97
117
  """Initialize CUDA graph state for TRTLLM MHA."""
118
+ max_num_pages = (self.max_context_len + self.page_size - 1) // self.page_size
98
119
  self.decode_cuda_graph_metadata = {
99
120
  "cache_seqlens": torch.zeros(max_bs, dtype=torch.int32, device=self.device),
100
121
  "page_table": torch.zeros(
101
122
  max_bs,
102
- (self.max_context_len + self.page_size - 1) // self.page_size,
123
+ max_num_pages,
103
124
  dtype=torch.int32,
104
125
  device=self.device,
105
126
  ),
@@ -108,6 +129,70 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
108
129
  ),
109
130
  }
110
131
 
132
+ if (
133
+ self.speculative_num_draft_tokens is not None
134
+ and self.speculative_num_draft_tokens > 0
135
+ ):
136
+ self.decode_cuda_graph_metadata["cu_seqlens_q"] = torch.arange(
137
+ 0, max_bs + 1, dtype=torch.int32, device=self.device
138
+ )
139
+ self.decode_cuda_graph_metadata["cu_seqlens_k"] = torch.zeros(
140
+ max_bs + 1, dtype=torch.int32, device=self.device
141
+ )
142
+ self.decode_cuda_graph_metadata["page_table_draft_decode"] = torch.zeros(
143
+ max_bs,
144
+ max_num_pages,
145
+ dtype=torch.int32,
146
+ device=self.device,
147
+ )
148
+ self.target_verify_metadata = {
149
+ "cache_seqlens": torch.zeros(
150
+ max_bs, dtype=torch.int32, device=self.device
151
+ ),
152
+ "cu_seqlens_q": torch.arange(
153
+ 0,
154
+ max_bs * self.speculative_num_draft_tokens + 1,
155
+ step=self.speculative_num_draft_tokens,
156
+ dtype=torch.int32,
157
+ device=self.device,
158
+ ),
159
+ "cu_seqlens_k": torch.zeros(
160
+ max_bs + 1, dtype=torch.int32, device=self.device
161
+ ),
162
+ "page_table": torch.zeros(
163
+ max_bs,
164
+ max_num_pages,
165
+ dtype=torch.int32,
166
+ device=self.device,
167
+ ),
168
+ "strided_indices": torch.arange(
169
+ 0, self.max_context_len, self.page_size, device=self.device
170
+ ),
171
+ }
172
+
173
+ self.draft_extend_metadata = {
174
+ "cache_seqlens": torch.zeros(
175
+ max_bs, dtype=torch.int32, device=self.device
176
+ ),
177
+ "cu_seqlens_q": torch.zeros(
178
+ max_bs + 1,
179
+ dtype=torch.int32,
180
+ device=self.device,
181
+ ),
182
+ "cu_seqlens_k": torch.zeros(
183
+ max_bs + 1, dtype=torch.int32, device=self.device
184
+ ),
185
+ "page_table": torch.zeros(
186
+ max_bs,
187
+ max_num_pages,
188
+ dtype=torch.int32,
189
+ device=self.device,
190
+ ),
191
+ "strided_indices": torch.arange(
192
+ 0, self.max_context_len, self.page_size, device=self.device
193
+ ),
194
+ }
195
+
111
196
  def init_forward_metadata_capture_cuda_graph(
112
197
  self,
113
198
  bs: int,
@@ -120,16 +205,105 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
120
205
  ):
121
206
  """Initialize metadata for CUDA graph capture."""
122
207
  metadata = TRTLLMMHAMetadata()
208
+ device = seq_lens.device
123
209
 
124
- # Get sequence information
125
- metadata.cache_seqlens_int32 = seq_lens[:bs].to(torch.int32)
210
+ if forward_mode.is_decode_or_idle():
211
+ if spec_info is not None:
212
+ # Draft Decode
213
+ # Here we only support topk = 1 for now.
214
+ metadata.cache_seqlens_int32 = self.decode_cuda_graph_metadata[
215
+ "cache_seqlens"
216
+ ][:bs]
217
+ metadata.max_seq_len_k = seq_lens.max().item() + (
218
+ self.speculative_step_id + 1
219
+ )
220
+ metadata.cu_seqlens_q = self.decode_cuda_graph_metadata["cu_seqlens_q"][
221
+ : bs + 1
222
+ ]
223
+ metadata.cu_seqlens_k = torch.nn.functional.pad(
224
+ torch.cumsum(
225
+ metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
226
+ ),
227
+ (1, 0),
228
+ )
229
+ metadata.page_table = self.decode_cuda_graph_metadata[
230
+ "page_table_draft_decode"
231
+ ][:bs, :]
232
+ self.decode_cuda_graph_metadata[bs] = metadata
233
+ else:
234
+ # Normal Decode
235
+ # Get sequence information
236
+ metadata.cache_seqlens_int32 = seq_lens[:bs].to(torch.int32)
237
+ batch_size = len(seq_lens)
238
+ metadata.cu_seqlens_k = torch.nn.functional.pad(
239
+ torch.cumsum(seq_lens, dim=0, dtype=torch.int32), (1, 0)
240
+ )
126
241
 
127
- # Precompute maximum sequence length
128
- metadata.max_seq_len_k = self.max_context_len
242
+ # Precompute maximum sequence length
243
+ metadata.max_seq_len_k = seq_lens.max().item()
244
+ # Precompute cumulative sequence lengths
245
+ metadata.cu_seqlens_q = torch.arange(
246
+ 0, batch_size + 1, dtype=torch.int32, device=device
247
+ )
248
+ # Precompute page table
249
+ metadata.page_table = self.decode_cuda_graph_metadata["page_table"][
250
+ :bs, :
251
+ ]
252
+ self.decode_cuda_graph_metadata[bs] = metadata
253
+ elif forward_mode.is_target_verify():
254
+ # Target Verify
255
+ # Here we only support topk = 1 for now.
256
+ metadata.cache_seqlens_int32 = self.target_verify_metadata["cache_seqlens"][
257
+ :bs
258
+ ]
259
+ metadata.cache_seqlens_int32.copy_(
260
+ (seq_lens + self.speculative_num_draft_tokens)
261
+ )
129
262
 
130
- # Precompute page table
131
- metadata.page_table = self.decode_cuda_graph_metadata["page_table"][:bs, :]
132
- self.decode_cuda_graph_metadata[bs] = metadata
263
+ metadata.cu_seqlens_q = torch.arange(
264
+ 0,
265
+ bs * self.speculative_num_draft_tokens + 1,
266
+ self.speculative_num_draft_tokens,
267
+ dtype=torch.int32,
268
+ device=device,
269
+ )
270
+
271
+ metadata.cu_seqlens_k = self.target_verify_metadata["cu_seqlens_k"][
272
+ : (bs + 1)
273
+ ]
274
+
275
+ metadata.max_seq_len_q = self.speculative_num_draft_tokens
276
+ metadata.max_seq_len_k = (
277
+ seq_lens.max().item() + self.speculative_num_draft_tokens
278
+ )
279
+
280
+ metadata.page_table = self.target_verify_metadata["page_table"][:bs, :]
281
+
282
+ self.target_verify_metadata[bs] = metadata
283
+ elif forward_mode.is_draft_extend():
284
+ metadata.cache_seqlens_int32 = self.draft_extend_metadata["cache_seqlens"][
285
+ :bs
286
+ ]
287
+ metadata.cache_seqlens_int32.copy_(seq_lens)
288
+ num_tokens_per_bs = num_tokens // bs
289
+ metadata.cu_seqlens_q = torch.arange(
290
+ 0,
291
+ bs * num_tokens_per_bs + 1,
292
+ num_tokens_per_bs,
293
+ dtype=torch.int32,
294
+ device=device,
295
+ )
296
+
297
+ metadata.cu_seqlens_k = self.draft_extend_metadata["cu_seqlens_k"][
298
+ : (bs + 1)
299
+ ]
300
+ num_tokens_per_bs = num_tokens // bs
301
+ metadata.max_seq_len_q = num_tokens_per_bs
302
+ metadata.max_seq_len_k = seq_lens.max().item()
303
+
304
+ metadata.page_table = self.draft_extend_metadata["page_table"][:bs, :]
305
+
306
+ self.draft_extend_metadata[bs] = metadata
133
307
  self.forward_metadata = metadata
134
308
 
135
309
  def init_forward_metadata_replay_cuda_graph(
@@ -147,21 +321,91 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
147
321
  seq_lens = seq_lens[:bs]
148
322
  seq_lens_cpu = seq_lens_cpu[:bs]
149
323
  req_pool_indices = req_pool_indices[:bs]
150
- device = seq_lens.device
151
324
  metadata = None
325
+ if forward_mode.is_decode_or_idle():
326
+ if spec_info is not None:
327
+ # Draft Decode
328
+ # Here we only support topk = 1 for now.
329
+ metadata = self.decode_cuda_graph_metadata[bs]
330
+ max_len = seq_lens_cpu.max().item()
331
+ metadata.max_seq_len_k = max_len + self.speculative_step_id + 1
332
+
333
+ max_seq_pages = (
334
+ metadata.max_seq_len_k + self.page_size - 1
335
+ ) // self.page_size
336
+
337
+ metadata.cache_seqlens_int32.copy_(
338
+ seq_lens + self.speculative_step_id + 1
339
+ )
340
+ else:
341
+ # Normal Decode
342
+ metadata = self.decode_cuda_graph_metadata[bs]
343
+ max_len = seq_lens_cpu.max().item()
344
+ max_seq_pages = (max_len + self.page_size - 1) // self.page_size
345
+ metadata.max_seq_len_k = max_len
346
+
347
+ metadata.cache_seqlens_int32.copy_(seq_lens)
348
+
349
+ metadata.cu_seqlens_k[1:].copy_(
350
+ torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32)
351
+ )
352
+ page_indices = self.req_to_token[
353
+ req_pool_indices[:, None],
354
+ self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages][
355
+ None, :
356
+ ],
357
+ ]
358
+ metadata.page_table[:, :max_seq_pages].copy_(page_indices // self.page_size)
359
+ elif forward_mode.is_target_verify():
360
+ # Here we only support topk = 1 for now.
361
+ metadata = self.target_verify_metadata[bs]
362
+ metadata.cache_seqlens_int32.copy_(
363
+ (seq_lens + self.speculative_num_draft_tokens)
364
+ )
152
365
 
153
- # Normal Decode
154
- metadata = self.decode_cuda_graph_metadata[bs]
155
- max_len = seq_lens_cpu.max().item()
156
- max_seq_pages = (max_len + self.page_size - 1) // self.page_size
157
- metadata.max_seq_len_k = self.max_context_len
158
-
159
- metadata.cache_seqlens_int32.copy_(seq_lens)
160
- page_indices = self.req_to_token[
161
- req_pool_indices[:, None],
162
- self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages][None, :],
163
- ]
164
- metadata.page_table[:, :max_seq_pages].copy_(page_indices // self.page_size)
366
+ metadata.max_seq_len_k = (
367
+ seq_lens_cpu.max().item() + self.speculative_num_draft_tokens
368
+ )
369
+ max_len = seq_lens_cpu.max().item()
370
+ metadata.cu_seqlens_k[1:].copy_(
371
+ torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32)
372
+ )
373
+ max_seq_pages = (
374
+ metadata.max_seq_len_k + self.page_size - 1
375
+ ) // self.page_size
376
+ page_indices = self.req_to_token[
377
+ req_pool_indices[:, None],
378
+ self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages],
379
+ ]
380
+ page_indices //= self.page_size
381
+ metadata.page_table[:, :max_seq_pages].copy_(page_indices)
382
+ elif forward_mode.is_draft_extend():
383
+ metadata = self.draft_extend_metadata[bs]
384
+ metadata.cache_seqlens_int32.copy_(seq_lens)
385
+
386
+ metadata.max_seq_len_k = seq_lens_cpu.max().item()
387
+ max_len = seq_lens_cpu.max().item()
388
+ metadata.cu_seqlens_k[1:].copy_(
389
+ torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32)
390
+ )
391
+ accept_length = spec_info.accept_length[:bs]
392
+ if spec_info.accept_length_cpu:
393
+ metadata.max_seq_len_q = max(spec_info.accept_length_cpu) + 1
394
+ else:
395
+ metadata.max_seq_len_q = 1
396
+
397
+ metadata.cu_seqlens_q[1:].copy_(
398
+ torch.cumsum(accept_length, dim=0, dtype=torch.int32)
399
+ )
400
+
401
+ max_seq_pages = (
402
+ metadata.max_seq_len_k + self.page_size - 1
403
+ ) // self.page_size
404
+ page_indices = self.req_to_token[
405
+ req_pool_indices[:, None],
406
+ self.draft_extend_metadata["strided_indices"][:max_seq_pages],
407
+ ]
408
+ metadata.page_table[:, :max_seq_pages].copy_(page_indices // self.page_size)
165
409
  self.forward_metadata = metadata
166
410
 
167
411
  def get_cuda_graph_seq_len_fill_value(self) -> int:
@@ -177,12 +421,65 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
177
421
  device = seqlens_in_batch.device
178
422
 
179
423
  if forward_batch.forward_mode.is_decode_or_idle():
180
- # Normal Decode
181
- metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
182
- metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
424
+ if forward_batch.spec_info is not None:
425
+ # Draft Decode
426
+ # Here we only support topk = 1 for now.
427
+ metadata.cache_seqlens_int32 = (
428
+ seqlens_in_batch + (self.speculative_step_id + 1)
429
+ ).to(torch.int32)
430
+ metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item() + (
431
+ self.speculative_step_id + 1
432
+ )
433
+ metadata.cu_seqlens_q = torch.arange(
434
+ 0, batch_size + 1, dtype=torch.int32, device=device
435
+ )
436
+ metadata.cu_seqlens_k = torch.nn.functional.pad(
437
+ torch.cumsum(
438
+ metadata.cache_seqlens_int32, dim=0, dtype=torch.int32
439
+ ),
440
+ (1, 0),
441
+ )
442
+ metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
443
+ forward_batch.req_pool_indices, : metadata.max_seq_len_k
444
+ ]
445
+ else:
446
+ # Normal Decode
447
+ metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
448
+ metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
449
+ metadata.cu_seqlens_q = torch.arange(
450
+ 0, batch_size + 1, dtype=torch.int32, device=device
451
+ )
452
+ metadata.cu_seqlens_k = torch.nn.functional.pad(
453
+ torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
454
+ )
455
+ metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
456
+ forward_batch.req_pool_indices, : metadata.max_seq_len_k
457
+ ]
458
+ elif forward_batch.forward_mode.is_target_verify():
459
+ # Only support topk = 1 for now.
460
+ metadata.cache_seqlens_int32 = (
461
+ forward_batch.seq_lens + self.speculative_num_draft_tokens
462
+ ).to(torch.int32)
463
+ metadata.max_seq_len_q = self.speculative_num_draft_tokens
464
+ metadata.max_seq_len_k = (
465
+ forward_batch.seq_lens_cpu.max().item()
466
+ + self.speculative_num_draft_tokens
467
+ )
468
+ metadata.cu_seqlens_q = torch.arange(
469
+ 0,
470
+ batch_size * self.speculative_num_draft_tokens + 1,
471
+ self.speculative_num_draft_tokens,
472
+ dtype=torch.int32,
473
+ device=device,
474
+ )
475
+ metadata.cu_seqlens_k = torch.nn.functional.pad(
476
+ torch.cumsum(metadata.cache_seqlens_int32, dim=0, dtype=torch.int32),
477
+ (1, 0),
478
+ )
183
479
  metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
184
480
  forward_batch.req_pool_indices, : metadata.max_seq_len_k
185
481
  ]
482
+
186
483
  else:
187
484
  metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
188
485
  metadata.max_seq_len_k = forward_batch.seq_lens_cpu.max().item()
@@ -193,7 +490,10 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
193
490
  forward_batch.req_pool_indices, : metadata.max_seq_len_k
194
491
  ]
195
492
 
196
- if any(forward_batch.extend_prefix_lens_cpu):
493
+ if (
494
+ any(forward_batch.extend_prefix_lens_cpu)
495
+ or forward_batch.forward_mode == ForwardMode.DRAFT_EXTEND
496
+ ):
197
497
  extend_seq_lens = forward_batch.extend_seq_lens
198
498
  metadata.max_seq_len_q = max(forward_batch.extend_seq_lens_cpu)
199
499
  metadata.cu_seqlens_q = torch.nn.functional.pad(
@@ -263,7 +563,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
263
563
  workspace_buffer=self.workspace_buffer,
264
564
  block_tables=self.forward_metadata.page_table,
265
565
  seq_lens=self.forward_metadata.cache_seqlens_int32,
266
- max_seq_len=self.forward_metadata.max_seq_len_k,
566
+ max_seq_len=self.max_context_len,
267
567
  bmm1_scale=bmm1_scale,
268
568
  bmm2_scale=bmm2_scale,
269
569
  window_left=layer.sliding_window_size,
@@ -318,7 +618,7 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
318
618
  block_tables=self.forward_metadata.page_table,
319
619
  seq_lens=self.forward_metadata.cache_seqlens_int32,
320
620
  max_q_len=self.forward_metadata.max_seq_len_q,
321
- max_kv_len=self.forward_metadata.max_seq_len_k,
621
+ max_kv_len=self.max_context_len,
322
622
  bmm1_scale=bmm1_scale,
323
623
  bmm2_scale=bmm2_scale,
324
624
  batch_size=forward_batch.batch_size,
@@ -330,3 +630,65 @@ class TRTLLMHAAttnBackend(FlashInferAttnBackend):
330
630
  )
331
631
 
332
632
  return o.view(-1, layer.tp_q_head_num * layer.head_dim)
633
+
634
+
635
+ class TRTLLMHAAttnMultiStepDraftBackend(FlashInferMultiStepDraftBackend):
636
+ """Multi-step TRTLLM MHA attention kernel used by EAGLE."""
637
+
638
+ def __init__(
639
+ self, model_runner: ModelRunner, topk: int, speculative_num_steps: int
640
+ ):
641
+ super().__init__(model_runner, topk, speculative_num_steps)
642
+ for i in range(speculative_num_steps):
643
+ self.attn_backends[i] = TRTLLMHAAttnBackend(
644
+ model_runner,
645
+ skip_prefill=True,
646
+ kv_indptr_buf=self.kv_indptr[i],
647
+ kv_last_page_len_buf=self.kv_last_page_len,
648
+ speculative_step_id=i,
649
+ )
650
+
651
+ def init_forward_metadata(self, forward_batch: ForwardBatch):
652
+ for i in range(self.speculative_num_steps - 1):
653
+ self.attn_backends[i].init_forward_metadata(forward_batch)
654
+
655
+ def init_cuda_graph_state(self, max_bs: int, max_num_tokens: int):
656
+ for i in range(self.speculative_num_steps):
657
+ self.attn_backends[i].init_cuda_graph_state(max_bs, max_num_tokens)
658
+
659
+ def init_forward_metadata_capture_cuda_graph(
660
+ self,
661
+ forward_batch: ForwardBatch,
662
+ ):
663
+ assert forward_batch.spec_info is not None
664
+ assert isinstance(forward_batch.spec_info, EagleDraftInput)
665
+
666
+ for i in range(self.speculative_num_steps - 1):
667
+ self.attn_backends[i].init_forward_metadata_capture_cuda_graph(
668
+ forward_batch.batch_size,
669
+ forward_batch.batch_size * self.topk,
670
+ forward_batch.req_pool_indices,
671
+ forward_batch.seq_lens,
672
+ encoder_lens=forward_batch.encoder_lens,
673
+ forward_mode=ForwardMode.DECODE,
674
+ spec_info=forward_batch.spec_info,
675
+ )
676
+
677
+ def init_forward_metadata_replay_cuda_graph(
678
+ self, forward_batch: ForwardBatch, bs: int
679
+ ):
680
+ assert forward_batch.spec_info is not None
681
+ assert isinstance(forward_batch.spec_info, EagleDraftInput)
682
+
683
+ for i in range(self.speculative_num_steps - 1):
684
+
685
+ self.attn_backends[i].init_forward_metadata_replay_cuda_graph(
686
+ bs,
687
+ forward_batch.req_pool_indices,
688
+ forward_batch.seq_lens,
689
+ forward_batch.seq_lens_sum,
690
+ encoder_lens=forward_batch.encoder_lens,
691
+ forward_mode=ForwardMode.DECODE,
692
+ spec_info=forward_batch.spec_info,
693
+ seq_lens_cpu=forward_batch.seq_lens_cpu,
694
+ )
@@ -11,7 +11,10 @@ from typing import TYPE_CHECKING, Optional, Union
11
11
  import torch
12
12
  import triton
13
13
 
14
- from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
14
+ from sglang.srt.layers.attention.flashinfer_mla_backend import (
15
+ FlashInferMLAAttnBackend,
16
+ FlashInferMLAMultiStepDraftBackend,
17
+ )
15
18
  from sglang.srt.layers.attention.utils import (
16
19
  TRITON_PAD_NUM_PAGE_PER_BLOCK,
17
20
  create_flashmla_kv_indices_triton,
@@ -39,6 +42,8 @@ DEFAULT_WORKSPACE_SIZE_MB = 128 # Memory workspace size in MB
39
42
  # compute the LCM with other padding constraints.
40
43
  TRTLLM_BLOCK_CONSTRAINT = 128
41
44
 
45
+ global_zero_init_workspace_buffer = None
46
+
42
47
 
43
48
  @dataclass
44
49
  class TRTLLMMLADecodeMetadata:
@@ -83,13 +88,18 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
83
88
 
84
89
  # Workspace allocation
85
90
  self.workspace_size = DEFAULT_WORKSPACE_SIZE_MB * 1024 * 1024
86
- self.workspace_buffer = torch.empty(
87
- self.workspace_size, dtype=torch.int8, device=self.device
88
- )
91
+ global global_zero_init_workspace_buffer
92
+ if global_zero_init_workspace_buffer is None:
93
+ global_zero_init_workspace_buffer = torch.zeros(
94
+ self.workspace_size,
95
+ dtype=torch.uint8,
96
+ device=model_runner.device,
97
+ )
98
+ self.workspace_buffer = global_zero_init_workspace_buffer
89
99
 
90
100
  # CUDA graph state
91
101
  self.decode_cuda_graph_metadata = {}
92
- self.cuda_graph_kv_indices = None
102
+ self.decode_cuda_graph_kv_indices = None
93
103
  self.forward_metadata: Union[TRTLLMMLADecodeMetadata, None] = None
94
104
 
95
105
  def _calc_padded_blocks(self, max_seq_len: int) -> int:
@@ -160,15 +170,18 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
160
170
  kv_indices_buf: Optional[torch.Tensor] = None,
161
171
  ):
162
172
  """Initialize CUDA graph state for TRTLLM MLA."""
173
+
163
174
  max_blocks_per_seq = self._calc_padded_blocks(self.max_context_len)
164
175
 
165
- self.cuda_graph_kv_indices = torch.full(
176
+ self.decode_cuda_graph_kv_indices = torch.full(
166
177
  (max_bs, max_blocks_per_seq), -1, dtype=torch.int32, device=self.device
167
178
  )
168
- self.cuda_graph_workspace = torch.empty(
179
+ self.decode_cuda_graph_workspace = torch.empty(
169
180
  self.workspace_size, dtype=torch.int8, device=self.device
170
181
  )
171
182
 
183
+ super().init_cuda_graph_state(max_bs, max_num_tokens, kv_indices_buf)
184
+
172
185
  def init_forward_metadata_capture_cuda_graph(
173
186
  self,
174
187
  bs: int,
@@ -180,8 +193,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
180
193
  spec_info: Optional[SpecInfo],
181
194
  ):
182
195
  """Initialize metadata for CUDA graph capture."""
183
- # Delegate to parent for non-decode modes or when speculative execution is used.
184
- if not (forward_mode.is_decode_or_idle() and spec_info is None):
196
+
197
+ # Delegate to parent for non-decode modes.
198
+ if not forward_mode.is_decode_or_idle():
185
199
  return super().init_forward_metadata_capture_cuda_graph(
186
200
  bs,
187
201
  num_tokens,
@@ -192,9 +206,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
192
206
  spec_info,
193
207
  )
194
208
 
195
- # Custom fast-path for decode/idle without speculative execution.
209
+ # Custom fast-path for decode/idle.
196
210
  max_seqlen_pad = self._calc_padded_blocks(seq_lens.max().item())
197
- block_kv_indices = self.cuda_graph_kv_indices[:bs, :max_seqlen_pad]
211
+ block_kv_indices = self.decode_cuda_graph_kv_indices[:bs, :max_seqlen_pad]
198
212
 
199
213
  create_flashmla_kv_indices_triton[(bs,)](
200
214
  self.req_to_token,
@@ -208,7 +222,9 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
208
222
  PAGED_SIZE=self.page_size,
209
223
  )
210
224
 
211
- metadata = TRTLLMMLADecodeMetadata(self.cuda_graph_workspace, block_kv_indices)
225
+ metadata = TRTLLMMLADecodeMetadata(
226
+ self.decode_cuda_graph_workspace, block_kv_indices
227
+ )
212
228
  self.decode_cuda_graph_metadata[bs] = metadata
213
229
  self.forward_metadata = metadata
214
230
 
@@ -224,8 +240,8 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
224
240
  seq_lens_cpu: Optional[torch.Tensor],
225
241
  ):
226
242
  """Replay CUDA graph with new inputs."""
227
- # Delegate to parent for non-decode modes or when speculative execution is used.
228
- if not (forward_mode.is_decode_or_idle() and spec_info is None):
243
+ # Delegate to parent for non-decode modes.
244
+ if not forward_mode.is_decode_or_idle():
229
245
  return super().init_forward_metadata_replay_cuda_graph(
230
246
  bs,
231
247
  req_pool_indices,
@@ -258,11 +274,8 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
258
274
 
259
275
  def init_forward_metadata(self, forward_batch: ForwardBatch):
260
276
  """Initialize the metadata for a forward pass."""
261
- # Delegate to parent for non-decode modes or when speculative execution is used.
262
- if not (
263
- forward_batch.forward_mode.is_decode_or_idle()
264
- and forward_batch.spec_info is None
265
- ):
277
+ # Delegate to parent for non-decode modes.
278
+ if not forward_batch.forward_mode.is_decode_or_idle():
266
279
  return super().init_forward_metadata(forward_batch)
267
280
 
268
281
  bs = forward_batch.batch_size
@@ -467,3 +480,20 @@ class TRTLLMMLABackend(FlashInferMLAAttnBackend):
467
480
  output = raw_out_v.view(-1, layer.tp_q_head_num * layer.v_head_dim)
468
481
 
469
482
  return output
483
+
484
+
485
+ class TRTLLMMLAMultiStepDraftBackend(FlashInferMLAMultiStepDraftBackend):
486
+ """Multi-step draft backend for TRT-LLM MLA used by EAGLE."""
487
+
488
+ def __init__(
489
+ self, model_runner: "ModelRunner", topk: int, speculative_num_steps: int
490
+ ):
491
+ super().__init__(model_runner, topk, speculative_num_steps)
492
+
493
+ for i in range(self.speculative_num_steps):
494
+ self.attn_backends[i] = TRTLLMMLABackend(
495
+ model_runner,
496
+ skip_prefill=True,
497
+ kv_indptr_buf=self.kv_indptr[i],
498
+ q_indptr_decode_buf=self.q_indptr_decode,
499
+ )