sglang 0.4.4.post2__py3-none-any.whl → 0.4.4.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. sglang/bench_serving.py +23 -3
  2. sglang/srt/configs/deepseekvl2.py +10 -1
  3. sglang/srt/configs/model_config.py +5 -16
  4. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
  5. sglang/srt/distributed/parallel_state.py +32 -5
  6. sglang/srt/entrypoints/http_server.py +7 -1
  7. sglang/srt/entrypoints/verl_engine.py +2 -0
  8. sglang/srt/function_call_parser.py +0 -1
  9. sglang/srt/layers/attention/flashattention_backend.py +218 -79
  10. sglang/srt/layers/dp_attention.py +12 -1
  11. sglang/srt/layers/moe/topk.py +30 -3
  12. sglang/srt/layers/quantization/__init__.py +134 -165
  13. sglang/srt/layers/quantization/awq.py +200 -0
  14. sglang/srt/layers/quantization/fp8_kernel.py +2 -1
  15. sglang/srt/layers/quantization/gptq.py +30 -40
  16. sglang/srt/layers/quantization/w8a8_fp8.py +1 -1
  17. sglang/srt/layers/rotary_embedding.py +12 -0
  18. sglang/srt/lora/backend/base_backend.py +4 -4
  19. sglang/srt/lora/backend/flashinfer_backend.py +12 -9
  20. sglang/srt/lora/backend/triton_backend.py +5 -8
  21. sglang/srt/lora/layers.py +19 -33
  22. sglang/srt/lora/lora_manager.py +20 -7
  23. sglang/srt/lora/mem_pool.py +12 -6
  24. sglang/srt/lora/triton_ops/gate_up_lora_b.py +10 -4
  25. sglang/srt/lora/triton_ops/qkv_lora_b.py +8 -3
  26. sglang/srt/lora/triton_ops/sgemm_lora_a.py +16 -5
  27. sglang/srt/lora/triton_ops/sgemm_lora_b.py +11 -6
  28. sglang/srt/lora/utils.py +6 -0
  29. sglang/srt/managers/io_struct.py +4 -2
  30. sglang/srt/managers/multimodal_processors/clip.py +63 -0
  31. sglang/srt/managers/schedule_batch.py +1 -0
  32. sglang/srt/managers/scheduler.py +25 -19
  33. sglang/srt/managers/tokenizer_manager.py +0 -1
  34. sglang/srt/managers/tp_worker.py +3 -0
  35. sglang/srt/model_executor/cuda_graph_runner.py +9 -8
  36. sglang/srt/model_executor/model_runner.py +9 -6
  37. sglang/srt/model_loader/loader.py +11 -1
  38. sglang/srt/model_loader/weight_utils.py +6 -3
  39. sglang/srt/models/clip.py +563 -0
  40. sglang/srt/models/deepseek_janus_pro.py +2 -2
  41. sglang/srt/models/deepseek_v2.py +151 -26
  42. sglang/srt/models/gemma3_causal.py +12 -2
  43. sglang/srt/models/gemma3_mm.py +6 -0
  44. sglang/srt/openai_api/adapter.py +88 -87
  45. sglang/srt/openai_api/protocol.py +10 -5
  46. sglang/srt/patch_torch.py +71 -0
  47. sglang/srt/server_args.py +21 -11
  48. sglang/srt/speculative/eagle_worker.py +1 -1
  49. sglang/srt/utils.py +33 -0
  50. sglang/test/runners.py +27 -2
  51. sglang/test/test_utils.py +1 -1
  52. sglang/version.py +1 -1
  53. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post3.dist-info}/METADATA +8 -4
  54. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post3.dist-info}/RECORD +57 -53
  55. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post3.dist-info}/WHEEL +0 -0
  56. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post3.dist-info}/licenses/LICENSE +0 -0
  57. {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post3.dist-info}/top_level.txt +0 -0
sglang/bench_serving.py CHANGED
@@ -965,7 +965,7 @@ async def benchmark(
965
965
  request_rate: float,
966
966
  max_concurrency: Optional[int],
967
967
  disable_tqdm: bool,
968
- lora_name: str,
968
+ lora_names: List[str],
969
969
  extra_request_body: Dict[str, Any],
970
970
  profile: bool,
971
971
  pd_seperated: bool = False,
@@ -988,6 +988,11 @@ async def benchmark(
988
988
  # Warmup
989
989
  print("Starting initial single prompt test run...")
990
990
  test_prompt, test_prompt_len, test_output_len = input_requests[0]
991
+ if lora_names != None and len(lora_names) != 0:
992
+ lora_name = lora_names[0]
993
+ else:
994
+ lora_name = None
995
+
991
996
  test_input = RequestFuncInput(
992
997
  model=model_id,
993
998
  prompt=test_prompt,
@@ -1028,6 +1033,12 @@ async def benchmark(
1028
1033
  tasks: List[asyncio.Task] = []
1029
1034
  async for request in get_request(input_requests, request_rate):
1030
1035
  prompt, prompt_len, output_len = request
1036
+ if lora_names != None and len(lora_names) != 0:
1037
+ idx = random.randint(0, len(lora_names) - 1)
1038
+ lora_name = lora_names[idx]
1039
+ else:
1040
+ lora_name = None
1041
+
1031
1042
  request_func_input = RequestFuncInput(
1032
1043
  model=model_id,
1033
1044
  prompt=prompt,
@@ -1347,7 +1358,7 @@ def run_benchmark(args_: argparse.Namespace):
1347
1358
  request_rate=args.request_rate,
1348
1359
  max_concurrency=args.max_concurrency,
1349
1360
  disable_tqdm=args.disable_tqdm,
1350
- lora_name=args.lora_name,
1361
+ lora_names=args.lora_name,
1351
1362
  extra_request_body=extra_request_body,
1352
1363
  profile=args.profile,
1353
1364
  pd_seperated=args.pd_seperated,
@@ -1366,6 +1377,13 @@ def set_ulimit(target_soft_limit=65535):
1366
1377
  print(f"Fail to set RLIMIT_NOFILE: {e}")
1367
1378
 
1368
1379
 
1380
+ class LoRAPathAction(argparse.Action):
1381
+ def __call__(self, parser, namespace, values, option_string=None):
1382
+ setattr(namespace, self.dest, [])
1383
+ for lora_name in values:
1384
+ getattr(namespace, self.dest).append(lora_name)
1385
+
1386
+
1369
1387
  if __name__ == "__main__":
1370
1388
  parser = ArgumentParser(description="Benchmark the online serving throughput.")
1371
1389
  parser.add_argument(
@@ -1509,8 +1527,10 @@ if __name__ == "__main__":
1509
1527
  parser.add_argument(
1510
1528
  "--lora-name",
1511
1529
  type=str,
1530
+ nargs="*",
1512
1531
  default=None,
1513
- help="The name of LoRA adapter",
1532
+ action=LoRAPathAction,
1533
+ help="The names of LoRA adapters. You can provide a list of names in the format {name} {name} {name}...",
1514
1534
  )
1515
1535
  parser.add_argument(
1516
1536
  "--prompt-suffix",
@@ -4,7 +4,6 @@ from dataclasses import dataclass
4
4
  from typing import Dict, List, Optional, Tuple
5
5
 
6
6
  import torch
7
- import torchvision.transforms as T
8
7
  from PIL import Image, ImageOps
9
8
  from transformers import (
10
9
  AutoProcessor,
@@ -76,6 +75,16 @@ class ImageTransform(object):
76
75
  self.std = std
77
76
  self.normalize = normalize
78
77
 
78
+ # only load torchvision.transforms when needed
79
+ try:
80
+ import torchvision.transforms as T
81
+
82
+ # FIXME: add version check for gguf
83
+ except ImportError as err:
84
+ raise ImportError(
85
+ "Please install torchvision via `pip install torchvision` to use Deepseek-VL2."
86
+ ) from err
87
+
79
88
  transform_pipelines = [T.ToTensor()]
80
89
 
81
90
  if normalize:
@@ -22,11 +22,7 @@ import torch
22
22
  from transformers import PretrainedConfig
23
23
 
24
24
  from sglang.srt.hf_transformers_utils import get_config, get_context_length
25
- from sglang.srt.layers.quantization import (
26
- BASE_QUANTIZATION_METHODS,
27
- QUANTIZATION_METHODS,
28
- VLLM_AVAILABLE,
29
- )
25
+ from sglang.srt.layers.quantization import QUANTIZATION_METHODS
30
26
  from sglang.srt.utils import get_bool_env_var, is_hip
31
27
 
32
28
  logger = logging.getLogger(__name__)
@@ -239,12 +235,7 @@ class ModelConfig:
239
235
 
240
236
  # adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
241
237
  def _verify_quantization(self) -> None:
242
- # Select supported quantization methods based on vllm availability
243
- if VLLM_AVAILABLE:
244
- supported_quantization = [*QUANTIZATION_METHODS]
245
- else:
246
- supported_quantization = [*BASE_QUANTIZATION_METHODS]
247
-
238
+ supported_quantization = [*QUANTIZATION_METHODS]
248
239
  rocm_supported_quantization = [
249
240
  "awq",
250
241
  "gptq",
@@ -282,11 +273,7 @@ class ModelConfig:
282
273
  quant_method = quant_cfg.get("quant_method", "").lower()
283
274
 
284
275
  # Detect which checkpoint is it
285
- # Only iterate through currently available quantization methods
286
- available_methods = (
287
- QUANTIZATION_METHODS if VLLM_AVAILABLE else BASE_QUANTIZATION_METHODS
288
- )
289
- for _, method in available_methods.items():
276
+ for _, method in QUANTIZATION_METHODS.items():
290
277
  quantization_override = method.override_quantization_method(
291
278
  quant_cfg, self.quantization
292
279
  )
@@ -467,6 +454,7 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal
467
454
  or "InternLM2ForRewardModel" in model_architectures
468
455
  or "Qwen2ForRewardModel" in model_architectures
469
456
  or "Qwen2ForSequenceClassification" in model_architectures
457
+ or "CLIPModel" in model_architectures
470
458
  ):
471
459
  return False
472
460
  else:
@@ -488,6 +476,7 @@ multimodal_model_archs = [
488
476
  "MllamaForConditionalGeneration",
489
477
  "Qwen2VLForConditionalGeneration",
490
478
  "Qwen2_5_VLForConditionalGeneration",
479
+ "CLIPModel",
491
480
  ]
492
481
 
493
482
 
@@ -5,7 +5,7 @@ import logging
5
5
  import os
6
6
  from contextlib import contextmanager
7
7
  from functools import wraps
8
- from typing import Callable, List, Optional, TypeVar, Union
8
+ from typing import Any, Callable, List, Optional, TypeVar, Union
9
9
 
10
10
  import torch
11
11
  import torch.distributed as dist
@@ -264,10 +264,16 @@ class GroupCoordinator:
264
264
  self.ca_comm: Optional[CustomAllreduce] = None
265
265
  if use_custom_allreduce and self.world_size > 1:
266
266
  # Initialize a custom fast all-reduce implementation.
267
- self.ca_comm = CustomAllreduce(
268
- group=self.cpu_group,
269
- device=self.device,
270
- )
267
+ try:
268
+ self.ca_comm = CustomAllreduce(
269
+ group=self.cpu_group,
270
+ device=self.device,
271
+ )
272
+ except Exception as e:
273
+ logger.warning(
274
+ f"Setup Custom allreduce failed with {e}. To silence this "
275
+ "warning, specify --disable-custom-all-reduce explicitly."
276
+ )
271
277
 
272
278
  from sglang.srt.distributed.device_communicators.hpu_communicator import (
273
279
  HpuCommunicator,
@@ -439,6 +445,15 @@ class GroupCoordinator:
439
445
  else:
440
446
  torch.distributed.all_reduce(input_, group=self.device_group)
441
447
 
448
+ def reduce_scatter(
449
+ self,
450
+ output: torch.Tensor,
451
+ input_list: List[torch.Tensor],
452
+ ) -> None:
453
+ # TODO(ch-wan): support other backends
454
+ torch.distributed.reduce_scatter(output, input_list, group=self.device_group)
455
+ return output
456
+
442
457
  def _all_gather_into_tensor(self, output: torch.Tensor, input: torch.Tensor):
443
458
  pynccl_comm = self.pynccl_comm
444
459
  if pynccl_comm is not None and not pynccl_comm.disabled:
@@ -456,11 +471,23 @@ class GroupCoordinator:
456
471
  output, input, group_name=self.unique_name
457
472
  )
458
473
 
459
- def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
474
+ def all_gather(
475
+ self,
476
+ input_: torch.Tensor,
477
+ dim: int = -1,
478
+ tensor_list: List[torch.Tensor] = None,
479
+ ) -> torch.Tensor:
460
480
  world_size = self.world_size
461
481
  # Bypass the function if we are using only 1 GPU.
462
482
  if world_size == 1:
463
483
  return input_
484
+
485
+ if tensor_list is not None:
486
+ # TODO(ch-wan): support other backends
487
+ return torch.distributed.all_gather(
488
+ tensor_list, input_, group=self.device_group
489
+ )
490
+
464
491
  assert (
465
492
  -input_.dim() <= dim < input_.dim()
466
493
  ), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}"
@@ -561,7 +561,13 @@ def available_models():
561
561
  served_model_names = [_global_state.tokenizer_manager.served_model_name]
562
562
  model_cards = []
563
563
  for served_model_name in served_model_names:
564
- model_cards.append(ModelCard(id=served_model_name, root=served_model_name))
564
+ model_cards.append(
565
+ ModelCard(
566
+ id=served_model_name,
567
+ root=served_model_name,
568
+ max_model_len=_global_state.tokenizer_manager.model_config.context_len,
569
+ )
570
+ )
565
571
  return ModelList(data=model_cards)
566
572
 
567
573
 
@@ -19,6 +19,7 @@ import torch.distributed as dist
19
19
  from torch.distributed.tensor import DeviceMesh, DTensor
20
20
 
21
21
  from sglang.srt.model_executor.model_runner import LocalSerializedTensor
22
+ from sglang.srt.patch_torch import monkey_patch_torch_reductions
22
23
  from sglang.srt.server import Engine
23
24
  from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj
24
25
 
@@ -30,6 +31,7 @@ class VerlEngine:
30
31
  nnodes: int = 1,
31
32
  **kwargs,
32
33
  ):
34
+ monkey_patch_torch_reductions()
33
35
  self._device_mesh_cpu = device_mesh_cpu
34
36
  self._tp_rank = device_mesh_cpu.get_local_rank()
35
37
  self._tp_size = device_mesh_cpu.size()
@@ -290,7 +290,6 @@ class BaseFormatDetector(ABC):
290
290
  calls=[
291
291
  ToolCallItem(
292
292
  tool_index=self.current_tool_id,
293
- name="",
294
293
  parameters=argument_diff,
295
294
  )
296
295
  ],
@@ -13,7 +13,9 @@ from typing import TYPE_CHECKING, Optional, Union
13
13
 
14
14
  import torch
15
15
 
16
+ from sglang.srt.configs.model_config import AttentionArch
16
17
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
18
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
17
19
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
18
20
 
19
21
  if TYPE_CHECKING:
@@ -29,11 +31,11 @@ class FlashAttentionMetadata:
29
31
 
30
32
  cu_seqlens_q: torch.Tensor = None
31
33
  cu_seqlens_k: torch.Tensor = None
34
+ max_seq_len_q: int = 0
32
35
  max_seq_len_k: int = 0
33
36
  window_size: tuple = (-1, -1)
34
37
  page_table: torch.Tensor = None
35
38
  cache_seqlens_int32: torch.Tensor = None
36
- max_seq_len_q: int = 0
37
39
 
38
40
 
39
41
  class FlashAttentionBackend(AttentionBackend):
@@ -57,13 +59,16 @@ class FlashAttentionBackend(AttentionBackend):
57
59
  self.device = model_runner.device
58
60
  self.decode_cuda_graph_metadata = {}
59
61
  self.req_to_token = model_runner.req_to_token_pool.req_to_token
62
+ self.page_size = model_runner.page_size
63
+ self.use_mla = (
64
+ model_runner.model_config.attention_arch == AttentionArch.MLA
65
+ ) and (not global_server_args_dict["disable_mla"])
60
66
 
61
67
  def init_forward_metadata(self, forward_batch: ForwardBatch):
62
68
  """Initialize forward metadata to cache repetitive calculations."""
63
69
  # Create metadata based on forward mode
64
70
  metadata = FlashAttentionMetadata()
65
71
 
66
- extend_seq_lens = forward_batch.extend_seq_lens
67
72
  # Get sequence information
68
73
  seqlens_in_batch = forward_batch.seq_lens
69
74
  # Precompute int32 version of sequence lengths
@@ -79,21 +84,33 @@ class FlashAttentionBackend(AttentionBackend):
79
84
  metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
80
85
  forward_batch.req_pool_indices, : metadata.max_seq_len_k
81
86
  ]
87
+
88
+ # Precompute strided indices
89
+ # [0, page_size, 2 * page_size, ...]
90
+ if self.page_size > 1:
91
+ self.strided_indices = torch.arange(
92
+ 0, metadata.page_table.shape[1], self.page_size, device=self.device
93
+ )
94
+ metadata.page_table = (
95
+ metadata.page_table[:, self.strided_indices] // self.page_size
96
+ )
97
+
82
98
  if forward_batch.forward_mode == ForwardMode.DECODE:
83
99
  # Precompute cumulative sequence lengths
84
100
  metadata.cu_seqlens_q = torch.arange(
85
101
  0, batch_size + 1, dtype=torch.int32, device=device
86
102
  )
87
103
  else:
88
- extend_no_prefix = not any(forward_batch.extend_prefix_lens)
89
104
  # Precompute cumulative sequence lengths
90
- if not extend_no_prefix:
105
+ if any(forward_batch.extend_prefix_lens_cpu):
106
+ extend_seq_lens = forward_batch.extend_seq_lens
91
107
  metadata.cu_seqlens_q = torch.nn.functional.pad(
92
108
  torch.cumsum(extend_seq_lens, dim=0, dtype=torch.int32), (1, 0)
93
109
  )
110
+ metadata.max_seq_len_q = max(forward_batch.extend_seq_lens_cpu)
94
111
  else:
95
112
  metadata.cu_seqlens_q = metadata.cu_seqlens_k
96
- metadata.max_seq_len_q = seqlens_in_batch.max().item()
113
+ metadata.max_seq_len_q = metadata.max_seq_len_k
97
114
  self.forward_metadata = metadata
98
115
 
99
116
  def forward_extend(
@@ -105,23 +122,30 @@ class FlashAttentionBackend(AttentionBackend):
105
122
  forward_batch: ForwardBatch,
106
123
  save_kv_cache=True,
107
124
  ):
108
- cache_loc = (
109
- forward_batch.out_cache_loc
110
- if not layer.is_cross_attention
111
- else forward_batch.encoder_out_cache_loc
112
- )
113
125
 
114
126
  if k is not None:
115
127
  assert v is not None
116
128
  if save_kv_cache:
117
- forward_batch.token_to_kv_pool.set_kv_buffer(
118
- layer, cache_loc, k, v, layer.k_scale, layer.v_scale
129
+ cache_loc = (
130
+ forward_batch.out_cache_loc
131
+ if not layer.is_cross_attention
132
+ else forward_batch.encoder_out_cache_loc
119
133
  )
134
+ if not self.use_mla:
135
+ forward_batch.token_to_kv_pool.set_kv_buffer(
136
+ layer, cache_loc, k, v, layer.k_scale, layer.v_scale
137
+ )
138
+ else:
139
+ forward_batch.token_to_kv_pool.set_kv_buffer(
140
+ layer,
141
+ cache_loc,
142
+ k,
143
+ v,
144
+ )
120
145
 
121
146
  # Use precomputed metadata
122
147
  metadata = self.forward_metadata
123
148
 
124
- # # Use Flash Attention for prefill
125
149
  # Calculate window size (can be moved to metadata if layer properties don't change)
126
150
  # we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1
127
151
  # here is two side inclusive
@@ -130,26 +154,72 @@ class FlashAttentionBackend(AttentionBackend):
130
154
  if layer.sliding_window_size is not None
131
155
  else (-1, -1)
132
156
  )
133
- kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
134
- key_cache, value_cache = kv_cache[0], kv_cache[1]
135
- o = flash_attn_with_kvcache(
136
- q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
137
- k_cache=key_cache.unsqueeze(1),
138
- v_cache=value_cache.unsqueeze(1),
139
- page_table=metadata.page_table,
140
- cache_seqlens=metadata.cache_seqlens_int32,
141
- cu_seqlens_q=metadata.cu_seqlens_q,
142
- cu_seqlens_k_new=metadata.cu_seqlens_k,
143
- max_seqlen_q=metadata.max_seq_len_q,
144
- softmax_scale=layer.scaling,
145
- causal=True,
146
- window_size=window_size,
147
- softcap=layer.logit_cap,
148
- k_descale=layer.k_scale,
149
- v_descale=layer.v_scale,
150
- )
151
157
 
152
- return o.view(-1, layer.tp_q_head_num * layer.head_dim)
158
+ page_table = metadata.page_table
159
+
160
+ # # Use Flash Attention for prefill
161
+ if not self.use_mla:
162
+ # Do multi-head attention
163
+ kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
164
+ key_cache, value_cache = kv_cache[0], kv_cache[1]
165
+ key_cache = key_cache.view(
166
+ -1, self.page_size, layer.tp_k_head_num, layer.head_dim
167
+ )
168
+ value_cache = value_cache.view(
169
+ -1, self.page_size, layer.tp_v_head_num, layer.head_dim
170
+ )
171
+ o = flash_attn_with_kvcache(
172
+ q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
173
+ k_cache=key_cache,
174
+ v_cache=value_cache,
175
+ page_table=page_table,
176
+ cache_seqlens=metadata.cache_seqlens_int32,
177
+ cu_seqlens_q=metadata.cu_seqlens_q,
178
+ cu_seqlens_k_new=metadata.cu_seqlens_k,
179
+ max_seqlen_q=metadata.max_seq_len_q,
180
+ softmax_scale=layer.scaling,
181
+ causal=True,
182
+ window_size=window_size,
183
+ softcap=layer.logit_cap,
184
+ k_descale=layer.k_scale,
185
+ v_descale=layer.v_scale,
186
+ )
187
+ else:
188
+ # Do absorbed multi-latent attention
189
+ kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
190
+ k_rope = kv_cache[:, :, layer.v_head_dim :]
191
+ c_kv = kv_cache[:, :, : layer.v_head_dim]
192
+ k_rope_cache = k_rope.view(
193
+ -1,
194
+ self.page_size,
195
+ layer.tp_k_head_num,
196
+ layer.head_dim - layer.v_head_dim,
197
+ )
198
+ c_kv_cache = c_kv.view(
199
+ -1, self.page_size, layer.tp_v_head_num, layer.v_head_dim
200
+ )
201
+
202
+ q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
203
+ q_nope = q_all[:, :, : layer.v_head_dim]
204
+ q_rope = q_all[:, :, layer.v_head_dim :]
205
+ o = flash_attn_with_kvcache(
206
+ q=q_rope,
207
+ k_cache=k_rope_cache,
208
+ v_cache=c_kv_cache,
209
+ qv=q_nope,
210
+ page_table=page_table,
211
+ cache_seqlens=metadata.cache_seqlens_int32,
212
+ cu_seqlens_q=metadata.cu_seqlens_q,
213
+ cu_seqlens_k_new=metadata.cu_seqlens_k,
214
+ max_seqlen_q=metadata.max_seq_len_q,
215
+ softmax_scale=layer.scaling,
216
+ causal=True,
217
+ softcap=layer.logit_cap,
218
+ k_descale=layer.k_scale,
219
+ v_descale=layer.v_scale,
220
+ )
221
+
222
+ return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
153
223
 
154
224
  def forward_decode(
155
225
  self,
@@ -162,26 +232,29 @@ class FlashAttentionBackend(AttentionBackend):
162
232
  ) -> torch.Tensor:
163
233
  """Forward pass with FlashAttention using precomputed metadata."""
164
234
  # Save KV cache if needed
165
- if k is not None and v is not None and save_kv_cache:
166
- cache_loc = (
167
- forward_batch.out_cache_loc
168
- if not layer.is_cross_attention
169
- else forward_batch.encoder_out_cache_loc
170
- )
171
- forward_batch.token_to_kv_pool.set_kv_buffer(
172
- layer, cache_loc, k, v, layer.k_scale, layer.v_scale
173
- )
174
-
175
- # Get KV cache
176
- kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
177
- key_cache, value_cache = kv_cache[0], kv_cache[1]
235
+ if k is not None:
236
+ assert v is not None
237
+ if save_kv_cache:
238
+ cache_loc = (
239
+ forward_batch.out_cache_loc
240
+ if not layer.is_cross_attention
241
+ else forward_batch.encoder_out_cache_loc
242
+ )
243
+ if not self.use_mla:
244
+ forward_batch.token_to_kv_pool.set_kv_buffer(
245
+ layer, cache_loc, k, v, layer.k_scale, layer.v_scale
246
+ )
247
+ else:
248
+ forward_batch.token_to_kv_pool.set_kv_buffer(
249
+ layer,
250
+ cache_loc,
251
+ k,
252
+ v,
253
+ )
178
254
 
179
255
  # Use precomputed metadata
180
256
  metadata = self.forward_metadata
181
257
 
182
- # Pre-reshape query tensor
183
- q_reshaped = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
184
-
185
258
  # Calculate window size (can be moved to metadata if layer properties don't change)
186
259
  # we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1
187
260
  # here is two side inclusive
@@ -190,25 +263,79 @@ class FlashAttentionBackend(AttentionBackend):
190
263
  if layer.sliding_window_size is not None
191
264
  else (-1, -1)
192
265
  )
193
- # Run attention with precomputed values
194
- o = flash_attn_with_kvcache(
195
- q=q_reshaped,
196
- k_cache=key_cache.unsqueeze(1),
197
- v_cache=value_cache.unsqueeze(1),
198
- page_table=metadata.page_table,
199
- cache_seqlens=metadata.cache_seqlens_int32,
200
- cu_seqlens_q=metadata.cu_seqlens_q,
201
- cu_seqlens_k_new=metadata.cu_seqlens_k,
202
- max_seqlen_q=1,
203
- softmax_scale=layer.scaling,
204
- causal=True,
205
- window_size=window_size,
206
- softcap=layer.logit_cap,
207
- k_descale=layer.k_scale,
208
- v_descale=layer.v_scale,
209
- )
210
266
 
211
- return o.view(-1, layer.tp_q_head_num * layer.head_dim)
267
+ page_table = metadata.page_table
268
+
269
+ if not self.use_mla:
270
+ # Do multi-head attention
271
+
272
+ # Get KV cache
273
+ kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
274
+ key_cache, value_cache = kv_cache[0], kv_cache[1]
275
+ key_cache = key_cache.view(
276
+ -1, self.page_size, layer.tp_k_head_num, layer.head_dim
277
+ )
278
+ value_cache = value_cache.view(
279
+ -1, self.page_size, layer.tp_v_head_num, layer.head_dim
280
+ )
281
+
282
+ # Pre-reshape query tensor
283
+ q_reshaped = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
284
+
285
+ # Run attention with precomputed values
286
+ o = flash_attn_with_kvcache(
287
+ q=q_reshaped,
288
+ k_cache=key_cache,
289
+ v_cache=value_cache,
290
+ page_table=page_table,
291
+ cache_seqlens=metadata.cache_seqlens_int32,
292
+ cu_seqlens_q=metadata.cu_seqlens_q,
293
+ cu_seqlens_k_new=metadata.cu_seqlens_k,
294
+ max_seqlen_q=1,
295
+ softmax_scale=layer.scaling,
296
+ causal=True,
297
+ window_size=window_size,
298
+ softcap=layer.logit_cap,
299
+ k_descale=layer.k_scale,
300
+ v_descale=layer.v_scale,
301
+ )
302
+ else:
303
+ # Do absorbed multi-latent attention
304
+ kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
305
+ k_rope = kv_cache[:, :, layer.v_head_dim :]
306
+ c_kv = kv_cache[:, :, : layer.v_head_dim]
307
+ k_rope_cache = k_rope.view(
308
+ -1,
309
+ self.page_size,
310
+ layer.tp_k_head_num,
311
+ layer.head_dim - layer.v_head_dim,
312
+ )
313
+ c_kv_cache = c_kv.view(
314
+ -1, self.page_size, layer.tp_v_head_num, layer.v_head_dim
315
+ )
316
+
317
+ q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
318
+ q_nope = q_all[:, :, : layer.v_head_dim]
319
+ q_rope = q_all[:, :, layer.v_head_dim :]
320
+
321
+ o = flash_attn_with_kvcache(
322
+ q=q_rope,
323
+ k_cache=k_rope_cache,
324
+ v_cache=c_kv_cache,
325
+ qv=q_nope,
326
+ page_table=page_table,
327
+ cache_seqlens=metadata.cache_seqlens_int32,
328
+ cu_seqlens_q=metadata.cu_seqlens_q,
329
+ cu_seqlens_k_new=metadata.cu_seqlens_k,
330
+ max_seqlen_q=1,
331
+ softmax_scale=layer.scaling,
332
+ causal=True,
333
+ softcap=layer.logit_cap,
334
+ k_descale=layer.k_scale,
335
+ v_descale=layer.v_scale,
336
+ )
337
+
338
+ return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
212
339
 
213
340
  def init_cuda_graph_state(self, max_bs: int):
214
341
  """Initialize CUDA graph state for the attention backend.
@@ -223,7 +350,13 @@ class FlashAttentionBackend(AttentionBackend):
223
350
  self.decode_cuda_graph_metadata = {
224
351
  # Page table for token mapping (batch_size, max_context_len)
225
352
  "page_table": torch.zeros(
226
- max_bs, self.max_context_len, dtype=torch.int32, device=self.device
353
+ max_bs,
354
+ (self.max_context_len + self.page_size - 1) // self.page_size,
355
+ dtype=torch.int32,
356
+ device=self.device,
357
+ ),
358
+ "strided_indices": torch.arange(
359
+ 0, self.max_context_len, self.page_size, device=self.device
227
360
  ),
228
361
  }
229
362
 
@@ -274,21 +407,27 @@ class FlashAttentionBackend(AttentionBackend):
274
407
  seq_lens_cpu: Optional[torch.Tensor],
275
408
  ):
276
409
  # """Initialize forward metadata for replaying CUDA graph."""
277
- seqlens_in_batch = seq_lens[:bs]
278
410
  metadata = self.decode_cuda_graph_metadata[bs]
279
- metadata.cache_seqlens_int32 = seqlens_in_batch.to(torch.int32)
411
+
412
+ # For CPU operations
413
+ max_len = seq_lens_cpu[:bs].max().item()
414
+ metadata.max_seq_len_k = max_len
415
+
416
+ # For GPU operations
417
+ seq_lens_in_batch = seq_lens[:bs]
418
+ metadata.cache_seqlens_int32 = seq_lens_in_batch.to(torch.int32)
280
419
  metadata.cu_seqlens_k = torch.nn.functional.pad(
281
- torch.cumsum(seqlens_in_batch, dim=0, dtype=torch.int32), (1, 0)
420
+ torch.cumsum(seq_lens_in_batch, dim=0, dtype=torch.int32), (1, 0)
282
421
  )
283
- # Precompute maximum sequence length
284
- metadata.max_seq_len_k = seqlens_in_batch.max().item()
285
- # Only zero out the part out of max_len_k
286
- metadata.page_table[:, metadata.max_seq_len_k :].fill_(0)
287
- # Then do the copy
288
- metadata.page_table[:, : metadata.max_seq_len_k].copy_(
289
- self.req_to_token[req_pool_indices[:bs], : metadata.max_seq_len_k]
290
- )
291
- self.forward_decode_metadata = metadata
422
+
423
+ max_seq_pages = (metadata.max_seq_len_k + self.page_size - 1) // self.page_size
424
+ page_indices = self.req_to_token[
425
+ :, self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages]
426
+ ]
427
+ page_indices = page_indices[req_pool_indices[:bs]] // self.page_size
428
+ metadata.page_table[:, :max_seq_pages].copy_(page_indices)
429
+ metadata.page_table[:, max_seq_pages:].fill_(0)
430
+ self.forward_metadata = metadata
292
431
 
293
432
  def get_cuda_graph_seq_len_fill_value(self):
294
433
  """Get the fill value for sequence length in CUDA graph."""