sglang 0.4.4__py3-none-any.whl → 0.4.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (176) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +6 -0
  3. sglang/bench_one_batch.py +1 -1
  4. sglang/bench_one_batch_server.py +1 -1
  5. sglang/bench_serving.py +3 -1
  6. sglang/check_env.py +3 -4
  7. sglang/lang/backend/openai.py +18 -5
  8. sglang/lang/chat_template.py +28 -7
  9. sglang/lang/interpreter.py +7 -3
  10. sglang/lang/ir.py +10 -0
  11. sglang/srt/_custom_ops.py +1 -1
  12. sglang/srt/code_completion_parser.py +174 -0
  13. sglang/srt/configs/__init__.py +2 -6
  14. sglang/srt/configs/deepseekvl2.py +667 -0
  15. sglang/srt/configs/janus_pro.py +3 -4
  16. sglang/srt/configs/load_config.py +1 -0
  17. sglang/srt/configs/model_config.py +63 -11
  18. sglang/srt/configs/utils.py +25 -0
  19. sglang/srt/connector/__init__.py +51 -0
  20. sglang/srt/connector/base_connector.py +112 -0
  21. sglang/srt/connector/redis.py +85 -0
  22. sglang/srt/connector/s3.py +122 -0
  23. sglang/srt/connector/serde/__init__.py +31 -0
  24. sglang/srt/connector/serde/safe_serde.py +29 -0
  25. sglang/srt/connector/serde/serde.py +43 -0
  26. sglang/srt/connector/utils.py +35 -0
  27. sglang/srt/conversation.py +88 -0
  28. sglang/srt/disaggregation/conn.py +81 -0
  29. sglang/srt/disaggregation/decode.py +495 -0
  30. sglang/srt/disaggregation/mini_lb.py +285 -0
  31. sglang/srt/disaggregation/prefill.py +249 -0
  32. sglang/srt/disaggregation/utils.py +44 -0
  33. sglang/srt/distributed/parallel_state.py +10 -3
  34. sglang/srt/entrypoints/engine.py +55 -5
  35. sglang/srt/entrypoints/http_server.py +71 -12
  36. sglang/srt/function_call_parser.py +164 -54
  37. sglang/srt/hf_transformers_utils.py +28 -3
  38. sglang/srt/layers/activation.py +4 -2
  39. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  40. sglang/srt/layers/attention/flashattention_backend.py +295 -0
  41. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  42. sglang/srt/layers/attention/flashmla_backend.py +284 -0
  43. sglang/srt/layers/attention/triton_backend.py +171 -38
  44. sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
  45. sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
  46. sglang/srt/layers/attention/utils.py +53 -0
  47. sglang/srt/layers/attention/vision.py +9 -28
  48. sglang/srt/layers/dp_attention.py +62 -23
  49. sglang/srt/layers/elementwise.py +411 -0
  50. sglang/srt/layers/layernorm.py +24 -2
  51. sglang/srt/layers/linear.py +17 -5
  52. sglang/srt/layers/logits_processor.py +26 -7
  53. sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
  54. sglang/srt/layers/moe/ep_moe/layer.py +273 -1
  55. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
  56. sglang/srt/layers/moe/fused_moe_native.py +2 -1
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
  63. sglang/srt/layers/moe/router.py +342 -0
  64. sglang/srt/layers/moe/topk.py +31 -18
  65. sglang/srt/layers/parameter.py +1 -1
  66. sglang/srt/layers/quantization/__init__.py +184 -126
  67. sglang/srt/layers/quantization/base_config.py +5 -0
  68. sglang/srt/layers/quantization/blockwise_int8.py +1 -1
  69. sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
  70. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
  71. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
  72. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
  73. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
  74. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
  75. sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
  76. sglang/srt/layers/quantization/fp8.py +76 -34
  77. sglang/srt/layers/quantization/fp8_kernel.py +24 -8
  78. sglang/srt/layers/quantization/fp8_utils.py +284 -28
  79. sglang/srt/layers/quantization/gptq.py +36 -9
  80. sglang/srt/layers/quantization/kv_cache.py +98 -0
  81. sglang/srt/layers/quantization/modelopt_quant.py +9 -7
  82. sglang/srt/layers/quantization/utils.py +153 -0
  83. sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
  84. sglang/srt/layers/rotary_embedding.py +66 -87
  85. sglang/srt/layers/sampler.py +1 -1
  86. sglang/srt/lora/layers.py +68 -0
  87. sglang/srt/lora/lora.py +2 -22
  88. sglang/srt/lora/lora_manager.py +47 -23
  89. sglang/srt/lora/mem_pool.py +110 -51
  90. sglang/srt/lora/utils.py +12 -1
  91. sglang/srt/managers/cache_controller.py +4 -5
  92. sglang/srt/managers/data_parallel_controller.py +31 -9
  93. sglang/srt/managers/expert_distribution.py +81 -0
  94. sglang/srt/managers/io_struct.py +39 -3
  95. sglang/srt/managers/mm_utils.py +373 -0
  96. sglang/srt/managers/multimodal_processor.py +68 -0
  97. sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
  98. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
  99. sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
  100. sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
  101. sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
  102. sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
  103. sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
  104. sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
  105. sglang/srt/managers/schedule_batch.py +134 -31
  106. sglang/srt/managers/scheduler.py +325 -38
  107. sglang/srt/managers/scheduler_output_processor_mixin.py +4 -1
  108. sglang/srt/managers/session_controller.py +1 -1
  109. sglang/srt/managers/tokenizer_manager.py +59 -23
  110. sglang/srt/managers/tp_worker.py +1 -1
  111. sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
  112. sglang/srt/managers/utils.py +6 -1
  113. sglang/srt/mem_cache/hiradix_cache.py +27 -8
  114. sglang/srt/mem_cache/memory_pool.py +258 -98
  115. sglang/srt/mem_cache/paged_allocator.py +2 -2
  116. sglang/srt/mem_cache/radix_cache.py +4 -4
  117. sglang/srt/model_executor/cuda_graph_runner.py +85 -28
  118. sglang/srt/model_executor/forward_batch_info.py +81 -15
  119. sglang/srt/model_executor/model_runner.py +70 -6
  120. sglang/srt/model_loader/loader.py +160 -2
  121. sglang/srt/model_loader/weight_utils.py +45 -0
  122. sglang/srt/models/deepseek_janus_pro.py +29 -86
  123. sglang/srt/models/deepseek_nextn.py +22 -10
  124. sglang/srt/models/deepseek_v2.py +326 -192
  125. sglang/srt/models/deepseek_vl2.py +358 -0
  126. sglang/srt/models/gemma3_causal.py +684 -0
  127. sglang/srt/models/gemma3_mm.py +462 -0
  128. sglang/srt/models/grok.py +374 -119
  129. sglang/srt/models/llama.py +47 -7
  130. sglang/srt/models/llama_eagle.py +1 -0
  131. sglang/srt/models/llama_eagle3.py +196 -0
  132. sglang/srt/models/llava.py +3 -3
  133. sglang/srt/models/llavavid.py +3 -3
  134. sglang/srt/models/minicpmo.py +1995 -0
  135. sglang/srt/models/minicpmv.py +62 -137
  136. sglang/srt/models/mllama.py +4 -4
  137. sglang/srt/models/phi3_small.py +1 -1
  138. sglang/srt/models/qwen2.py +3 -0
  139. sglang/srt/models/qwen2_5_vl.py +68 -146
  140. sglang/srt/models/qwen2_classification.py +75 -0
  141. sglang/srt/models/qwen2_moe.py +9 -1
  142. sglang/srt/models/qwen2_vl.py +25 -63
  143. sglang/srt/openai_api/adapter.py +145 -47
  144. sglang/srt/openai_api/protocol.py +23 -2
  145. sglang/srt/sampling/sampling_batch_info.py +1 -1
  146. sglang/srt/sampling/sampling_params.py +6 -6
  147. sglang/srt/server_args.py +104 -14
  148. sglang/srt/speculative/build_eagle_tree.py +7 -347
  149. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
  150. sglang/srt/speculative/eagle_utils.py +208 -252
  151. sglang/srt/speculative/eagle_worker.py +139 -53
  152. sglang/srt/speculative/spec_info.py +6 -1
  153. sglang/srt/torch_memory_saver_adapter.py +22 -0
  154. sglang/srt/utils.py +182 -21
  155. sglang/test/__init__.py +0 -0
  156. sglang/test/attention/__init__.py +0 -0
  157. sglang/test/attention/test_flashattn_backend.py +312 -0
  158. sglang/test/runners.py +2 -0
  159. sglang/test/test_activation.py +2 -1
  160. sglang/test/test_block_fp8.py +5 -4
  161. sglang/test/test_block_fp8_ep.py +2 -1
  162. sglang/test/test_dynamic_grad_mode.py +58 -0
  163. sglang/test/test_layernorm.py +3 -2
  164. sglang/test/test_utils.py +55 -4
  165. sglang/utils.py +31 -0
  166. sglang/version.py +1 -1
  167. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/METADATA +12 -8
  168. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/RECORD +171 -125
  169. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/WHEEL +1 -1
  170. sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
  171. sglang/srt/managers/image_processor.py +0 -55
  172. sglang/srt/managers/image_processors/base_image_processor.py +0 -219
  173. sglang/srt/managers/image_processors/minicpmv.py +0 -86
  174. sglang/srt/managers/multi_modality_padding.py +0 -134
  175. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info/licenses}/LICENSE +0 -0
  176. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/top_level.txt +0 -0
@@ -33,16 +33,18 @@ from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_N
33
33
  from sglang.srt.configs import (
34
34
  ChatGLMConfig,
35
35
  DbrxConfig,
36
+ DeepseekVL2Config,
36
37
  ExaoneConfig,
37
38
  MultiModalityConfig,
38
- Qwen2_5_VLConfig,
39
39
  )
40
+ from sglang.srt.connector import create_remote_connector
41
+ from sglang.srt.utils import is_remote_url
40
42
 
41
43
  _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
42
44
  ChatGLMConfig.model_type: ChatGLMConfig,
43
45
  DbrxConfig.model_type: DbrxConfig,
44
46
  ExaoneConfig.model_type: ExaoneConfig,
45
- Qwen2_5_VLConfig.model_type: Qwen2_5_VLConfig,
47
+ DeepseekVL2Config.model_type: DeepseekVL2Config,
46
48
  MultiModalityConfig.model_type: MultiModalityConfig,
47
49
  }
48
50
 
@@ -155,6 +157,14 @@ def get_tokenizer(
155
157
  kwargs["gguf_file"] = tokenizer_name
156
158
  tokenizer_name = Path(tokenizer_name).parent
157
159
 
160
+ if is_remote_url(tokenizer_name):
161
+ # BaseConnector implements __del__() to clean up the local dir.
162
+ # Since config files need to exist all the time, so we DO NOT use
163
+ # with statement to avoid closing the client.
164
+ client = create_remote_connector(tokenizer_name)
165
+ client.pull_files(ignore_pattern=["*.pt", "*.safetensors", "*.bin"])
166
+ tokenizer_name = client.get_local_dir()
167
+
158
168
  try:
159
169
  tokenizer = AutoTokenizer.from_pretrained(
160
170
  tokenizer_name,
@@ -207,11 +217,26 @@ def get_processor(
207
217
  tokenizer_revision: Optional[str] = None,
208
218
  **kwargs,
209
219
  ):
220
+ # pop 'revision' from kwargs if present.
221
+ revision = kwargs.pop("revision", tokenizer_revision)
222
+
223
+ config = AutoConfig.from_pretrained(
224
+ tokenizer_name,
225
+ trust_remote_code=trust_remote_code,
226
+ revision=revision,
227
+ **kwargs,
228
+ )
229
+
230
+ # fix: for Qwen2-VL model, inject default 'size' if not provided.
231
+ if config.model_type in {"qwen2_vl"}:
232
+ if "size" not in kwargs:
233
+ kwargs["size"] = {"shortest_edge": 3136, "longest_edge": 1003520}
234
+
210
235
  processor = AutoProcessor.from_pretrained(
211
236
  tokenizer_name,
212
237
  *args,
213
238
  trust_remote_code=trust_remote_code,
214
- tokenizer_revision=tokenizer_revision,
239
+ revision=revision,
215
240
  **kwargs,
216
241
  )
217
242
 
@@ -23,7 +23,9 @@ import torch.nn.functional as F
23
23
 
24
24
  from sglang.srt.utils import is_cuda_available
25
25
 
26
- if is_cuda_available():
26
+ _is_cuda = is_cuda_available()
27
+
28
+ if _is_cuda:
27
29
  from sgl_kernel import gelu_and_mul, gelu_tanh_and_mul, silu_and_mul
28
30
 
29
31
  from sglang.srt.custom_op import CustomOp
@@ -165,7 +167,7 @@ def get_act_fn(
165
167
  return act_fn
166
168
 
167
169
 
168
- if not is_cuda_available():
170
+ if not _is_cuda:
169
171
  logger.info(
170
172
  "sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries."
171
173
  )
@@ -47,7 +47,7 @@ class AttentionBackend(ABC):
47
47
  spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
48
48
  seq_lens_cpu: Optional[torch.Tensor],
49
49
  ):
50
- """Init the metadata for a forward pass for replying a cuda graph."""
50
+ """Init the metadata for a forward pass for replaying a cuda graph."""
51
51
  raise NotImplementedError()
52
52
 
53
53
  def get_cuda_graph_seq_len_fill_value(self):
@@ -0,0 +1,295 @@
1
+ from __future__ import annotations
2
+
3
+ from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
4
+
5
+ """
6
+ Support different attention backends.
7
+ Now there are three backends: FlashInfer, Triton and FlashAttention.
8
+ Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
9
+ """
10
+
11
+ from dataclasses import dataclass
12
+ from typing import TYPE_CHECKING, Optional, Union
13
+
14
+ import torch
15
+
16
+ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
17
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
18
+
19
+ if TYPE_CHECKING:
20
+ from sglang.srt.layers.radix_attention import RadixAttention
21
+ from sglang.srt.model_executor.model_runner import ModelRunner
22
+
23
+ from flash_attn_interface import flash_attn_with_kvcache
24
+
25
+
26
+ @dataclass
27
+ class FlashAttentionMetadata:
28
+ """Metadata for decode operations to avoid redundant computations."""
29
+
30
+ cu_seqlens_q: torch.Tensor = None
31
+ cu_seqlens_k: torch.Tensor = None
32
+ max_seq_len_k: int = 0
33
+ window_size: tuple = (-1, -1)
34
+ page_table: torch.Tensor = None
35
+ cache_seqlens_int32: torch.Tensor = None
36
+ max_seq_len_q: int = 0
37
+
38
+
39
+ class FlashAttentionBackend(AttentionBackend):
40
+ """FlashAttention backend implementation."""
41
+
42
+ def __init__(
43
+ self,
44
+ model_runner: ModelRunner,
45
+ skip_prefill: bool = False,
46
+ ):
47
+ super().__init__()
48
+
49
+ assert not (
50
+ model_runner.sliding_window_size is not None
51
+ and model_runner.model_config.is_encoder_decoder
52
+ ), "Sliding window and cross attention are not supported together"
53
+
54
+ # Initialize metadata
55
+ self.forward_metadata: FlashAttentionMetadata = None
56
+ self.max_context_len = model_runner.model_config.context_len
57
+ self.device = model_runner.device
58
+ self.decode_cuda_graph_metadata = {}
59
+ self.req_to_token = model_runner.req_to_token_pool.req_to_token
60
+
61
+ def init_forward_metadata(self, forward_batch: ForwardBatch):
62
+ """Initialize forward metadata to cache repetitive calculations."""
63
+ # Create metadata based on forward mode
64
+ metadata = FlashAttentionMetadata()
65
+
66
+ extend_seq_lens = forward_batch.extend_seq_lens
67
+ # Get sequence information
68
+ seqlens_in_batch = forward_batch.seq_lens
69
+ # Precompute int32 version of sequence lengths
70
+ metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
71
+ batch_size = len(seqlens_in_batch)
72
+ device = seqlens_in_batch.device
73
+ metadata.cu_seqlens_k = torch.nn.functional.pad(
74
+ torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
75
+ )
76
+ # Precompute maximum sequence length
77
+ metadata.max_seq_len_k = seqlens_in_batch.max().item()
78
+ # Precompute page table
79
+ metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
80
+ forward_batch.req_pool_indices, : metadata.max_seq_len_k
81
+ ]
82
+ if forward_batch.forward_mode == ForwardMode.DECODE:
83
+ # Precompute cumulative sequence lengths
84
+ metadata.cu_seqlens_q = torch.arange(
85
+ 0, batch_size + 1, dtype=torch.int32, device=device
86
+ )
87
+ else:
88
+ extend_no_prefix = not any(forward_batch.extend_prefix_lens)
89
+ # Precompute cumulative sequence lengths
90
+ if not extend_no_prefix:
91
+ metadata.cu_seqlens_q = torch.nn.functional.pad(
92
+ torch.cumsum(extend_seq_lens, dim=0, dtype=torch.int32), (1, 0)
93
+ )
94
+ else:
95
+ metadata.cu_seqlens_q = metadata.cu_seqlens_k
96
+ metadata.max_seq_len_q = seqlens_in_batch.max().item()
97
+ self.forward_metadata = metadata
98
+
99
+ def forward_extend(
100
+ self,
101
+ q: torch.Tensor,
102
+ k: torch.Tensor,
103
+ v: torch.Tensor,
104
+ layer: RadixAttention,
105
+ forward_batch: ForwardBatch,
106
+ save_kv_cache=True,
107
+ ):
108
+ cache_loc = (
109
+ forward_batch.out_cache_loc
110
+ if not layer.is_cross_attention
111
+ else forward_batch.encoder_out_cache_loc
112
+ )
113
+
114
+ if k is not None:
115
+ assert v is not None
116
+ if save_kv_cache:
117
+ forward_batch.token_to_kv_pool.set_kv_buffer(
118
+ layer, cache_loc, k, v, layer.k_scale, layer.v_scale
119
+ )
120
+
121
+ # Use precomputed metadata
122
+ metadata = self.forward_metadata
123
+
124
+ # # Use Flash Attention for prefill
125
+ # Calculate window size (can be moved to metadata if layer properties don't change)
126
+ # we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1
127
+ # here is two side inclusive
128
+ window_size = (
129
+ (layer.sliding_window_size, 0)
130
+ if layer.sliding_window_size is not None
131
+ else (-1, -1)
132
+ )
133
+ kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
134
+ key_cache, value_cache = kv_cache[0], kv_cache[1]
135
+ o = flash_attn_with_kvcache(
136
+ q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
137
+ k_cache=key_cache.unsqueeze(1),
138
+ v_cache=value_cache.unsqueeze(1),
139
+ page_table=metadata.page_table,
140
+ cache_seqlens=metadata.cache_seqlens_int32,
141
+ cu_seqlens_q=metadata.cu_seqlens_q,
142
+ cu_seqlens_k_new=metadata.cu_seqlens_k,
143
+ max_seqlen_q=metadata.max_seq_len_q,
144
+ softmax_scale=layer.scaling,
145
+ causal=True,
146
+ window_size=window_size,
147
+ softcap=layer.logit_cap,
148
+ k_descale=layer.k_scale,
149
+ v_descale=layer.v_scale,
150
+ )
151
+
152
+ return o.view(-1, layer.tp_q_head_num * layer.head_dim)
153
+
154
+ def forward_decode(
155
+ self,
156
+ q: torch.Tensor,
157
+ k: torch.Tensor,
158
+ v: torch.Tensor,
159
+ layer: RadixAttention,
160
+ forward_batch: ForwardBatch,
161
+ save_kv_cache=True,
162
+ ) -> torch.Tensor:
163
+ """Forward pass with FlashAttention using precomputed metadata."""
164
+ # Save KV cache if needed
165
+ if k is not None and v is not None and save_kv_cache:
166
+ cache_loc = (
167
+ forward_batch.out_cache_loc
168
+ if not layer.is_cross_attention
169
+ else forward_batch.encoder_out_cache_loc
170
+ )
171
+ forward_batch.token_to_kv_pool.set_kv_buffer(
172
+ layer, cache_loc, k, v, layer.k_scale, layer.v_scale
173
+ )
174
+
175
+ # Get KV cache
176
+ kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
177
+ key_cache, value_cache = kv_cache[0], kv_cache[1]
178
+
179
+ # Use precomputed metadata
180
+ metadata = self.forward_metadata
181
+
182
+ # Pre-reshape query tensor
183
+ q_reshaped = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
184
+
185
+ # Calculate window size (can be moved to metadata if layer properties don't change)
186
+ # we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1
187
+ # here is two side inclusive
188
+ window_size = (
189
+ (layer.sliding_window_size, 0)
190
+ if layer.sliding_window_size is not None
191
+ else (-1, -1)
192
+ )
193
+ # Run attention with precomputed values
194
+ o = flash_attn_with_kvcache(
195
+ q=q_reshaped,
196
+ k_cache=key_cache.unsqueeze(1),
197
+ v_cache=value_cache.unsqueeze(1),
198
+ page_table=metadata.page_table,
199
+ cache_seqlens=metadata.cache_seqlens_int32,
200
+ cu_seqlens_q=metadata.cu_seqlens_q,
201
+ cu_seqlens_k_new=metadata.cu_seqlens_k,
202
+ max_seqlen_q=1,
203
+ softmax_scale=layer.scaling,
204
+ causal=True,
205
+ window_size=window_size,
206
+ softcap=layer.logit_cap,
207
+ k_descale=layer.k_scale,
208
+ v_descale=layer.v_scale,
209
+ )
210
+
211
+ return o.view(-1, layer.tp_q_head_num * layer.head_dim)
212
+
213
+ def init_cuda_graph_state(self, max_bs: int):
214
+ """Initialize CUDA graph state for the attention backend.
215
+
216
+ Args:
217
+ max_bs (int): Maximum batch size to support in CUDA graphs
218
+
219
+ This creates fixed-size tensors that will be reused during CUDA graph replay
220
+ to avoid memory allocations.
221
+ """
222
+ # Initialize fixed size tensors for decode operations
223
+ self.decode_cuda_graph_metadata = {
224
+ # Page table for token mapping (batch_size, max_context_len)
225
+ "page_table": torch.zeros(
226
+ max_bs, self.max_context_len, dtype=torch.int32, device=self.device
227
+ ),
228
+ }
229
+
230
+ def init_forward_metadata_capture_cuda_graph(
231
+ self,
232
+ bs: int,
233
+ num_tokens: int,
234
+ req_pool_indices: torch.Tensor,
235
+ seq_lens: torch.Tensor,
236
+ encoder_lens: Optional[torch.Tensor],
237
+ forward_mode: ForwardMode,
238
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
239
+ ):
240
+ """Initialize forward metadata for capturing CUDA graph."""
241
+ metadata = FlashAttentionMetadata()
242
+ # Get sequence information
243
+ metadata.cache_seqlens_int32 = seq_lens.to(torch.int32)
244
+ batch_size = len(seq_lens)
245
+ device = seq_lens.device
246
+ metadata.cu_seqlens_k = torch.nn.functional.pad(
247
+ torch.cumsum(seq_lens, dim=0, dtype=torch.int32), (1, 0)
248
+ )
249
+ # Precompute maximum sequence length
250
+ metadata.max_seq_len_k = seq_lens.max().item()
251
+ # Precompute page table
252
+ metadata.page_table = self.decode_cuda_graph_metadata["page_table"][
253
+ req_pool_indices, :
254
+ ]
255
+ if forward_mode == ForwardMode.DECODE:
256
+ # Precompute cumulative sequence lengths
257
+ metadata.cu_seqlens_q = torch.arange(
258
+ 0, batch_size + 1, dtype=torch.int32, device=device
259
+ )
260
+ else:
261
+ raise ValueError("Do not support Prefill Mode cuda graph")
262
+ self.decode_cuda_graph_metadata[bs] = metadata
263
+ self.forward_metadata = metadata
264
+
265
+ def init_forward_metadata_replay_cuda_graph(
266
+ self,
267
+ bs: int,
268
+ req_pool_indices: torch.Tensor,
269
+ seq_lens: torch.Tensor,
270
+ seq_lens_sum: int,
271
+ encoder_lens: Optional[torch.Tensor],
272
+ forward_mode: ForwardMode,
273
+ spec_info: Optional[Union[EagleDraftInput, EagleVerifyInput]],
274
+ seq_lens_cpu: Optional[torch.Tensor],
275
+ ):
276
+ # """Initialize forward metadata for replaying CUDA graph."""
277
+ seqlens_in_batch = seq_lens[:bs]
278
+ metadata = self.decode_cuda_graph_metadata[bs]
279
+ metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
280
+ metadata.cu_seqlens_k = torch.nn.functional.pad(
281
+ torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
282
+ )
283
+ # Precompute maximum sequence length
284
+ metadata.max_seq_len_k = seqlens_in_batch.max().item()
285
+ # Only zero out the part out of max_len_k
286
+ metadata.page_table[:, metadata.max_seq_len_k :].fill_(0)
287
+ # Then do the copy
288
+ metadata.page_table[:, : metadata.max_seq_len_k].copy_(
289
+ self.req_to_token[req_pool_indices[:bs], : metadata.max_seq_len_k]
290
+ )
291
+ self.forward_decode_metadata = metadata
292
+
293
+ def get_cuda_graph_seq_len_fill_value(self):
294
+ """Get the fill value for sequence length in CUDA graph."""
295
+ return 0
@@ -1008,7 +1008,7 @@ class FlashInferMultiStepDraftBackend:
1008
1008
  global_override_indptr_cpu = None
1009
1009
 
1010
1010
  def init_forward_metadata(self, forward_batch: ForwardBatch):
1011
- kv_indices = torch.zeros(
1011
+ kv_indices = torch.empty(
1012
1012
  (
1013
1013
  self.speculative_num_steps,
1014
1014
  forward_batch.batch_size * self.topk * self.max_context_len,
@@ -0,0 +1,284 @@
1
+ from __future__ import annotations
2
+
3
+ """
4
+ Support attention backend for FlashMLA.
5
+
6
+ #TODO
7
+ Enable speculative sampling in FlashMLA
8
+ """
9
+
10
+ from dataclasses import dataclass
11
+ from typing import TYPE_CHECKING, Optional, Union
12
+
13
+ import torch
14
+ import triton
15
+ from flash_mla import flash_mla_with_kvcache, get_mla_metadata
16
+
17
+ from sglang.global_config import global_config
18
+ from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
19
+ from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
20
+ from sglang.srt.layers.attention.utils import create_flashmla_kv_indices_triton
21
+ from sglang.srt.layers.dp_attention import get_attention_tp_size
22
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
23
+
24
+ if TYPE_CHECKING:
25
+ from sglang.srt.layers.radix_attention import RadixAttention
26
+ from sglang.srt.model_executor.model_runner import ModelRunner
27
+ from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
28
+ from sglang.srt.speculative.spec_info import SpecInfo
29
+
30
+
31
+ # FlashMLA only supports pagesize=64
32
+ PAGE_SIZE = 64
33
+ # TODO The current setup is hard-coded and will be changed after integrating with MTP.
34
+ Q_LEN = 1
35
+
36
+
37
+ @dataclass
38
+ class FlashMLADecodeMetadata:
39
+ flashmla_metadata: Optional[Tuple[torch.Tensor, torch.Tensor]] = None
40
+ num_splits: Optional[torch.Tensor] = None
41
+ block_kv_indices: Optional[torch.Tensor] = None
42
+
43
+ def __init__(
44
+ self,
45
+ flashmla_metadata: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
46
+ num_splits: Optional[torch.Tensor] = None,
47
+ block_kv_indices: Optional[torch.Tensor] = None,
48
+ ):
49
+ self.flashmla_metadata = flashmla_metadata
50
+ self.num_splits = num_splits
51
+ self.block_kv_indices = block_kv_indices
52
+
53
+
54
+ class FlashMLABackend(FlashInferMLAAttnBackend):
55
+ """Flashinfer attention kernels."""
56
+
57
+ def __init__(
58
+ self,
59
+ model_runner: ModelRunner,
60
+ skip_prefill: bool = False,
61
+ kv_indptr_buf: Optional[torch.Tensor] = None,
62
+ kv_last_page_len_buf: Optional[torch.Tensor] = None,
63
+ ):
64
+ super().__init__(
65
+ model_runner, skip_prefill, kv_indptr_buf, kv_last_page_len_buf
66
+ )
67
+
68
+ self.num_q_heads = (
69
+ model_runner.model_config.num_attention_heads // get_attention_tp_size()
70
+ )
71
+ self.num_kv_heads = model_runner.model_config.get_num_kv_heads(
72
+ get_attention_tp_size()
73
+ )
74
+ self.req_to_token = model_runner.req_to_token_pool.req_to_token
75
+ self.num_local_heads = (
76
+ model_runner.model_config.num_attention_heads // get_attention_tp_size()
77
+ )
78
+ self.forward_metadata: Union[FlashMLADecodeMetadata] = None
79
+ self.kv_lora_rank = model_runner.model_config.kv_lora_rank
80
+ self.qk_nope_head_dim = model_runner.model_config.qk_nope_head_dim
81
+ self.qk_rope_head_dim = model_runner.model_config.qk_rope_head_dim
82
+ self.v_head_dim = model_runner.model_config.v_head_dim
83
+ self.scaling = model_runner.model_config.scaling
84
+ self.data_type = model_runner.kv_cache_dtype
85
+ self.q_data_type = model_runner.dtype
86
+ self.kv_cache_dim = self.kv_lora_rank + self.qk_rope_head_dim
87
+
88
+ def init_forward_metadata(self, forward_batch: ForwardBatch):
89
+
90
+ bs = forward_batch.batch_size
91
+ spec_info = forward_batch.spec_info
92
+ if forward_batch.forward_mode.is_decode_or_idle():
93
+ if spec_info is None:
94
+ max_seqlen_pad = triton.cdiv(
95
+ forward_batch.decode_seq_lens_cpu.max().item(), PAGE_SIZE
96
+ )
97
+ block_kv_indices = torch.full(
98
+ (bs, max_seqlen_pad),
99
+ -1,
100
+ dtype=torch.int32,
101
+ device=forward_batch.seq_lens.device,
102
+ )
103
+ create_flashmla_kv_indices_triton[(bs,)](
104
+ self.req_to_token,
105
+ forward_batch.req_pool_indices,
106
+ forward_batch.seq_lens,
107
+ None,
108
+ block_kv_indices,
109
+ self.req_to_token.stride(0),
110
+ max_seqlen_pad,
111
+ )
112
+ mla_metadata, num_splits = get_mla_metadata(
113
+ forward_batch.seq_lens.to(torch.int32),
114
+ Q_LEN * self.num_q_heads // self.num_kv_heads,
115
+ self.num_kv_heads,
116
+ )
117
+ self.forward_metadata = FlashMLADecodeMetadata(
118
+ mla_metadata,
119
+ num_splits,
120
+ block_kv_indices,
121
+ )
122
+ else:
123
+ super().init_forward_metadata(forward_batch)
124
+ else:
125
+ super().init_forward_metadata(forward_batch)
126
+
127
+ def init_cuda_graph_state(
128
+ self,
129
+ max_bs: int,
130
+ block_kv_indices: Optional[torch.Tensor] = None,
131
+ ):
132
+ if block_kv_indices is None:
133
+ cuda_graph_kv_indices = torch.full(
134
+ (max_bs, (self.max_context_len + PAGE_SIZE) // PAGE_SIZE),
135
+ 1,
136
+ dtype=torch.int32,
137
+ device="cuda",
138
+ )
139
+ else:
140
+ cuda_graph_kv_indices = block_kv_indices
141
+
142
+ self.cuda_graph_mla_metadata, self.cuda_graph_num_splits = get_mla_metadata(
143
+ torch.ones(max_bs, dtype=torch.int32, device=cuda_graph_kv_indices.device),
144
+ Q_LEN * self.num_q_heads // self.num_kv_heads,
145
+ self.num_kv_heads,
146
+ )
147
+ self.cuda_graph_kv_indices = cuda_graph_kv_indices
148
+
149
+ def init_forward_metadata_capture_cuda_graph(
150
+ self,
151
+ bs: int,
152
+ num_tokens: int,
153
+ req_pool_indices: torch.Tensor,
154
+ seq_lens: torch.Tensor,
155
+ encoder_lens: Optional[torch.Tensor],
156
+ forward_mode: ForwardMode,
157
+ spec_info: Optional[SpecInfo],
158
+ ):
159
+ if forward_mode.is_decode_or_idle():
160
+ if spec_info is None:
161
+ max_seqlen_pad = triton.cdiv(seq_lens.max().item(), PAGE_SIZE)
162
+
163
+ create_flashmla_kv_indices_triton[(bs,)](
164
+ self.req_to_token,
165
+ req_pool_indices,
166
+ seq_lens,
167
+ None,
168
+ self.cuda_graph_kv_indices,
169
+ self.req_to_token.stride(0),
170
+ self.cuda_graph_kv_indices.stride(0),
171
+ )
172
+ mla_metadata, num_splits = get_mla_metadata(
173
+ seq_lens.to(torch.int32),
174
+ Q_LEN * self.num_q_heads // self.num_kv_heads,
175
+ self.num_kv_heads,
176
+ )
177
+ self.cuda_graph_mla_metadata.copy_(mla_metadata)
178
+ self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
179
+ self.forward_metadata = FlashMLADecodeMetadata(
180
+ self.cuda_graph_mla_metadata,
181
+ self.cuda_graph_num_splits[: bs + 1],
182
+ self.cuda_graph_kv_indices[:bs, :max_seqlen_pad],
183
+ )
184
+
185
+ else:
186
+ super().init_forward_metadata_capture_cuda_graph(
187
+ bs,
188
+ num_tokens,
189
+ req_pool_indices,
190
+ seq_lens,
191
+ encoder_lens,
192
+ forward_mode,
193
+ spec_info,
194
+ )
195
+
196
+ def init_forward_metadata_replay_cuda_graph(
197
+ self,
198
+ bs: int,
199
+ req_pool_indices: torch.Tensor,
200
+ seq_lens: torch.Tensor,
201
+ seq_lens_sum: int,
202
+ encoder_lens: Optional[torch.Tensor],
203
+ forward_mode: ForwardMode,
204
+ spec_info: Optional[SpecInfo],
205
+ seq_lens_cpu: Optional[torch.Tensor],
206
+ ):
207
+
208
+ if forward_mode.is_decode_or_idle():
209
+ assert seq_lens_cpu is not None
210
+ seq_lens = seq_lens[:bs]
211
+ seq_lens_cpu = seq_lens_cpu[:bs]
212
+ max_seqlen_pad = triton.cdiv(seq_lens_cpu.max().item(), PAGE_SIZE)
213
+ create_flashmla_kv_indices_triton[(bs,)](
214
+ self.req_to_token,
215
+ req_pool_indices[:bs],
216
+ seq_lens,
217
+ None,
218
+ self.cuda_graph_kv_indices,
219
+ self.req_to_token.stride(0),
220
+ self.cuda_graph_kv_indices.stride(0),
221
+ )
222
+ mla_metadata, num_splits = get_mla_metadata(
223
+ seq_lens.to(torch.int32),
224
+ Q_LEN * self.num_q_heads // self.num_kv_heads,
225
+ self.num_kv_heads,
226
+ )
227
+ self.cuda_graph_mla_metadata.copy_(mla_metadata)
228
+ self.cuda_graph_num_splits[: bs + 1].copy_(num_splits)
229
+ self.forward_metadata.mla_metadata = self.cuda_graph_mla_metadata
230
+ self.forward_metadata.num_splits = self.cuda_graph_num_splits[: bs + 1]
231
+ self.forward_metadata.block_kv_indices = self.cuda_graph_kv_indices[
232
+ :bs, :max_seqlen_pad
233
+ ]
234
+
235
+ else:
236
+ super().init_forward_metadata_replay_cuda_graph(
237
+ bs,
238
+ req_pool_indices,
239
+ seq_lens,
240
+ seq_lens_sum,
241
+ encoder_lens,
242
+ forward_mode,
243
+ spec_info,
244
+ seq_lens_cpu,
245
+ )
246
+
247
+ def forward_decode(
248
+ self,
249
+ q: torch.Tensor,
250
+ k: torch.Tensor,
251
+ v: torch.Tensor,
252
+ layer: RadixAttention,
253
+ forward_batch: ForwardBatch,
254
+ save_kv_cache: bool = True,
255
+ ):
256
+ cache_loc = forward_batch.out_cache_loc
257
+
258
+ if k is not None:
259
+ assert v is not None
260
+ if save_kv_cache:
261
+ forward_batch.token_to_kv_pool.set_kv_buffer(
262
+ layer,
263
+ cache_loc,
264
+ k,
265
+ v,
266
+ )
267
+ bs = forward_batch.batch_size
268
+ k_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
269
+
270
+ reshape_q = q.view(bs, -1, layer.tp_q_head_num, layer.head_dim)
271
+
272
+ o, _ = flash_mla_with_kvcache(
273
+ q=reshape_q,
274
+ k_cache=k_cache.view(-1, PAGE_SIZE, 1, self.kv_cache_dim),
275
+ block_table=self.forward_metadata.block_kv_indices,
276
+ cache_seqlens=forward_batch.seq_lens.to(torch.int32),
277
+ head_dim_v=self.kv_lora_rank, # TODO Retrieve from config.
278
+ tile_scheduler_metadata=self.forward_metadata.flashmla_metadata,
279
+ num_splits=self.forward_metadata.num_splits,
280
+ softmax_scale=layer.scaling,
281
+ causal=False,
282
+ )
283
+
284
+ return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)