sglang 0.4.4.post1__py3-none-any.whl → 0.4.4.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (185) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +6 -0
  3. sglang/bench_one_batch.py +1 -1
  4. sglang/bench_one_batch_server.py +1 -1
  5. sglang/bench_serving.py +26 -4
  6. sglang/check_env.py +3 -4
  7. sglang/lang/backend/openai.py +18 -5
  8. sglang/lang/chat_template.py +28 -7
  9. sglang/lang/interpreter.py +7 -3
  10. sglang/lang/ir.py +10 -0
  11. sglang/srt/_custom_ops.py +1 -1
  12. sglang/srt/code_completion_parser.py +174 -0
  13. sglang/srt/configs/__init__.py +2 -6
  14. sglang/srt/configs/deepseekvl2.py +676 -0
  15. sglang/srt/configs/janus_pro.py +3 -4
  16. sglang/srt/configs/load_config.py +1 -0
  17. sglang/srt/configs/model_config.py +49 -8
  18. sglang/srt/configs/utils.py +25 -0
  19. sglang/srt/connector/__init__.py +51 -0
  20. sglang/srt/connector/base_connector.py +112 -0
  21. sglang/srt/connector/redis.py +85 -0
  22. sglang/srt/connector/s3.py +122 -0
  23. sglang/srt/connector/serde/__init__.py +31 -0
  24. sglang/srt/connector/serde/safe_serde.py +29 -0
  25. sglang/srt/connector/serde/serde.py +43 -0
  26. sglang/srt/connector/utils.py +35 -0
  27. sglang/srt/conversation.py +88 -0
  28. sglang/srt/disaggregation/conn.py +81 -0
  29. sglang/srt/disaggregation/decode.py +495 -0
  30. sglang/srt/disaggregation/mini_lb.py +285 -0
  31. sglang/srt/disaggregation/prefill.py +249 -0
  32. sglang/srt/disaggregation/utils.py +44 -0
  33. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
  34. sglang/srt/distributed/parallel_state.py +42 -8
  35. sglang/srt/entrypoints/engine.py +55 -5
  36. sglang/srt/entrypoints/http_server.py +78 -13
  37. sglang/srt/entrypoints/verl_engine.py +2 -0
  38. sglang/srt/function_call_parser.py +133 -55
  39. sglang/srt/hf_transformers_utils.py +28 -3
  40. sglang/srt/layers/activation.py +4 -2
  41. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  42. sglang/srt/layers/attention/flashattention_backend.py +434 -0
  43. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  44. sglang/srt/layers/attention/flashmla_backend.py +284 -0
  45. sglang/srt/layers/attention/triton_backend.py +171 -38
  46. sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
  47. sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
  48. sglang/srt/layers/attention/utils.py +53 -0
  49. sglang/srt/layers/attention/vision.py +9 -28
  50. sglang/srt/layers/dp_attention.py +41 -19
  51. sglang/srt/layers/layernorm.py +24 -2
  52. sglang/srt/layers/linear.py +17 -5
  53. sglang/srt/layers/logits_processor.py +25 -7
  54. sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
  55. sglang/srt/layers/moe/ep_moe/layer.py +273 -1
  56. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
  57. sglang/srt/layers/moe/fused_moe_native.py +2 -1
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  62. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
  63. sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
  64. sglang/srt/layers/moe/topk.py +60 -20
  65. sglang/srt/layers/parameter.py +1 -1
  66. sglang/srt/layers/quantization/__init__.py +80 -53
  67. sglang/srt/layers/quantization/awq.py +200 -0
  68. sglang/srt/layers/quantization/base_config.py +5 -0
  69. sglang/srt/layers/quantization/blockwise_int8.py +1 -1
  70. sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
  71. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
  72. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
  73. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
  74. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
  75. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
  76. sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
  77. sglang/srt/layers/quantization/fp8.py +76 -34
  78. sglang/srt/layers/quantization/fp8_kernel.py +25 -8
  79. sglang/srt/layers/quantization/fp8_utils.py +284 -28
  80. sglang/srt/layers/quantization/gptq.py +36 -19
  81. sglang/srt/layers/quantization/kv_cache.py +98 -0
  82. sglang/srt/layers/quantization/modelopt_quant.py +9 -7
  83. sglang/srt/layers/quantization/utils.py +153 -0
  84. sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
  85. sglang/srt/layers/rotary_embedding.py +78 -87
  86. sglang/srt/layers/sampler.py +1 -1
  87. sglang/srt/lora/backend/base_backend.py +4 -4
  88. sglang/srt/lora/backend/flashinfer_backend.py +12 -9
  89. sglang/srt/lora/backend/triton_backend.py +5 -8
  90. sglang/srt/lora/layers.py +87 -33
  91. sglang/srt/lora/lora.py +2 -22
  92. sglang/srt/lora/lora_manager.py +67 -30
  93. sglang/srt/lora/mem_pool.py +117 -52
  94. sglang/srt/lora/triton_ops/gate_up_lora_b.py +10 -4
  95. sglang/srt/lora/triton_ops/qkv_lora_b.py +8 -3
  96. sglang/srt/lora/triton_ops/sgemm_lora_a.py +16 -5
  97. sglang/srt/lora/triton_ops/sgemm_lora_b.py +11 -6
  98. sglang/srt/lora/utils.py +18 -1
  99. sglang/srt/managers/cache_controller.py +2 -5
  100. sglang/srt/managers/data_parallel_controller.py +30 -8
  101. sglang/srt/managers/expert_distribution.py +81 -0
  102. sglang/srt/managers/io_struct.py +43 -5
  103. sglang/srt/managers/mm_utils.py +373 -0
  104. sglang/srt/managers/multimodal_processor.py +68 -0
  105. sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
  106. sglang/srt/managers/multimodal_processors/clip.py +63 -0
  107. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
  108. sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
  109. sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
  110. sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
  111. sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
  112. sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
  113. sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
  114. sglang/srt/managers/schedule_batch.py +134 -30
  115. sglang/srt/managers/scheduler.py +290 -31
  116. sglang/srt/managers/session_controller.py +1 -1
  117. sglang/srt/managers/tokenizer_manager.py +59 -24
  118. sglang/srt/managers/tp_worker.py +4 -1
  119. sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
  120. sglang/srt/managers/utils.py +6 -1
  121. sglang/srt/mem_cache/hiradix_cache.py +18 -7
  122. sglang/srt/mem_cache/memory_pool.py +255 -98
  123. sglang/srt/mem_cache/paged_allocator.py +2 -2
  124. sglang/srt/mem_cache/radix_cache.py +4 -4
  125. sglang/srt/model_executor/cuda_graph_runner.py +36 -21
  126. sglang/srt/model_executor/forward_batch_info.py +68 -11
  127. sglang/srt/model_executor/model_runner.py +75 -8
  128. sglang/srt/model_loader/loader.py +171 -3
  129. sglang/srt/model_loader/weight_utils.py +51 -3
  130. sglang/srt/models/clip.py +563 -0
  131. sglang/srt/models/deepseek_janus_pro.py +31 -88
  132. sglang/srt/models/deepseek_nextn.py +22 -10
  133. sglang/srt/models/deepseek_v2.py +329 -73
  134. sglang/srt/models/deepseek_vl2.py +358 -0
  135. sglang/srt/models/gemma3_causal.py +694 -0
  136. sglang/srt/models/gemma3_mm.py +468 -0
  137. sglang/srt/models/llama.py +47 -7
  138. sglang/srt/models/llama_eagle.py +1 -0
  139. sglang/srt/models/llama_eagle3.py +196 -0
  140. sglang/srt/models/llava.py +3 -3
  141. sglang/srt/models/llavavid.py +3 -3
  142. sglang/srt/models/minicpmo.py +1995 -0
  143. sglang/srt/models/minicpmv.py +62 -137
  144. sglang/srt/models/mllama.py +4 -4
  145. sglang/srt/models/phi3_small.py +1 -1
  146. sglang/srt/models/qwen2.py +3 -0
  147. sglang/srt/models/qwen2_5_vl.py +68 -146
  148. sglang/srt/models/qwen2_classification.py +75 -0
  149. sglang/srt/models/qwen2_moe.py +9 -1
  150. sglang/srt/models/qwen2_vl.py +25 -63
  151. sglang/srt/openai_api/adapter.py +201 -104
  152. sglang/srt/openai_api/protocol.py +33 -7
  153. sglang/srt/patch_torch.py +71 -0
  154. sglang/srt/sampling/sampling_batch_info.py +1 -1
  155. sglang/srt/sampling/sampling_params.py +6 -6
  156. sglang/srt/server_args.py +114 -14
  157. sglang/srt/speculative/build_eagle_tree.py +7 -347
  158. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
  159. sglang/srt/speculative/eagle_utils.py +208 -252
  160. sglang/srt/speculative/eagle_worker.py +140 -54
  161. sglang/srt/speculative/spec_info.py +6 -1
  162. sglang/srt/torch_memory_saver_adapter.py +22 -0
  163. sglang/srt/utils.py +215 -21
  164. sglang/test/__init__.py +0 -0
  165. sglang/test/attention/__init__.py +0 -0
  166. sglang/test/attention/test_flashattn_backend.py +312 -0
  167. sglang/test/runners.py +29 -2
  168. sglang/test/test_activation.py +2 -1
  169. sglang/test/test_block_fp8.py +5 -4
  170. sglang/test/test_block_fp8_ep.py +2 -1
  171. sglang/test/test_dynamic_grad_mode.py +58 -0
  172. sglang/test/test_layernorm.py +3 -2
  173. sglang/test/test_utils.py +56 -5
  174. sglang/utils.py +31 -0
  175. sglang/version.py +1 -1
  176. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/METADATA +16 -8
  177. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/RECORD +180 -132
  178. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/WHEEL +1 -1
  179. sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
  180. sglang/srt/managers/image_processor.py +0 -55
  181. sglang/srt/managers/image_processors/base_image_processor.py +0 -219
  182. sglang/srt/managers/image_processors/minicpmv.py +0 -86
  183. sglang/srt/managers/multi_modality_padding.py +0 -134
  184. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info/licenses}/LICENSE +0 -0
  185. {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/top_level.txt +0 -0
@@ -33,16 +33,18 @@ from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_N
33
33
  from sglang.srt.configs import (
34
34
  ChatGLMConfig,
35
35
  DbrxConfig,
36
+ DeepseekVL2Config,
36
37
  ExaoneConfig,
37
38
  MultiModalityConfig,
38
- Qwen2_5_VLConfig,
39
39
  )
40
+ from sglang.srt.connector import create_remote_connector
41
+ from sglang.srt.utils import is_remote_url
40
42
 
41
43
  _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
42
44
  ChatGLMConfig.model_type: ChatGLMConfig,
43
45
  DbrxConfig.model_type: DbrxConfig,
44
46
  ExaoneConfig.model_type: ExaoneConfig,
45
- Qwen2_5_VLConfig.model_type: Qwen2_5_VLConfig,
47
+ DeepseekVL2Config.model_type: DeepseekVL2Config,
46
48
  MultiModalityConfig.model_type: MultiModalityConfig,
47
49
  }
48
50
 
@@ -155,6 +157,14 @@ def get_tokenizer(
155
157
  kwargs["gguf_file"] = tokenizer_name
156
158
  tokenizer_name = Path(tokenizer_name).parent
157
159
 
160
+ if is_remote_url(tokenizer_name):
161
+ # BaseConnector implements __del__() to clean up the local dir.
162
+ # Since config files need to exist all the time, so we DO NOT use
163
+ # with statement to avoid closing the client.
164
+ client = create_remote_connector(tokenizer_name)
165
+ client.pull_files(ignore_pattern=["*.pt", "*.safetensors", "*.bin"])
166
+ tokenizer_name = client.get_local_dir()
167
+
158
168
  try:
159
169
  tokenizer = AutoTokenizer.from_pretrained(
160
170
  tokenizer_name,
@@ -207,11 +217,26 @@ def get_processor(
207
217
  tokenizer_revision: Optional[str] = None,
208
218
  **kwargs,
209
219
  ):
220
+ # pop 'revision' from kwargs if present.
221
+ revision = kwargs.pop("revision", tokenizer_revision)
222
+
223
+ config = AutoConfig.from_pretrained(
224
+ tokenizer_name,
225
+ trust_remote_code=trust_remote_code,
226
+ revision=revision,
227
+ **kwargs,
228
+ )
229
+
230
+ # fix: for Qwen2-VL model, inject default 'size' if not provided.
231
+ if config.model_type in {"qwen2_vl"}:
232
+ if "size" not in kwargs:
233
+ kwargs["size"] = {"shortest_edge": 3136, "longest_edge": 1003520}
234
+
210
235
  processor = AutoProcessor.from_pretrained(
211
236
  tokenizer_name,
212
237
  *args,
213
238
  trust_remote_code=trust_remote_code,
214
- tokenizer_revision=tokenizer_revision,
239
+ revision=revision,
215
240
  **kwargs,
216
241
  )
217
242
 
@@ -23,7 +23,9 @@ import torch.nn.functional as F
23
23
 
24
24
  from sglang.srt.utils import is_cuda_available
25
25
 
26
- if is_cuda_available():
26
+ _is_cuda = is_cuda_available()
27
+
28
+ if _is_cuda:
27
29
  from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
28
30
 
29
31
  from sglang.srt.custom_op import CustomOp
@@ -165,7 +167,7 @@ def get_act_fn(
165
167
  return act_fn
166
168
 
167
169
 
168
- if not is_cuda_available():
170
+ if not _is_cuda:
169
171
  logger.info(
170
172
  "sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries."
171
173
  )
@@ -47,7 +47,7 @@ class AttentionBackend(ABC):
47
47
  spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
48
48
  seq_lens_cpu: Optional[torch.Tensor],
49
49
  ):
50
- """Init the metadata for a forward pass for replying a cuda graph."""
50
+ """Init the metadata for a forward pass for replaying a cuda graph."""
51
51
  raise NotImplementedError()
52
52
 
53
53
  def get_cuda_graph_seq_len_fill_value(self):
@@ -0,0 +1,434 @@
1
+ from __future__ import annotations
2
+
3
+ from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
4
+
5
+ """
6
+ Support different attention backends.
7
+ Now there are three backends: FlashInfer, Triton and FlashAttention.
8
+ Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
9
+ """
10
+
11
+ from dataclasses import dataclass
12
+ from typing import TYPE_CHECKING, Optional, Union
13
+
14
+ import torch
15
+
16
+ from sglang.srt.configs.model_config import AttentionArch
17
+ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
18
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
19
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
20
+
21
+ if TYPE_CHECKING:
22
+ from sglang.srt.layers.radix_attention import RadixAttention
23
+ from sglang.srt.model_executor.model_runner import ModelRunner
24
+
25
+ from flash_attn_interface import flash_attn_with_kvcache
26
+
27
+
28
+ @dataclass
29
+ class FlashAttentionMetadata:
30
+ """Metadata for decode operations to avoid redundant computations."""
31
+
32
+ cu_seqlens_q: torch.Tensor = None
33
+ cu_seqlens_k: torch.Tensor = None
34
+ max_seq_len_q: int = 0
35
+ max_seq_len_k: int = 0
36
+ window_size: tuple = (-1, -1)
37
+ page_table: torch.Tensor = None
38
+ cache_seqlens_int32: torch.Tensor = None
39
+
40
+
41
+ class FlashAttentionBackend(AttentionBackend):
42
+ """FlashAttention backend implementation."""
43
+
44
+ def __init__(
45
+ self,
46
+ model_runner: ModelRunner,
47
+ skip_prefill: bool = False,
48
+ ):
49
+ super().__init__()
50
+
51
+ assert not (
52
+ model_runner.sliding_window_size is not None
53
+ and model_runner.model_config.is_encoder_decoder
54
+ ), "Sliding window and cross attention are not supported together"
55
+
56
+ # Initialize metadata
57
+ self.forward_metadata: FlashAttentionMetadata = None
58
+ self.max_context_len = model_runner.model_config.context_len
59
+ self.device = model_runner.device
60
+ self.decode_cuda_graph_metadata = {}
61
+ self.req_to_token = model_runner.req_to_token_pool.req_to_token
62
+ self.page_size = model_runner.page_size
63
+ self.use_mla = (
64
+ model_runner.model_config.attention_arch == AttentionArch.MLA
65
+ ) and (not global_server_args_dict["disable_mla"])
66
+
67
+ def init_forward_metadata(self, forward_batch: ForwardBatch):
68
+ """Initialize forward metadata to cache repetitive calculations."""
69
+ # Create metadata based on forward mode
70
+ metadata = FlashAttentionMetadata()
71
+
72
+ # Get sequence information
73
+ seqlens_in_batch = forward_batch.seq_lens
74
+ # Precompute int32 version of sequence lengths
75
+ metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
76
+ batch_size = len(seqlens_in_batch)
77
+ device = seqlens_in_batch.device
78
+ metadata.cu_seqlens_k = torch.nn.functional.pad(
79
+ torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
80
+ )
81
+ # Precompute maximum sequence length
82
+ metadata.max_seq_len_k = seqlens_in_batch.max().item()
83
+ # Precompute page table
84
+ metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
85
+ forward_batch.req_pool_indices, : metadata.max_seq_len_k
86
+ ]
87
+
88
+ # Precompute strided indices
89
+ # [0, page_size, 2 * page_size, ...]
90
+ if self.page_size > 1:
91
+ self.strided_indices = torch.arange(
92
+ 0, metadata.page_table.shape[1], self.page_size, device=self.device
93
+ )
94
+ metadata.page_table = (
95
+ metadata.page_table[:, self.strided_indices] // self.page_size
96
+ )
97
+
98
+ if forward_batch.forward_mode == ForwardMode.DECODE:
99
+ # Precompute cumulative sequence lengths
100
+ metadata.cu_seqlens_q = torch.arange(
101
+ 0, batch_size + 1, dtype=torch.int32, device=device
102
+ )
103
+ else:
104
+ # Precompute cumulative sequence lengths
105
+ if any(forward_batch.extend_prefix_lens_cpu):
106
+ extend_seq_lens = forward_batch.extend_seq_lens
107
+ metadata.cu_seqlens_q = torch.nn.functional.pad(
108
+ torch.cumsum(extend_seq_lens, dim=0, dtype=torch.int32), (1, 0)
109
+ )
110
+ metadata.max_seq_len_q = max(forward_batch.extend_seq_lens_cpu)
111
+ else:
112
+ metadata.cu_seqlens_q = metadata.cu_seqlens_k
113
+ metadata.max_seq_len_q = metadata.max_seq_len_k
114
+ self.forward_metadata = metadata
115
+
116
+ def forward_extend(
117
+ self,
118
+ q: torch.Tensor,
119
+ k: torch.Tensor,
120
+ v: torch.Tensor,
121
+ layer: RadixAttention,
122
+ forward_batch: ForwardBatch,
123
+ save_kv_cache=True,
124
+ ):
125
+
126
+ if k is not None:
127
+ assert v is not None
128
+ if save_kv_cache:
129
+ cache_loc = (
130
+ forward_batch.out_cache_loc
131
+ if not layer.is_cross_attention
132
+ else forward_batch.encoder_out_cache_loc
133
+ )
134
+ if not self.use_mla:
135
+ forward_batch.token_to_kv_pool.set_kv_buffer(
136
+ layer, cache_loc, k, v, layer.k_scale, layer.v_scale
137
+ )
138
+ else:
139
+ forward_batch.token_to_kv_pool.set_kv_buffer(
140
+ layer,
141
+ cache_loc,
142
+ k,
143
+ v,
144
+ )
145
+
146
+ # Use precomputed metadata
147
+ metadata = self.forward_metadata
148
+
149
+ # Calculate window size (can be moved to metadata if layer properties don't change)
150
+ # we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1
151
+ # here is two side inclusive
152
+ window_size = (
153
+ (layer.sliding_window_size, 0)
154
+ if layer.sliding_window_size is not None
155
+ else (-1, -1)
156
+ )
157
+
158
+ page_table = metadata.page_table
159
+
160
+ # # Use Flash Attention for prefill
161
+ if not self.use_mla:
162
+ # Do multi-head attention
163
+ kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
164
+ key_cache, value_cache = kv_cache[0], kv_cache[1]
165
+ key_cache = key_cache.view(
166
+ -1, self.page_size, layer.tp_k_head_num, layer.head_dim
167
+ )
168
+ value_cache = value_cache.view(
169
+ -1, self.page_size, layer.tp_v_head_num, layer.head_dim
170
+ )
171
+ o = flash_attn_with_kvcache(
172
+ q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
173
+ k_cache=key_cache,
174
+ v_cache=value_cache,
175
+ page_table=page_table,
176
+ cache_seqlens=metadata.cache_seqlens_int32,
177
+ cu_seqlens_q=metadata.cu_seqlens_q,
178
+ cu_seqlens_k_new=metadata.cu_seqlens_k,
179
+ max_seqlen_q=metadata.max_seq_len_q,
180
+ softmax_scale=layer.scaling,
181
+ causal=True,
182
+ window_size=window_size,
183
+ softcap=layer.logit_cap,
184
+ k_descale=layer.k_scale,
185
+ v_descale=layer.v_scale,
186
+ )
187
+ else:
188
+ # Do absorbed multi-latent attention
189
+ kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
190
+ k_rope = kv_cache[:, :, layer.v_head_dim :]
191
+ c_kv = kv_cache[:, :, : layer.v_head_dim]
192
+ k_rope_cache = k_rope.view(
193
+ -1,
194
+ self.page_size,
195
+ layer.tp_k_head_num,
196
+ layer.head_dim - layer.v_head_dim,
197
+ )
198
+ c_kv_cache = c_kv.view(
199
+ -1, self.page_size, layer.tp_v_head_num, layer.v_head_dim
200
+ )
201
+
202
+ q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
203
+ q_nope = q_all[:, :, : layer.v_head_dim]
204
+ q_rope = q_all[:, :, layer.v_head_dim :]
205
+ o = flash_attn_with_kvcache(
206
+ q=q_rope,
207
+ k_cache=k_rope_cache,
208
+ v_cache=c_kv_cache,
209
+ qv=q_nope,
210
+ page_table=page_table,
211
+ cache_seqlens=metadata.cache_seqlens_int32,
212
+ cu_seqlens_q=metadata.cu_seqlens_q,
213
+ cu_seqlens_k_new=metadata.cu_seqlens_k,
214
+ max_seqlen_q=metadata.max_seq_len_q,
215
+ softmax_scale=layer.scaling,
216
+ causal=True,
217
+ softcap=layer.logit_cap,
218
+ k_descale=layer.k_scale,
219
+ v_descale=layer.v_scale,
220
+ )
221
+
222
+ return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
223
+
224
+ def forward_decode(
225
+ self,
226
+ q: torch.Tensor,
227
+ k: torch.Tensor,
228
+ v: torch.Tensor,
229
+ layer: RadixAttention,
230
+ forward_batch: ForwardBatch,
231
+ save_kv_cache=True,
232
+ ) -> torch.Tensor:
233
+ """Forward pass with FlashAttention using precomputed metadata."""
234
+ # Save KV cache if needed
235
+ if k is not None:
236
+ assert v is not None
237
+ if save_kv_cache:
238
+ cache_loc = (
239
+ forward_batch.out_cache_loc
240
+ if not layer.is_cross_attention
241
+ else forward_batch.encoder_out_cache_loc
242
+ )
243
+ if not self.use_mla:
244
+ forward_batch.token_to_kv_pool.set_kv_buffer(
245
+ layer, cache_loc, k, v, layer.k_scale, layer.v_scale
246
+ )
247
+ else:
248
+ forward_batch.token_to_kv_pool.set_kv_buffer(
249
+ layer,
250
+ cache_loc,
251
+ k,
252
+ v,
253
+ )
254
+
255
+ # Use precomputed metadata
256
+ metadata = self.forward_metadata
257
+
258
+ # Calculate window size (can be moved to metadata if layer properties don't change)
259
+ # we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1
260
+ # here is two side inclusive
261
+ window_size = (
262
+ (layer.sliding_window_size, 0)
263
+ if layer.sliding_window_size is not None
264
+ else (-1, -1)
265
+ )
266
+
267
+ page_table = metadata.page_table
268
+
269
+ if not self.use_mla:
270
+ # Do multi-head attention
271
+
272
+ # Get KV cache
273
+ kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
274
+ key_cache, value_cache = kv_cache[0], kv_cache[1]
275
+ key_cache = key_cache.view(
276
+ -1, self.page_size, layer.tp_k_head_num, layer.head_dim
277
+ )
278
+ value_cache = value_cache.view(
279
+ -1, self.page_size, layer.tp_v_head_num, layer.head_dim
280
+ )
281
+
282
+ # Pre-reshape query tensor
283
+ q_reshaped = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
284
+
285
+ # Run attention with precomputed values
286
+ o = flash_attn_with_kvcache(
287
+ q=q_reshaped,
288
+ k_cache=key_cache,
289
+ v_cache=value_cache,
290
+ page_table=page_table,
291
+ cache_seqlens=metadata.cache_seqlens_int32,
292
+ cu_seqlens_q=metadata.cu_seqlens_q,
293
+ cu_seqlens_k_new=metadata.cu_seqlens_k,
294
+ max_seqlen_q=1,
295
+ softmax_scale=layer.scaling,
296
+ causal=True,
297
+ window_size=window_size,
298
+ softcap=layer.logit_cap,
299
+ k_descale=layer.k_scale,
300
+ v_descale=layer.v_scale,
301
+ )
302
+ else:
303
+ # Do absorbed multi-latent attention
304
+ kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
305
+ k_rope = kv_cache[:, :, layer.v_head_dim :]
306
+ c_kv = kv_cache[:, :, : layer.v_head_dim]
307
+ k_rope_cache = k_rope.view(
308
+ -1,
309
+ self.page_size,
310
+ layer.tp_k_head_num,
311
+ layer.head_dim - layer.v_head_dim,
312
+ )
313
+ c_kv_cache = c_kv.view(
314
+ -1, self.page_size, layer.tp_v_head_num, layer.v_head_dim
315
+ )
316
+
317
+ q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
318
+ q_nope = q_all[:, :, : layer.v_head_dim]
319
+ q_rope = q_all[:, :, layer.v_head_dim :]
320
+
321
+ o = flash_attn_with_kvcache(
322
+ q=q_rope,
323
+ k_cache=k_rope_cache,
324
+ v_cache=c_kv_cache,
325
+ qv=q_nope,
326
+ page_table=page_table,
327
+ cache_seqlens=metadata.cache_seqlens_int32,
328
+ cu_seqlens_q=metadata.cu_seqlens_q,
329
+ cu_seqlens_k_new=metadata.cu_seqlens_k,
330
+ max_seqlen_q=1,
331
+ softmax_scale=layer.scaling,
332
+ causal=True,
333
+ softcap=layer.logit_cap,
334
+ k_descale=layer.k_scale,
335
+ v_descale=layer.v_scale,
336
+ )
337
+
338
+ return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
339
+
340
+ def init_cuda_graph_state(self, max_bs: int):
341
+ """Initialize CUDA graph state for the attention backend.
342
+
343
+ Args:
344
+ max_bs (int): Maximum batch size to support in CUDA graphs
345
+
346
+ This creates fixed-size tensors that will be reused during CUDA graph replay
347
+ to avoid memory allocations.
348
+ """
349
+ # Initialize fixed size tensors for decode operations
350
+ self.decode_cuda_graph_metadata = {
351
+ # Page table for token mapping (batch_size, max_context_len)
352
+ "page_table": torch.zeros(
353
+ max_bs,
354
+ (self.max_context_len + self.page_size - 1) // self.page_size,
355
+ dtype=torch.int32,
356
+ device=self.device,
357
+ ),
358
+ "strided_indices": torch.arange(
359
+ 0, self.max_context_len, self.page_size, device=self.device
360
+ ),
361
+ }
362
+
363
+ def init_forward_metadata_capture_cuda_graph(
364
+ self,
365
+ bs: int,
366
+ num_tokens: int,
367
+ req_pool_indices: torch.Tensor,
368
+ seq_lens: torch.Tensor,
369
+ encoder_lens: Optional[torch.Tensor],
370
+ forward_mode: ForwardMode,
371
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
372
+ ):
373
+ """Initialize forward metadata for capturing CUDA graph."""
374
+ metadata = FlashAttentionMetadata()
375
+ # Get sequence information
376
+ metadata.cache_seqlens_int32 = seq_lens.to(torch.int32)
377
+ batch_size = len(seq_lens)
378
+ device = seq_lens.device
379
+ metadata.cu_seqlens_k = torch.nn.functional.pad(
380
+ torch.cumsum(seq_lens, dim=0, dtype=torch.int32), (1, 0)
381
+ )
382
+ # Precompute maximum sequence length
383
+ metadata.max_seq_len_k = seq_lens.max().item()
384
+ # Precompute page table
385
+ metadata.page_table = self.decode_cuda_graph_metadata["page_table"][
386
+ req_pool_indices, :
387
+ ]
388
+ if forward_mode == ForwardMode.DECODE:
389
+ # Precompute cumulative sequence lengths
390
+ metadata.cu_seqlens_q = torch.arange(
391
+ 0, batch_size + 1, dtype=torch.int32, device=device
392
+ )
393
+ else:
394
+ raise ValueError("Do not support Prefill Mode cuda graph")
395
+ self.decode_cuda_graph_metadata[bs] = metadata
396
+ self.forward_metadata = metadata
397
+
398
+ def init_forward_metadata_replay_cuda_graph(
399
+ self,
400
+ bs: int,
401
+ req_pool_indices: torch.Tensor,
402
+ seq_lens: torch.Tensor,
403
+ seq_lens_sum: int,
404
+ encoder_lens: Optional[torch.Tensor],
405
+ forward_mode: ForwardMode,
406
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
407
+ seq_lens_cpu: Optional[torch.Tensor],
408
+ ):
409
+ # """Initialize forward metadata for replaying CUDA graph."""
410
+ metadata = self.decode_cuda_graph_metadata[bs]
411
+
412
+ # For CPU operations
413
+ max_len = seq_lens_cpu[:bs].max().item()
414
+ metadata.max_seq_len_k = max_len
415
+
416
+ # For GPU operations
417
+ seq_lens_in_batch = seq_lens[:bs]
418
+ metadata.cache_seqlens_int32 = seq_lens_in_batch.to(torch.int32)
419
+ metadata.cu_seqlens_k = torch.nn.functional.pad(
420
+ torch.cumsum(seq_lens_in_batch, dim=0, dtype=torch.int32), (1, 0)
421
+ )
422
+
423
+ max_seq_pages = (metadata.max_seq_len_k + self.page_size - 1) // self.page_size
424
+ page_indices = self.req_to_token[
425
+ :, self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages]
426
+ ]
427
+ page_indices = page_indices[req_pool_indices[:bs]] // self.page_size
428
+ metadata.page_table[:, :max_seq_pages].copy_(page_indices)
429
+ metadata.page_table[:, max_seq_pages:].fill_(0)
430
+ self.forward_metadata = metadata
431
+
432
+ def get_cuda_graph_seq_len_fill_value(self):
433
+ """Get the fill value for sequence length in CUDA graph."""
434
+ return 0
@@ -1008,7 +1008,7 @@ class FlashInferMultiStepDraftBackend:
1008
1008
  global_override_indptr_cpu = None
1009
1009
 
1010
1010
  def init_forward_metadata(self, forward_batch: ForwardBatch):
1011
- kv_indices = torch.zeros(
1011
+ kv_indices = torch.empty(
1012
1012
  (
1013
1013
  self.speculative_num_steps,
1014
1014
  forward_batch.batch_size * self.topk * self.max_context_len,