sglang 0.4.2.post4__py3-none-any.whl → 0.4.3.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. sglang/global_config.py +2 -0
  2. sglang/lang/backend/openai.py +5 -0
  3. sglang/lang/chat_template.py +22 -7
  4. sglang/lang/ir.py +1 -0
  5. sglang/srt/configs/__init__.py +6 -3
  6. sglang/srt/configs/model_config.py +2 -0
  7. sglang/srt/configs/qwen2_5_vl_config.py +1003 -0
  8. sglang/srt/entrypoints/engine.py +18 -3
  9. sglang/srt/hf_transformers_utils.py +2 -3
  10. sglang/srt/layers/attention/flashinfer_backend.py +235 -110
  11. sglang/srt/layers/attention/triton_backend.py +358 -72
  12. sglang/srt/layers/attention/triton_ops/extend_attention.py +4 -4
  13. sglang/srt/layers/linear.py +12 -5
  14. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +2 -2
  15. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  16. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +2 -2
  17. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +200 -0
  18. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +200 -0
  19. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +200 -0
  20. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +178 -0
  21. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +200 -0
  22. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +175 -0
  23. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -2
  24. sglang/srt/layers/moe/fused_moe_triton/layer.py +2 -0
  25. sglang/srt/layers/moe/topk.py +1 -1
  26. sglang/srt/layers/quantization/__init__.py +51 -5
  27. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  28. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +30 -30
  29. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  30. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +29 -29
  31. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  32. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +33 -33
  33. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  34. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +31 -31
  35. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  36. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +27 -27
  37. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  38. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +31 -31
  39. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  40. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +24 -24
  41. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  42. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +30 -30
  43. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  44. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +42 -42
  45. sglang/srt/layers/quantization/fp8_kernel.py +123 -17
  46. sglang/srt/layers/quantization/fp8_utils.py +33 -4
  47. sglang/srt/managers/detokenizer_manager.py +1 -0
  48. sglang/srt/managers/image_processor.py +217 -122
  49. sglang/srt/managers/io_struct.py +4 -0
  50. sglang/srt/managers/schedule_batch.py +16 -3
  51. sglang/srt/managers/scheduler.py +29 -0
  52. sglang/srt/managers/tokenizer_manager.py +6 -0
  53. sglang/srt/managers/tp_worker_overlap_thread.py +4 -0
  54. sglang/srt/model_executor/cuda_graph_runner.py +12 -1
  55. sglang/srt/model_executor/forward_batch_info.py +4 -1
  56. sglang/srt/model_executor/model_runner.py +12 -2
  57. sglang/srt/models/deepseek_nextn.py +295 -0
  58. sglang/srt/models/deepseek_v2.py +21 -8
  59. sglang/srt/models/llava.py +2 -1
  60. sglang/srt/models/qwen2_5_vl.py +722 -0
  61. sglang/srt/models/qwen2_vl.py +2 -1
  62. sglang/srt/openai_api/adapter.py +17 -3
  63. sglang/srt/server_args.py +26 -4
  64. sglang/srt/speculative/eagle_worker.py +35 -10
  65. sglang/srt/speculative/spec_info.py +11 -1
  66. sglang/srt/utils.py +7 -0
  67. sglang/utils.py +99 -19
  68. sglang/version.py +1 -1
  69. {sglang-0.4.2.post4.dist-info → sglang-0.4.3.post1.dist-info}/METADATA +5 -4
  70. {sglang-0.4.2.post4.dist-info → sglang-0.4.3.post1.dist-info}/RECORD +73 -55
  71. sglang/srt/configs/qwen2vl.py +0 -130
  72. {sglang-0.4.2.post4.dist-info → sglang-0.4.3.post1.dist-info}/LICENSE +0 -0
  73. {sglang-0.4.2.post4.dist-info → sglang-0.4.3.post1.dist-info}/WHEEL +0 -0
  74. {sglang-0.4.2.post4.dist-info → sglang-0.4.3.post1.dist-info}/top_level.txt +0 -0
@@ -115,6 +115,9 @@ class Engine:
115
115
  sampling_params: Optional[Union[List[Dict], Dict]] = None,
116
116
  # The token ids for text; one can either specify text or input_ids.
117
117
  input_ids: Optional[Union[List[List[int]], List[int]]] = None,
118
+ # The image input. It can be a file name, a url, or base64 encoded string.
119
+ # See also python/sglang/srt/utils.py:load_image.
120
+ image_data: Optional[Union[List[str], str]] = None,
118
121
  return_logprob: Optional[Union[List[bool], bool]] = False,
119
122
  logprob_start_len: Optional[Union[List[int], int]] = None,
120
123
  top_logprobs_num: Optional[Union[List[int], int]] = None,
@@ -126,14 +129,20 @@ class Engine:
126
129
  The arguments of this function is the same as `sglang/srt/managers/io_struct.py::GenerateReqInput`.
127
130
  Please refer to `GenerateReqInput` for the documentation.
128
131
  """
132
+ modalities_list = []
133
+ if image_data is not None:
134
+ modalities_list.append("image")
135
+
129
136
  obj = GenerateReqInput(
130
137
  text=prompt,
131
138
  input_ids=input_ids,
132
139
  sampling_params=sampling_params,
140
+ image_data=image_data,
133
141
  return_logprob=return_logprob,
134
142
  logprob_start_len=logprob_start_len,
135
143
  top_logprobs_num=top_logprobs_num,
136
144
  lora_path=lora_path,
145
+ modalities=modalities_list,
137
146
  custom_logit_processor=custom_logit_processor,
138
147
  stream=stream,
139
148
  )
@@ -162,6 +171,9 @@ class Engine:
162
171
  sampling_params: Optional[Union[List[Dict], Dict]] = None,
163
172
  # The token ids for text; one can either specify text or input_ids.
164
173
  input_ids: Optional[Union[List[List[int]], List[int]]] = None,
174
+ # The image input. It can be a file name, a url, or base64 encoded string.
175
+ # See also python/sglang/srt/utils.py:load_image.
176
+ image_data: Optional[Union[List[str], str]] = None,
165
177
  return_logprob: Optional[Union[List[bool], bool]] = False,
166
178
  logprob_start_len: Optional[Union[List[int], int]] = None,
167
179
  top_logprobs_num: Optional[Union[List[int], int]] = None,
@@ -177,6 +189,7 @@ class Engine:
177
189
  text=prompt,
178
190
  input_ids=input_ids,
179
191
  sampling_params=sampling_params,
192
+ image_data=image_data,
180
193
  return_logprob=return_logprob,
181
194
  logprob_start_len=logprob_start_len,
182
195
  top_logprobs_num=top_logprobs_num,
@@ -297,7 +310,7 @@ def _set_envs_and_config(server_args: ServerArgs):
297
310
  # Set global environments
298
311
  os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
299
312
  os.environ["NCCL_CUMEM_ENABLE"] = "0"
300
- os.environ["NCCL_NVLS_ENABLE"] = "0"
313
+ os.environ["NCCL_NVLS_ENABLE"] = str(int(server_args.enable_nccl_nvls))
301
314
  os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
302
315
  os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4"
303
316
 
@@ -317,7 +330,7 @@ def _set_envs_and_config(server_args: ServerArgs):
317
330
  if server_args.attention_backend == "flashinfer":
318
331
  assert_pkg_version(
319
332
  "flashinfer_python",
320
- "0.2.0.post2",
333
+ "0.2.1.post1",
321
334
  "Please uninstall the old version and "
322
335
  "reinstall the latest version by following the instructions "
323
336
  "at https://docs.flashinfer.ai/installation.html.",
@@ -425,7 +438,9 @@ def _launch_subprocesses(server_args: ServerArgs) -> Tuple[TokenizerManager, Dic
425
438
  # Launch tokenizer process
426
439
  tokenizer_manager = TokenizerManager(server_args, port_args)
427
440
  if server_args.chat_template:
428
- load_chat_template_for_openai_api(tokenizer_manager, server_args.chat_template)
441
+ load_chat_template_for_openai_api(
442
+ tokenizer_manager, server_args.chat_template, server_args.model_path
443
+ )
429
444
 
430
445
  # Wait for the model to finish loading
431
446
  scheduler_infos = []
@@ -30,16 +30,15 @@ from transformers import (
30
30
  )
31
31
  from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
32
32
 
33
- from sglang.srt.configs import ChatGLMConfig, DbrxConfig, ExaoneConfig, Qwen2VLConfig
33
+ from sglang.srt.configs import ChatGLMConfig, DbrxConfig, ExaoneConfig, Qwen2_5_VLConfig
34
34
 
35
35
  _CONFIG_REGISTRY: Dict[str, Type[PretrainedConfig]] = {
36
36
  ChatGLMConfig.model_type: ChatGLMConfig,
37
37
  DbrxConfig.model_type: DbrxConfig,
38
38
  ExaoneConfig.model_type: ExaoneConfig,
39
- Qwen2VLConfig.model_type: Qwen2VLConfig,
39
+ Qwen2_5_VLConfig.model_type: Qwen2_5_VLConfig,
40
40
  }
41
41
 
42
-
43
42
  for name, cls in _CONFIG_REGISTRY.items():
44
43
  with contextlib.suppress(ValueError):
45
44
  AutoConfig.register(name, cls)
@@ -7,6 +7,7 @@ FlashInfer is faster and Triton is easier to customize.
7
7
  Each backend supports two operators: extend (i.e. prefill with cached prefix) and decode.
8
8
  """
9
9
 
10
+ import math
10
11
  import os
11
12
  from dataclasses import dataclass
12
13
  from enum import Enum, auto
@@ -20,6 +21,7 @@ import triton.language as tl
20
21
  from sglang.global_config import global_config
21
22
  from sglang.srt.layers.attention import AttentionBackend
22
23
  from sglang.srt.layers.dp_attention import get_attention_tp_size
24
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
23
25
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
24
26
  from sglang.srt.utils import is_flashinfer_available
25
27
 
@@ -35,7 +37,7 @@ if is_flashinfer_available():
35
37
  BatchPrefillWithRaggedKVCacheWrapper,
36
38
  )
37
39
  from flashinfer.cascade import merge_state
38
- from flashinfer.decode import PosEncodingMode
40
+ from flashinfer.mla import BatchMLAPagedAttentionWrapper
39
41
 
40
42
 
41
43
  class WrapperDispatch(Enum):
@@ -45,7 +47,9 @@ class WrapperDispatch(Enum):
45
47
 
46
48
  @dataclass
47
49
  class DecodeMetadata:
48
- decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper]
50
+ decode_wrappers: List[
51
+ Union[BatchDecodeWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
52
+ ]
49
53
 
50
54
 
51
55
  @dataclass
@@ -103,6 +107,12 @@ class FlashInferAttnBackend(AttentionBackend):
103
107
  if "Qwen2ForCausalLM" in model_runner.model_config.hf_config.architectures:
104
108
  global_config.flashinfer_workspace_size = 512 * 1024 * 1024
105
109
 
110
+ self.enable_flashinfer_mla = False
111
+ if "DeepseekV3ForCausalLM" in model_runner.model_config.hf_config.architectures:
112
+ if global_server_args_dict["enable_flashinfer_mla"]:
113
+ self.enable_flashinfer_mla = True
114
+ global_config.enable_flashinfer_mla = True
115
+
106
116
  # Allocate buffers
107
117
  global global_workspace_buffer
108
118
  if global_workspace_buffer is None:
@@ -120,6 +130,13 @@ class FlashInferAttnBackend(AttentionBackend):
120
130
  )
121
131
  for _ in range(self.num_wrappers)
122
132
  ]
133
+ if self.enable_flashinfer_mla:
134
+ self.qo_indptr = [
135
+ torch.zeros(
136
+ (max_bs + 1,), dtype=torch.int32, device=model_runner.device
137
+ )
138
+ for _ in range(self.num_wrappers)
139
+ ]
123
140
  else:
124
141
  assert self.num_wrappers == 1
125
142
  self.kv_indptr = [kv_indptr_buf]
@@ -153,13 +170,18 @@ class FlashInferAttnBackend(AttentionBackend):
153
170
  self.prefill_wrappers_verify.append(
154
171
  BatchPrefillWithPagedKVCacheWrapper(self.workspace_buffer, "NHD")
155
172
  )
156
- self.decode_wrappers.append(
157
- BatchDecodeWithPagedKVCacheWrapper(
158
- self.workspace_buffer,
159
- "NHD",
160
- use_tensor_cores=self.decode_use_tensor_cores,
173
+ if self.enable_flashinfer_mla:
174
+ self.decode_wrappers.append(
175
+ BatchMLAPagedAttentionWrapper(self.workspace_buffer, backend="fa2")
176
+ )
177
+ else:
178
+ self.decode_wrappers.append(
179
+ BatchDecodeWithPagedKVCacheWrapper(
180
+ self.workspace_buffer,
181
+ "NHD",
182
+ use_tensor_cores=self.decode_use_tensor_cores,
183
+ )
161
184
  )
162
- )
163
185
 
164
186
  # Create indices updater
165
187
  if not skip_prefill:
@@ -274,19 +296,32 @@ class FlashInferAttnBackend(AttentionBackend):
274
296
  if forward_mode.is_decode_or_idle():
275
297
  decode_wrappers = []
276
298
  for i in range(self.num_wrappers):
277
- decode_wrappers.append(
278
- BatchDecodeWithPagedKVCacheWrapper(
279
- self.workspace_buffer,
280
- "NHD",
281
- use_cuda_graph=True,
282
- use_tensor_cores=self.decode_use_tensor_cores,
283
- paged_kv_indptr_buffer=self.kv_indptr[i][: num_tokens + 1],
284
- paged_kv_indices_buffer=self.cuda_graph_kv_indices[i],
285
- paged_kv_last_page_len_buffer=self.kv_last_page_len[
286
- :num_tokens
287
- ],
299
+ if self.enable_flashinfer_mla:
300
+ decode_wrappers.append(
301
+ BatchMLAPagedAttentionWrapper(
302
+ self.workspace_buffer,
303
+ use_cuda_graph=True,
304
+ qo_indptr=self.qo_indptr[i][: num_tokens + 1],
305
+ kv_indptr=self.kv_indptr[i][: num_tokens + 1],
306
+ kv_indices=self.cuda_graph_kv_indices[i],
307
+ kv_len_arr=self.kv_last_page_len[:num_tokens],
308
+ backend="fa2",
309
+ )
310
+ )
311
+ else:
312
+ decode_wrappers.append(
313
+ BatchDecodeWithPagedKVCacheWrapper(
314
+ self.workspace_buffer,
315
+ "NHD",
316
+ use_cuda_graph=True,
317
+ use_tensor_cores=self.decode_use_tensor_cores,
318
+ paged_kv_indptr_buffer=self.kv_indptr[i][: num_tokens + 1],
319
+ paged_kv_indices_buffer=self.cuda_graph_kv_indices[i],
320
+ paged_kv_last_page_len_buffer=self.kv_last_page_len[
321
+ :num_tokens
322
+ ],
323
+ )
288
324
  )
289
- )
290
325
  seq_lens_sum = seq_lens.sum().item()
291
326
  self.indices_updater_decode.update(
292
327
  req_pool_indices,
@@ -375,64 +410,94 @@ class FlashInferAttnBackend(AttentionBackend):
375
410
  forward_batch: ForwardBatch,
376
411
  save_kv_cache=True,
377
412
  ):
378
- prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[
379
- self._get_wrapper_idx(layer)
380
- ]
381
- cache_loc = (
382
- forward_batch.out_cache_loc
383
- if not layer.is_cross_attention
384
- else forward_batch.encoder_out_cache_loc
385
- )
413
+ if global_config.enable_flashinfer_mla:
414
+ cache_loc = (
415
+ forward_batch.out_cache_loc
416
+ if not layer.is_cross_attention
417
+ else forward_batch.encoder_out_cache_loc
418
+ )
386
419
 
387
- logits_soft_cap = layer.logit_cap
420
+ logits_soft_cap = layer.logit_cap
388
421
 
389
- if not self.forward_metadata.use_ragged:
390
- if k is not None:
391
- assert v is not None
392
- if save_kv_cache:
393
- forward_batch.token_to_kv_pool.set_kv_buffer(
394
- layer, cache_loc, k, v, layer.k_scale, layer.v_scale
395
- )
396
-
397
- o = prefill_wrapper_paged.forward(
398
- q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
399
- forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
400
- causal=not layer.is_cross_attention,
401
- sm_scale=layer.scaling,
402
- window_left=layer.sliding_window_size,
403
- logits_soft_cap=logits_soft_cap,
404
- k_scale=layer.k_scale,
405
- v_scale=layer.v_scale,
406
- )
407
- else:
408
- o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
422
+ o1, _ = self.prefill_wrapper_ragged.forward_return_lse(
409
423
  q.view(-1, layer.tp_q_head_num, layer.head_dim),
410
424
  k.view(-1, layer.tp_k_head_num, layer.head_dim),
411
- v.view(-1, layer.tp_v_head_num, layer.head_dim),
425
+ v.view(-1, layer.tp_v_head_num, layer.v_head_dim),
412
426
  causal=True,
413
427
  sm_scale=layer.scaling,
414
428
  logits_soft_cap=logits_soft_cap,
415
429
  )
416
430
 
417
- if self.forward_metadata.extend_no_prefix:
418
- o = o1
419
- else:
420
- o2, s2 = prefill_wrapper_paged.forward_return_lse(
431
+ o = o1
432
+
433
+ if save_kv_cache:
434
+ forward_batch.token_to_kv_pool.set_kv_buffer(
435
+ layer,
436
+ cache_loc,
437
+ k,
438
+ v,
439
+ )
440
+
441
+ return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
442
+ else:
443
+ prefill_wrapper_paged = self.forward_metadata.prefill_wrappers[
444
+ self._get_wrapper_idx(layer)
445
+ ]
446
+ cache_loc = (
447
+ forward_batch.out_cache_loc
448
+ if not layer.is_cross_attention
449
+ else forward_batch.encoder_out_cache_loc
450
+ )
451
+
452
+ logits_soft_cap = layer.logit_cap
453
+
454
+ if not self.forward_metadata.use_ragged:
455
+ if k is not None:
456
+ assert v is not None
457
+ if save_kv_cache:
458
+ forward_batch.token_to_kv_pool.set_kv_buffer(
459
+ layer, cache_loc, k, v, layer.k_scale, layer.v_scale
460
+ )
461
+
462
+ o = prefill_wrapper_paged.forward(
421
463
  q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
422
464
  forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
423
- causal=False,
465
+ causal=not layer.is_cross_attention,
424
466
  sm_scale=layer.scaling,
425
- logits_soft_cap=layer.logit_cap,
467
+ window_left=layer.sliding_window_size,
468
+ logits_soft_cap=logits_soft_cap,
469
+ k_scale=layer.k_scale,
470
+ v_scale=layer.v_scale,
471
+ )
472
+ else:
473
+ o1, s1 = self.prefill_wrapper_ragged.forward_return_lse(
474
+ q.view(-1, layer.tp_q_head_num, layer.head_dim),
475
+ k.view(-1, layer.tp_k_head_num, layer.head_dim),
476
+ v.view(-1, layer.tp_v_head_num, layer.head_dim),
477
+ causal=True,
478
+ sm_scale=layer.scaling,
479
+ logits_soft_cap=logits_soft_cap,
426
480
  )
427
481
 
428
- o, _ = merge_state(o1, s1, o2, s2)
482
+ if self.forward_metadata.extend_no_prefix:
483
+ o = o1
484
+ else:
485
+ o2, s2 = prefill_wrapper_paged.forward_return_lse(
486
+ q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
487
+ forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
488
+ causal=False,
489
+ sm_scale=layer.scaling,
490
+ logits_soft_cap=layer.logit_cap,
491
+ )
429
492
 
430
- if save_kv_cache:
431
- forward_batch.token_to_kv_pool.set_kv_buffer(
432
- layer, cache_loc, k, v, layer.k_scale, layer.v_scale
433
- )
493
+ o, _ = merge_state(o1, s1, o2, s2)
494
+
495
+ if save_kv_cache:
496
+ forward_batch.token_to_kv_pool.set_kv_buffer(
497
+ layer, cache_loc, k, v, layer.k_scale, layer.v_scale
498
+ )
434
499
 
435
- return o.view(-1, layer.tp_q_head_num * layer.head_dim)
500
+ return o.view(-1, layer.tp_q_head_num * layer.head_dim)
436
501
 
437
502
  def forward_decode(
438
503
  self,
@@ -452,23 +517,45 @@ class FlashInferAttnBackend(AttentionBackend):
452
517
  else forward_batch.encoder_out_cache_loc
453
518
  )
454
519
 
455
- if k is not None:
456
- assert v is not None
457
- if save_kv_cache:
458
- forward_batch.token_to_kv_pool.set_kv_buffer(
459
- layer, cache_loc, k, v, layer.k_scale, layer.v_scale
460
- )
520
+ if self.enable_flashinfer_mla:
521
+ if k is not None:
522
+ assert v is not None
523
+ if save_kv_cache:
524
+ forward_batch.token_to_kv_pool.set_kv_buffer(
525
+ layer,
526
+ cache_loc,
527
+ k,
528
+ v,
529
+ )
530
+ reshaped_q = q.view(-1, layer.tp_q_head_num, layer.head_dim)
531
+ k_buffer = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
532
+ reshaped_k = k_buffer.view(-1, 1, layer.head_dim)
533
+ o = decode_wrapper.run(
534
+ reshaped_q[:, :, : layer.v_head_dim],
535
+ reshaped_q[:, :, layer.v_head_dim :],
536
+ reshaped_k[:, :, : layer.v_head_dim],
537
+ reshaped_k[:, :, layer.v_head_dim :],
538
+ )
461
539
 
462
- o = decode_wrapper.forward(
463
- q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
464
- forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
465
- sm_scale=layer.scaling,
466
- logits_soft_cap=layer.logit_cap,
467
- k_scale=layer.k_scale,
468
- v_scale=layer.v_scale,
469
- )
540
+ return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
541
+ else:
542
+ if k is not None:
543
+ assert v is not None
544
+ if save_kv_cache:
545
+ forward_batch.token_to_kv_pool.set_kv_buffer(
546
+ layer, cache_loc, k, v, layer.k_scale, layer.v_scale
547
+ )
470
548
 
471
- return o.view(-1, layer.tp_q_head_num * layer.head_dim)
549
+ o = decode_wrapper.forward(
550
+ q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
551
+ forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id),
552
+ sm_scale=layer.scaling,
553
+ logits_soft_cap=layer.logit_cap,
554
+ k_scale=layer.k_scale,
555
+ v_scale=layer.v_scale,
556
+ )
557
+
558
+ return o.view(-1, layer.tp_q_head_num * layer.head_dim)
472
559
 
473
560
  def _get_wrapper_idx(self, layer: RadixAttention):
474
561
  if self.num_wrappers == 1:
@@ -516,7 +603,9 @@ class FlashInferIndicesUpdaterDecode:
516
603
  req_pool_indices: torch.Tensor,
517
604
  seq_lens: torch.Tensor,
518
605
  seq_lens_sum: int,
519
- decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
606
+ decode_wrappers: List[
607
+ Union[BatchDecodeWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
608
+ ],
520
609
  encoder_lens: Optional[torch.Tensor],
521
610
  spec_info: Optional[SpecInfo],
522
611
  ):
@@ -528,7 +617,9 @@ class FlashInferIndicesUpdaterDecode:
528
617
  req_pool_indices: torch.Tensor,
529
618
  seq_lens: torch.Tensor,
530
619
  seq_lens_sum: int,
531
- decode_wrappers: List[BatchDecodeWithPagedKVCacheWrapper],
620
+ decode_wrappers: List[
621
+ Union[BatchDecodeWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper]
622
+ ],
532
623
  encoder_lens: Optional[torch.Tensor],
533
624
  spec_info: Optional[SpecInfo],
534
625
  ):
@@ -609,7 +700,9 @@ class FlashInferIndicesUpdaterDecode:
609
700
 
610
701
  def call_begin_forward(
611
702
  self,
612
- wrapper: BatchDecodeWithPagedKVCacheWrapper,
703
+ wrapper: Union[
704
+ BatchDecodeWithPagedKVCacheWrapper, BatchMLAPagedAttentionWrapper
705
+ ],
613
706
  req_pool_indices: torch.Tensor,
614
707
  paged_kernel_lens: torch.Tensor,
615
708
  paged_kernel_lens_sum: int,
@@ -637,18 +730,37 @@ class FlashInferIndicesUpdaterDecode:
637
730
  kv_indptr, kv_indices = spec_info.kv_indptr, spec_info.kv_indices
638
731
  bs = kv_indptr.shape[0] - 1
639
732
 
640
- wrapper.begin_forward(
641
- kv_indptr,
642
- kv_indices,
643
- self.kv_last_page_len[:bs],
644
- self.num_qo_heads,
645
- self.num_kv_heads,
646
- self.head_dim,
647
- 1,
648
- data_type=self.data_type,
649
- q_data_type=self.q_data_type,
650
- non_blocking=True,
651
- )
733
+ if global_config.enable_flashinfer_mla:
734
+ sm_scale = 1.0 / math.sqrt(192)
735
+ q_indptr = torch.arange(0, bs + 1).to(0).int()
736
+ kv_lens = paged_kernel_lens.to(torch.int32)
737
+ wrapper.plan(
738
+ q_indptr,
739
+ kv_indptr,
740
+ kv_indices,
741
+ kv_lens,
742
+ self.num_qo_heads,
743
+ 512,
744
+ 64,
745
+ 1,
746
+ False,
747
+ sm_scale,
748
+ self.data_type,
749
+ self.data_type,
750
+ )
751
+ else:
752
+ wrapper.begin_forward(
753
+ kv_indptr,
754
+ kv_indices,
755
+ self.kv_last_page_len[:bs],
756
+ self.num_qo_heads,
757
+ self.num_kv_heads,
758
+ self.head_dim,
759
+ 1,
760
+ data_type=self.data_type,
761
+ q_data_type=self.q_data_type,
762
+ non_blocking=True,
763
+ )
652
764
 
653
765
 
654
766
  class FlashInferIndicesUpdaterPrefill:
@@ -857,30 +969,42 @@ class FlashInferIndicesUpdaterPrefill:
857
969
 
858
970
  # extend part
859
971
  if use_ragged:
860
- wrapper_ragged.begin_forward(
861
- qo_indptr,
972
+ if global_config.enable_flashinfer_mla:
973
+ wrapper_ragged.begin_forward(
974
+ qo_indptr=qo_indptr,
975
+ kv_indptr=qo_indptr,
976
+ num_qo_heads=self.num_qo_heads,
977
+ num_kv_heads=self.num_kv_heads,
978
+ head_dim_qk=192,
979
+ head_dim_vo=128,
980
+ q_data_type=self.q_data_type,
981
+ )
982
+ else:
983
+ wrapper_ragged.begin_forward(
984
+ qo_indptr,
985
+ qo_indptr,
986
+ self.num_qo_heads,
987
+ self.num_kv_heads,
988
+ self.head_dim,
989
+ q_data_type=self.q_data_type,
990
+ )
991
+
992
+ if not global_config.enable_flashinfer_mla:
993
+ # cached part
994
+ wrapper_paged.begin_forward(
862
995
  qo_indptr,
996
+ kv_indptr,
997
+ kv_indices,
998
+ self.kv_last_page_len[:bs],
863
999
  self.num_qo_heads,
864
1000
  self.num_kv_heads,
865
1001
  self.head_dim,
1002
+ 1,
866
1003
  q_data_type=self.q_data_type,
1004
+ custom_mask=custom_mask,
1005
+ non_blocking=True,
867
1006
  )
868
1007
 
869
- # cached part
870
- wrapper_paged.begin_forward(
871
- qo_indptr,
872
- kv_indptr,
873
- kv_indices,
874
- self.kv_last_page_len[:bs],
875
- self.num_qo_heads,
876
- self.num_kv_heads,
877
- self.head_dim,
878
- 1,
879
- q_data_type=self.q_data_type,
880
- custom_mask=custom_mask,
881
- non_blocking=True,
882
- )
883
-
884
1008
 
885
1009
  class FlashInferMultiStepDraftBackend:
886
1010
  """
@@ -947,7 +1071,7 @@ class FlashInferMultiStepDraftBackend:
947
1071
  triton.next_power_of_2(bs),
948
1072
  )
949
1073
 
950
- for i in range(self.speculative_num_steps):
1074
+ for i in range(self.speculative_num_steps - 1):
951
1075
  forward_batch.spec_info.kv_indptr = self.kv_indptr[i, : bs + 1]
952
1076
  forward_batch.spec_info.kv_indices = kv_indices_buffer[i][
953
1077
  : seq_lens_sum * self.topk + bs * (i + 1)
@@ -1163,6 +1287,7 @@ def fast_decode_plan(
1163
1287
  window_left,
1164
1288
  logits_soft_cap,
1165
1289
  head_dim,
1290
+ head_dim,
1166
1291
  empty_q_data,
1167
1292
  empty_kv_cache,
1168
1293
  stream.cuda_stream,