sglang 0.4.4.post2__py3-none-any.whl → 0.4.4.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +23 -3
- sglang/srt/configs/deepseekvl2.py +10 -1
- sglang/srt/configs/model_config.py +5 -16
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
- sglang/srt/distributed/parallel_state.py +32 -5
- sglang/srt/entrypoints/http_server.py +7 -1
- sglang/srt/entrypoints/verl_engine.py +2 -0
- sglang/srt/function_call_parser.py +0 -1
- sglang/srt/layers/attention/flashattention_backend.py +218 -79
- sglang/srt/layers/dp_attention.py +12 -1
- sglang/srt/layers/moe/topk.py +30 -3
- sglang/srt/layers/quantization/__init__.py +134 -165
- sglang/srt/layers/quantization/awq.py +200 -0
- sglang/srt/layers/quantization/fp8_kernel.py +2 -1
- sglang/srt/layers/quantization/gptq.py +30 -40
- sglang/srt/layers/quantization/w8a8_fp8.py +1 -1
- sglang/srt/layers/rotary_embedding.py +12 -0
- sglang/srt/lora/backend/base_backend.py +4 -4
- sglang/srt/lora/backend/flashinfer_backend.py +12 -9
- sglang/srt/lora/backend/triton_backend.py +5 -8
- sglang/srt/lora/layers.py +19 -33
- sglang/srt/lora/lora_manager.py +20 -7
- sglang/srt/lora/mem_pool.py +12 -6
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +10 -4
- sglang/srt/lora/triton_ops/qkv_lora_b.py +8 -3
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +16 -5
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +11 -6
- sglang/srt/lora/utils.py +6 -0
- sglang/srt/managers/io_struct.py +4 -2
- sglang/srt/managers/multimodal_processors/clip.py +63 -0
- sglang/srt/managers/schedule_batch.py +1 -0
- sglang/srt/managers/scheduler.py +25 -19
- sglang/srt/managers/tokenizer_manager.py +0 -1
- sglang/srt/managers/tp_worker.py +3 -0
- sglang/srt/model_executor/cuda_graph_runner.py +9 -8
- sglang/srt/model_executor/model_runner.py +9 -6
- sglang/srt/model_loader/loader.py +11 -1
- sglang/srt/model_loader/weight_utils.py +6 -3
- sglang/srt/models/clip.py +563 -0
- sglang/srt/models/deepseek_janus_pro.py +2 -2
- sglang/srt/models/deepseek_v2.py +151 -26
- sglang/srt/models/gemma3_causal.py +12 -2
- sglang/srt/models/gemma3_mm.py +6 -0
- sglang/srt/openai_api/adapter.py +88 -87
- sglang/srt/openai_api/protocol.py +10 -5
- sglang/srt/patch_torch.py +71 -0
- sglang/srt/server_args.py +21 -11
- sglang/srt/speculative/eagle_worker.py +1 -1
- sglang/srt/utils.py +33 -0
- sglang/test/runners.py +27 -2
- sglang/test/test_utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post3.dist-info}/METADATA +8 -4
- {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post3.dist-info}/RECORD +57 -53
- {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post3.dist-info}/WHEEL +0 -0
- {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.4.post2.dist-info → sglang-0.4.4.post3.dist-info}/top_level.txt +0 -0
sglang/bench_serving.py
CHANGED
@@ -965,7 +965,7 @@ async def benchmark(
|
|
965
965
|
request_rate: float,
|
966
966
|
max_concurrency: Optional[int],
|
967
967
|
disable_tqdm: bool,
|
968
|
-
|
968
|
+
lora_names: List[str],
|
969
969
|
extra_request_body: Dict[str, Any],
|
970
970
|
profile: bool,
|
971
971
|
pd_seperated: bool = False,
|
@@ -988,6 +988,11 @@ async def benchmark(
|
|
988
988
|
# Warmup
|
989
989
|
print("Starting initial single prompt test run...")
|
990
990
|
test_prompt, test_prompt_len, test_output_len = input_requests[0]
|
991
|
+
if lora_names != None and len(lora_names) != 0:
|
992
|
+
lora_name = lora_names[0]
|
993
|
+
else:
|
994
|
+
lora_name = None
|
995
|
+
|
991
996
|
test_input = RequestFuncInput(
|
992
997
|
model=model_id,
|
993
998
|
prompt=test_prompt,
|
@@ -1028,6 +1033,12 @@ async def benchmark(
|
|
1028
1033
|
tasks: List[asyncio.Task] = []
|
1029
1034
|
async for request in get_request(input_requests, request_rate):
|
1030
1035
|
prompt, prompt_len, output_len = request
|
1036
|
+
if lora_names != None and len(lora_names) != 0:
|
1037
|
+
idx = random.randint(0, len(lora_names) - 1)
|
1038
|
+
lora_name = lora_names[idx]
|
1039
|
+
else:
|
1040
|
+
lora_name = None
|
1041
|
+
|
1031
1042
|
request_func_input = RequestFuncInput(
|
1032
1043
|
model=model_id,
|
1033
1044
|
prompt=prompt,
|
@@ -1347,7 +1358,7 @@ def run_benchmark(args_: argparse.Namespace):
|
|
1347
1358
|
request_rate=args.request_rate,
|
1348
1359
|
max_concurrency=args.max_concurrency,
|
1349
1360
|
disable_tqdm=args.disable_tqdm,
|
1350
|
-
|
1361
|
+
lora_names=args.lora_name,
|
1351
1362
|
extra_request_body=extra_request_body,
|
1352
1363
|
profile=args.profile,
|
1353
1364
|
pd_seperated=args.pd_seperated,
|
@@ -1366,6 +1377,13 @@ def set_ulimit(target_soft_limit=65535):
|
|
1366
1377
|
print(f"Fail to set RLIMIT_NOFILE: {e}")
|
1367
1378
|
|
1368
1379
|
|
1380
|
+
class LoRAPathAction(argparse.Action):
|
1381
|
+
def __call__(self, parser, namespace, values, option_string=None):
|
1382
|
+
setattr(namespace, self.dest, [])
|
1383
|
+
for lora_name in values:
|
1384
|
+
getattr(namespace, self.dest).append(lora_name)
|
1385
|
+
|
1386
|
+
|
1369
1387
|
if __name__ == "__main__":
|
1370
1388
|
parser = ArgumentParser(description="Benchmark the online serving throughput.")
|
1371
1389
|
parser.add_argument(
|
@@ -1509,8 +1527,10 @@ if __name__ == "__main__":
|
|
1509
1527
|
parser.add_argument(
|
1510
1528
|
"--lora-name",
|
1511
1529
|
type=str,
|
1530
|
+
nargs="*",
|
1512
1531
|
default=None,
|
1513
|
-
|
1532
|
+
action=LoRAPathAction,
|
1533
|
+
help="The names of LoRA adapters. You can provide a list of names in the format {name} {name} {name}...",
|
1514
1534
|
)
|
1515
1535
|
parser.add_argument(
|
1516
1536
|
"--prompt-suffix",
|
@@ -4,7 +4,6 @@ from dataclasses import dataclass
|
|
4
4
|
from typing import Dict, List, Optional, Tuple
|
5
5
|
|
6
6
|
import torch
|
7
|
-
import torchvision.transforms as T
|
8
7
|
from PIL import Image, ImageOps
|
9
8
|
from transformers import (
|
10
9
|
AutoProcessor,
|
@@ -76,6 +75,16 @@ class ImageTransform(object):
|
|
76
75
|
self.std = std
|
77
76
|
self.normalize = normalize
|
78
77
|
|
78
|
+
# only load torchvision.transforms when needed
|
79
|
+
try:
|
80
|
+
import torchvision.transforms as T
|
81
|
+
|
82
|
+
# FIXME: add version check for gguf
|
83
|
+
except ImportError as err:
|
84
|
+
raise ImportError(
|
85
|
+
"Please install torchvision via `pip install torchvision` to use Deepseek-VL2."
|
86
|
+
) from err
|
87
|
+
|
79
88
|
transform_pipelines = [T.ToTensor()]
|
80
89
|
|
81
90
|
if normalize:
|
@@ -22,11 +22,7 @@ import torch
|
|
22
22
|
from transformers import PretrainedConfig
|
23
23
|
|
24
24
|
from sglang.srt.hf_transformers_utils import get_config, get_context_length
|
25
|
-
from sglang.srt.layers.quantization import
|
26
|
-
BASE_QUANTIZATION_METHODS,
|
27
|
-
QUANTIZATION_METHODS,
|
28
|
-
VLLM_AVAILABLE,
|
29
|
-
)
|
25
|
+
from sglang.srt.layers.quantization import QUANTIZATION_METHODS
|
30
26
|
from sglang.srt.utils import get_bool_env_var, is_hip
|
31
27
|
|
32
28
|
logger = logging.getLogger(__name__)
|
@@ -239,12 +235,7 @@ class ModelConfig:
|
|
239
235
|
|
240
236
|
# adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
|
241
237
|
def _verify_quantization(self) -> None:
|
242
|
-
|
243
|
-
if VLLM_AVAILABLE:
|
244
|
-
supported_quantization = [*QUANTIZATION_METHODS]
|
245
|
-
else:
|
246
|
-
supported_quantization = [*BASE_QUANTIZATION_METHODS]
|
247
|
-
|
238
|
+
supported_quantization = [*QUANTIZATION_METHODS]
|
248
239
|
rocm_supported_quantization = [
|
249
240
|
"awq",
|
250
241
|
"gptq",
|
@@ -282,11 +273,7 @@ class ModelConfig:
|
|
282
273
|
quant_method = quant_cfg.get("quant_method", "").lower()
|
283
274
|
|
284
275
|
# Detect which checkpoint is it
|
285
|
-
|
286
|
-
available_methods = (
|
287
|
-
QUANTIZATION_METHODS if VLLM_AVAILABLE else BASE_QUANTIZATION_METHODS
|
288
|
-
)
|
289
|
-
for _, method in available_methods.items():
|
276
|
+
for _, method in QUANTIZATION_METHODS.items():
|
290
277
|
quantization_override = method.override_quantization_method(
|
291
278
|
quant_cfg, self.quantization
|
292
279
|
)
|
@@ -467,6 +454,7 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal
|
|
467
454
|
or "InternLM2ForRewardModel" in model_architectures
|
468
455
|
or "Qwen2ForRewardModel" in model_architectures
|
469
456
|
or "Qwen2ForSequenceClassification" in model_architectures
|
457
|
+
or "CLIPModel" in model_architectures
|
470
458
|
):
|
471
459
|
return False
|
472
460
|
else:
|
@@ -488,6 +476,7 @@ multimodal_model_archs = [
|
|
488
476
|
"MllamaForConditionalGeneration",
|
489
477
|
"Qwen2VLForConditionalGeneration",
|
490
478
|
"Qwen2_5_VLForConditionalGeneration",
|
479
|
+
"CLIPModel",
|
491
480
|
]
|
492
481
|
|
493
482
|
|
@@ -5,7 +5,7 @@ import logging
|
|
5
5
|
import os
|
6
6
|
from contextlib import contextmanager
|
7
7
|
from functools import wraps
|
8
|
-
from typing import Callable, List, Optional, TypeVar, Union
|
8
|
+
from typing import Any, Callable, List, Optional, TypeVar, Union
|
9
9
|
|
10
10
|
import torch
|
11
11
|
import torch.distributed as dist
|
@@ -264,10 +264,16 @@ class GroupCoordinator:
|
|
264
264
|
self.ca_comm: Optional[CustomAllreduce] = None
|
265
265
|
if use_custom_allreduce and self.world_size > 1:
|
266
266
|
# Initialize a custom fast all-reduce implementation.
|
267
|
-
|
268
|
-
|
269
|
-
|
270
|
-
|
267
|
+
try:
|
268
|
+
self.ca_comm = CustomAllreduce(
|
269
|
+
group=self.cpu_group,
|
270
|
+
device=self.device,
|
271
|
+
)
|
272
|
+
except Exception as e:
|
273
|
+
logger.warning(
|
274
|
+
f"Setup Custom allreduce failed with {e}. To silence this "
|
275
|
+
"warning, specify --disable-custom-all-reduce explicitly."
|
276
|
+
)
|
271
277
|
|
272
278
|
from sglang.srt.distributed.device_communicators.hpu_communicator import (
|
273
279
|
HpuCommunicator,
|
@@ -439,6 +445,15 @@ class GroupCoordinator:
|
|
439
445
|
else:
|
440
446
|
torch.distributed.all_reduce(input_, group=self.device_group)
|
441
447
|
|
448
|
+
def reduce_scatter(
|
449
|
+
self,
|
450
|
+
output: torch.Tensor,
|
451
|
+
input_list: List[torch.Tensor],
|
452
|
+
) -> None:
|
453
|
+
# TODO(ch-wan): support other backends
|
454
|
+
torch.distributed.reduce_scatter(output, input_list, group=self.device_group)
|
455
|
+
return output
|
456
|
+
|
442
457
|
def _all_gather_into_tensor(self, output: torch.Tensor, input: torch.Tensor):
|
443
458
|
pynccl_comm = self.pynccl_comm
|
444
459
|
if pynccl_comm is not None and not pynccl_comm.disabled:
|
@@ -456,11 +471,23 @@ class GroupCoordinator:
|
|
456
471
|
output, input, group_name=self.unique_name
|
457
472
|
)
|
458
473
|
|
459
|
-
def all_gather(
|
474
|
+
def all_gather(
|
475
|
+
self,
|
476
|
+
input_: torch.Tensor,
|
477
|
+
dim: int = -1,
|
478
|
+
tensor_list: List[torch.Tensor] = None,
|
479
|
+
) -> torch.Tensor:
|
460
480
|
world_size = self.world_size
|
461
481
|
# Bypass the function if we are using only 1 GPU.
|
462
482
|
if world_size == 1:
|
463
483
|
return input_
|
484
|
+
|
485
|
+
if tensor_list is not None:
|
486
|
+
# TODO(ch-wan): support other backends
|
487
|
+
return torch.distributed.all_gather(
|
488
|
+
tensor_list, input_, group=self.device_group
|
489
|
+
)
|
490
|
+
|
464
491
|
assert (
|
465
492
|
-input_.dim() <= dim < input_.dim()
|
466
493
|
), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}"
|
@@ -561,7 +561,13 @@ def available_models():
|
|
561
561
|
served_model_names = [_global_state.tokenizer_manager.served_model_name]
|
562
562
|
model_cards = []
|
563
563
|
for served_model_name in served_model_names:
|
564
|
-
model_cards.append(
|
564
|
+
model_cards.append(
|
565
|
+
ModelCard(
|
566
|
+
id=served_model_name,
|
567
|
+
root=served_model_name,
|
568
|
+
max_model_len=_global_state.tokenizer_manager.model_config.context_len,
|
569
|
+
)
|
570
|
+
)
|
565
571
|
return ModelList(data=model_cards)
|
566
572
|
|
567
573
|
|
@@ -19,6 +19,7 @@ import torch.distributed as dist
|
|
19
19
|
from torch.distributed.tensor import DeviceMesh, DTensor
|
20
20
|
|
21
21
|
from sglang.srt.model_executor.model_runner import LocalSerializedTensor
|
22
|
+
from sglang.srt.patch_torch import monkey_patch_torch_reductions
|
22
23
|
from sglang.srt.server import Engine
|
23
24
|
from sglang.srt.utils import MultiprocessingSerializer, broadcast_pyobj
|
24
25
|
|
@@ -30,6 +31,7 @@ class VerlEngine:
|
|
30
31
|
nnodes: int = 1,
|
31
32
|
**kwargs,
|
32
33
|
):
|
34
|
+
monkey_patch_torch_reductions()
|
33
35
|
self._device_mesh_cpu = device_mesh_cpu
|
34
36
|
self._tp_rank = device_mesh_cpu.get_local_rank()
|
35
37
|
self._tp_size = device_mesh_cpu.size()
|
@@ -13,7 +13,9 @@ from typing import TYPE_CHECKING, Optional, Union
|
|
13
13
|
|
14
14
|
import torch
|
15
15
|
|
16
|
+
from sglang.srt.configs.model_config import AttentionArch
|
16
17
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
18
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
17
19
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
|
18
20
|
|
19
21
|
if TYPE_CHECKING:
|
@@ -29,11 +31,11 @@ class FlashAttentionMetadata:
|
|
29
31
|
|
30
32
|
cu_seqlens_q: torch.Tensor = None
|
31
33
|
cu_seqlens_k: torch.Tensor = None
|
34
|
+
max_seq_len_q: int = 0
|
32
35
|
max_seq_len_k: int = 0
|
33
36
|
window_size: tuple = (-1, -1)
|
34
37
|
page_table: torch.Tensor = None
|
35
38
|
cache_seqlens_int32: torch.Tensor = None
|
36
|
-
max_seq_len_q: int = 0
|
37
39
|
|
38
40
|
|
39
41
|
class FlashAttentionBackend(AttentionBackend):
|
@@ -57,13 +59,16 @@ class FlashAttentionBackend(AttentionBackend):
|
|
57
59
|
self.device = model_runner.device
|
58
60
|
self.decode_cuda_graph_metadata = {}
|
59
61
|
self.req_to_token = model_runner.req_to_token_pool.req_to_token
|
62
|
+
self.page_size = model_runner.page_size
|
63
|
+
self.use_mla = (
|
64
|
+
model_runner.model_config.attention_arch == AttentionArch.MLA
|
65
|
+
) and (not global_server_args_dict["disable_mla"])
|
60
66
|
|
61
67
|
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
62
68
|
"""Initialize forward metadata to cache repetitive calculations."""
|
63
69
|
# Create metadata based on forward mode
|
64
70
|
metadata = FlashAttentionMetadata()
|
65
71
|
|
66
|
-
extend_seq_lens = forward_batch.extend_seq_lens
|
67
72
|
# Get sequence information
|
68
73
|
seqlens_in_batch = forward_batch.seq_lens
|
69
74
|
# Precompute int32 version of sequence lengths
|
@@ -79,21 +84,33 @@ class FlashAttentionBackend(AttentionBackend):
|
|
79
84
|
metadata.page_table = forward_batch.req_to_token_pool.req_to_token[
|
80
85
|
forward_batch.req_pool_indices, : metadata.max_seq_len_k
|
81
86
|
]
|
87
|
+
|
88
|
+
# Precompute strided indices
|
89
|
+
# [0, page_size, 2 * page_size, ...]
|
90
|
+
if self.page_size > 1:
|
91
|
+
self.strided_indices = torch.arange(
|
92
|
+
0, metadata.page_table.shape[1], self.page_size, device=self.device
|
93
|
+
)
|
94
|
+
metadata.page_table = (
|
95
|
+
metadata.page_table[:, self.strided_indices] // self.page_size
|
96
|
+
)
|
97
|
+
|
82
98
|
if forward_batch.forward_mode == ForwardMode.DECODE:
|
83
99
|
# Precompute cumulative sequence lengths
|
84
100
|
metadata.cu_seqlens_q = torch.arange(
|
85
101
|
0, batch_size + 1, dtype=torch.int32, device=device
|
86
102
|
)
|
87
103
|
else:
|
88
|
-
extend_no_prefix = not any(forward_batch.extend_prefix_lens)
|
89
104
|
# Precompute cumulative sequence lengths
|
90
|
-
if
|
105
|
+
if any(forward_batch.extend_prefix_lens_cpu):
|
106
|
+
extend_seq_lens = forward_batch.extend_seq_lens
|
91
107
|
metadata.cu_seqlens_q = torch.nn.functional.pad(
|
92
108
|
torch.cumsum(extend_seq_lens, dim=0, dtype=torch.int32), (1, 0)
|
93
109
|
)
|
110
|
+
metadata.max_seq_len_q = max(forward_batch.extend_seq_lens_cpu)
|
94
111
|
else:
|
95
112
|
metadata.cu_seqlens_q = metadata.cu_seqlens_k
|
96
|
-
|
113
|
+
metadata.max_seq_len_q = metadata.max_seq_len_k
|
97
114
|
self.forward_metadata = metadata
|
98
115
|
|
99
116
|
def forward_extend(
|
@@ -105,23 +122,30 @@ class FlashAttentionBackend(AttentionBackend):
|
|
105
122
|
forward_batch: ForwardBatch,
|
106
123
|
save_kv_cache=True,
|
107
124
|
):
|
108
|
-
cache_loc = (
|
109
|
-
forward_batch.out_cache_loc
|
110
|
-
if not layer.is_cross_attention
|
111
|
-
else forward_batch.encoder_out_cache_loc
|
112
|
-
)
|
113
125
|
|
114
126
|
if k is not None:
|
115
127
|
assert v is not None
|
116
128
|
if save_kv_cache:
|
117
|
-
|
118
|
-
|
129
|
+
cache_loc = (
|
130
|
+
forward_batch.out_cache_loc
|
131
|
+
if not layer.is_cross_attention
|
132
|
+
else forward_batch.encoder_out_cache_loc
|
119
133
|
)
|
134
|
+
if not self.use_mla:
|
135
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
136
|
+
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
137
|
+
)
|
138
|
+
else:
|
139
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
140
|
+
layer,
|
141
|
+
cache_loc,
|
142
|
+
k,
|
143
|
+
v,
|
144
|
+
)
|
120
145
|
|
121
146
|
# Use precomputed metadata
|
122
147
|
metadata = self.forward_metadata
|
123
148
|
|
124
|
-
# # Use Flash Attention for prefill
|
125
149
|
# Calculate window size (can be moved to metadata if layer properties don't change)
|
126
150
|
# we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1
|
127
151
|
# here is two side inclusive
|
@@ -130,26 +154,72 @@ class FlashAttentionBackend(AttentionBackend):
|
|
130
154
|
if layer.sliding_window_size is not None
|
131
155
|
else (-1, -1)
|
132
156
|
)
|
133
|
-
kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
|
134
|
-
key_cache, value_cache = kv_cache[0], kv_cache[1]
|
135
|
-
o = flash_attn_with_kvcache(
|
136
|
-
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
137
|
-
k_cache=key_cache.unsqueeze(1),
|
138
|
-
v_cache=value_cache.unsqueeze(1),
|
139
|
-
page_table=metadata.page_table,
|
140
|
-
cache_seqlens=metadata.cache_seqlens_int32,
|
141
|
-
cu_seqlens_q=metadata.cu_seqlens_q,
|
142
|
-
cu_seqlens_k_new=metadata.cu_seqlens_k,
|
143
|
-
max_seqlen_q=metadata.max_seq_len_q,
|
144
|
-
softmax_scale=layer.scaling,
|
145
|
-
causal=True,
|
146
|
-
window_size=window_size,
|
147
|
-
softcap=layer.logit_cap,
|
148
|
-
k_descale=layer.k_scale,
|
149
|
-
v_descale=layer.v_scale,
|
150
|
-
)
|
151
157
|
|
152
|
-
|
158
|
+
page_table = metadata.page_table
|
159
|
+
|
160
|
+
# # Use Flash Attention for prefill
|
161
|
+
if not self.use_mla:
|
162
|
+
# Do multi-head attention
|
163
|
+
kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
|
164
|
+
key_cache, value_cache = kv_cache[0], kv_cache[1]
|
165
|
+
key_cache = key_cache.view(
|
166
|
+
-1, self.page_size, layer.tp_k_head_num, layer.head_dim
|
167
|
+
)
|
168
|
+
value_cache = value_cache.view(
|
169
|
+
-1, self.page_size, layer.tp_v_head_num, layer.head_dim
|
170
|
+
)
|
171
|
+
o = flash_attn_with_kvcache(
|
172
|
+
q=q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
173
|
+
k_cache=key_cache,
|
174
|
+
v_cache=value_cache,
|
175
|
+
page_table=page_table,
|
176
|
+
cache_seqlens=metadata.cache_seqlens_int32,
|
177
|
+
cu_seqlens_q=metadata.cu_seqlens_q,
|
178
|
+
cu_seqlens_k_new=metadata.cu_seqlens_k,
|
179
|
+
max_seqlen_q=metadata.max_seq_len_q,
|
180
|
+
softmax_scale=layer.scaling,
|
181
|
+
causal=True,
|
182
|
+
window_size=window_size,
|
183
|
+
softcap=layer.logit_cap,
|
184
|
+
k_descale=layer.k_scale,
|
185
|
+
v_descale=layer.v_scale,
|
186
|
+
)
|
187
|
+
else:
|
188
|
+
# Do absorbed multi-latent attention
|
189
|
+
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
190
|
+
k_rope = kv_cache[:, :, layer.v_head_dim :]
|
191
|
+
c_kv = kv_cache[:, :, : layer.v_head_dim]
|
192
|
+
k_rope_cache = k_rope.view(
|
193
|
+
-1,
|
194
|
+
self.page_size,
|
195
|
+
layer.tp_k_head_num,
|
196
|
+
layer.head_dim - layer.v_head_dim,
|
197
|
+
)
|
198
|
+
c_kv_cache = c_kv.view(
|
199
|
+
-1, self.page_size, layer.tp_v_head_num, layer.v_head_dim
|
200
|
+
)
|
201
|
+
|
202
|
+
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
203
|
+
q_nope = q_all[:, :, : layer.v_head_dim]
|
204
|
+
q_rope = q_all[:, :, layer.v_head_dim :]
|
205
|
+
o = flash_attn_with_kvcache(
|
206
|
+
q=q_rope,
|
207
|
+
k_cache=k_rope_cache,
|
208
|
+
v_cache=c_kv_cache,
|
209
|
+
qv=q_nope,
|
210
|
+
page_table=page_table,
|
211
|
+
cache_seqlens=metadata.cache_seqlens_int32,
|
212
|
+
cu_seqlens_q=metadata.cu_seqlens_q,
|
213
|
+
cu_seqlens_k_new=metadata.cu_seqlens_k,
|
214
|
+
max_seqlen_q=metadata.max_seq_len_q,
|
215
|
+
softmax_scale=layer.scaling,
|
216
|
+
causal=True,
|
217
|
+
softcap=layer.logit_cap,
|
218
|
+
k_descale=layer.k_scale,
|
219
|
+
v_descale=layer.v_scale,
|
220
|
+
)
|
221
|
+
|
222
|
+
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
153
223
|
|
154
224
|
def forward_decode(
|
155
225
|
self,
|
@@ -162,26 +232,29 @@ class FlashAttentionBackend(AttentionBackend):
|
|
162
232
|
) -> torch.Tensor:
|
163
233
|
"""Forward pass with FlashAttention using precomputed metadata."""
|
164
234
|
# Save KV cache if needed
|
165
|
-
if k is not None
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
170
|
-
|
171
|
-
|
172
|
-
|
173
|
-
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
235
|
+
if k is not None:
|
236
|
+
assert v is not None
|
237
|
+
if save_kv_cache:
|
238
|
+
cache_loc = (
|
239
|
+
forward_batch.out_cache_loc
|
240
|
+
if not layer.is_cross_attention
|
241
|
+
else forward_batch.encoder_out_cache_loc
|
242
|
+
)
|
243
|
+
if not self.use_mla:
|
244
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
245
|
+
layer, cache_loc, k, v, layer.k_scale, layer.v_scale
|
246
|
+
)
|
247
|
+
else:
|
248
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
249
|
+
layer,
|
250
|
+
cache_loc,
|
251
|
+
k,
|
252
|
+
v,
|
253
|
+
)
|
178
254
|
|
179
255
|
# Use precomputed metadata
|
180
256
|
metadata = self.forward_metadata
|
181
257
|
|
182
|
-
# Pre-reshape query tensor
|
183
|
-
q_reshaped = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
184
|
-
|
185
258
|
# Calculate window size (can be moved to metadata if layer properties don't change)
|
186
259
|
# we don't do layer.sliding_window_size - 1 since in model.get_attention_sliding_window_size() we already - 1
|
187
260
|
# here is two side inclusive
|
@@ -190,25 +263,79 @@ class FlashAttentionBackend(AttentionBackend):
|
|
190
263
|
if layer.sliding_window_size is not None
|
191
264
|
else (-1, -1)
|
192
265
|
)
|
193
|
-
# Run attention with precomputed values
|
194
|
-
o = flash_attn_with_kvcache(
|
195
|
-
q=q_reshaped,
|
196
|
-
k_cache=key_cache.unsqueeze(1),
|
197
|
-
v_cache=value_cache.unsqueeze(1),
|
198
|
-
page_table=metadata.page_table,
|
199
|
-
cache_seqlens=metadata.cache_seqlens_int32,
|
200
|
-
cu_seqlens_q=metadata.cu_seqlens_q,
|
201
|
-
cu_seqlens_k_new=metadata.cu_seqlens_k,
|
202
|
-
max_seqlen_q=1,
|
203
|
-
softmax_scale=layer.scaling,
|
204
|
-
causal=True,
|
205
|
-
window_size=window_size,
|
206
|
-
softcap=layer.logit_cap,
|
207
|
-
k_descale=layer.k_scale,
|
208
|
-
v_descale=layer.v_scale,
|
209
|
-
)
|
210
266
|
|
211
|
-
|
267
|
+
page_table = metadata.page_table
|
268
|
+
|
269
|
+
if not self.use_mla:
|
270
|
+
# Do multi-head attention
|
271
|
+
|
272
|
+
# Get KV cache
|
273
|
+
kv_cache = forward_batch.token_to_kv_pool.get_kv_buffer(layer.layer_id)
|
274
|
+
key_cache, value_cache = kv_cache[0], kv_cache[1]
|
275
|
+
key_cache = key_cache.view(
|
276
|
+
-1, self.page_size, layer.tp_k_head_num, layer.head_dim
|
277
|
+
)
|
278
|
+
value_cache = value_cache.view(
|
279
|
+
-1, self.page_size, layer.tp_v_head_num, layer.head_dim
|
280
|
+
)
|
281
|
+
|
282
|
+
# Pre-reshape query tensor
|
283
|
+
q_reshaped = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
284
|
+
|
285
|
+
# Run attention with precomputed values
|
286
|
+
o = flash_attn_with_kvcache(
|
287
|
+
q=q_reshaped,
|
288
|
+
k_cache=key_cache,
|
289
|
+
v_cache=value_cache,
|
290
|
+
page_table=page_table,
|
291
|
+
cache_seqlens=metadata.cache_seqlens_int32,
|
292
|
+
cu_seqlens_q=metadata.cu_seqlens_q,
|
293
|
+
cu_seqlens_k_new=metadata.cu_seqlens_k,
|
294
|
+
max_seqlen_q=1,
|
295
|
+
softmax_scale=layer.scaling,
|
296
|
+
causal=True,
|
297
|
+
window_size=window_size,
|
298
|
+
softcap=layer.logit_cap,
|
299
|
+
k_descale=layer.k_scale,
|
300
|
+
v_descale=layer.v_scale,
|
301
|
+
)
|
302
|
+
else:
|
303
|
+
# Do absorbed multi-latent attention
|
304
|
+
kv_cache = forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id)
|
305
|
+
k_rope = kv_cache[:, :, layer.v_head_dim :]
|
306
|
+
c_kv = kv_cache[:, :, : layer.v_head_dim]
|
307
|
+
k_rope_cache = k_rope.view(
|
308
|
+
-1,
|
309
|
+
self.page_size,
|
310
|
+
layer.tp_k_head_num,
|
311
|
+
layer.head_dim - layer.v_head_dim,
|
312
|
+
)
|
313
|
+
c_kv_cache = c_kv.view(
|
314
|
+
-1, self.page_size, layer.tp_v_head_num, layer.v_head_dim
|
315
|
+
)
|
316
|
+
|
317
|
+
q_all = q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim)
|
318
|
+
q_nope = q_all[:, :, : layer.v_head_dim]
|
319
|
+
q_rope = q_all[:, :, layer.v_head_dim :]
|
320
|
+
|
321
|
+
o = flash_attn_with_kvcache(
|
322
|
+
q=q_rope,
|
323
|
+
k_cache=k_rope_cache,
|
324
|
+
v_cache=c_kv_cache,
|
325
|
+
qv=q_nope,
|
326
|
+
page_table=page_table,
|
327
|
+
cache_seqlens=metadata.cache_seqlens_int32,
|
328
|
+
cu_seqlens_q=metadata.cu_seqlens_q,
|
329
|
+
cu_seqlens_k_new=metadata.cu_seqlens_k,
|
330
|
+
max_seqlen_q=1,
|
331
|
+
softmax_scale=layer.scaling,
|
332
|
+
causal=True,
|
333
|
+
softcap=layer.logit_cap,
|
334
|
+
k_descale=layer.k_scale,
|
335
|
+
v_descale=layer.v_scale,
|
336
|
+
)
|
337
|
+
|
338
|
+
return o.view(-1, layer.tp_q_head_num * layer.v_head_dim)
|
212
339
|
|
213
340
|
def init_cuda_graph_state(self, max_bs: int):
|
214
341
|
"""Initialize CUDA graph state for the attention backend.
|
@@ -223,7 +350,13 @@ class FlashAttentionBackend(AttentionBackend):
|
|
223
350
|
self.decode_cuda_graph_metadata = {
|
224
351
|
# Page table for token mapping (batch_size, max_context_len)
|
225
352
|
"page_table": torch.zeros(
|
226
|
-
max_bs,
|
353
|
+
max_bs,
|
354
|
+
(self.max_context_len + self.page_size - 1) // self.page_size,
|
355
|
+
dtype=torch.int32,
|
356
|
+
device=self.device,
|
357
|
+
),
|
358
|
+
"strided_indices": torch.arange(
|
359
|
+
0, self.max_context_len, self.page_size, device=self.device
|
227
360
|
),
|
228
361
|
}
|
229
362
|
|
@@ -274,21 +407,27 @@ class FlashAttentionBackend(AttentionBackend):
|
|
274
407
|
seq_lens_cpu: Optional[torch.Tensor],
|
275
408
|
):
|
276
409
|
# """Initialize forward metadata for replaying CUDA graph."""
|
277
|
-
seqlens_in_batch = seq_lens[:bs]
|
278
410
|
metadata = self.decode_cuda_graph_metadata[bs]
|
279
|
-
|
411
|
+
|
412
|
+
# For CPU operations
|
413
|
+
max_len = seq_lens_cpu[:bs].max().item()
|
414
|
+
metadata.max_seq_len_k = max_len
|
415
|
+
|
416
|
+
# For GPU operations
|
417
|
+
seq_lens_in_batch = seq_lens[:bs]
|
418
|
+
metadata.cache_seqlens_int32 = seq_lens_in_batch.to(torch.int32)
|
280
419
|
metadata.cu_seqlens_k = torch.nn.functional.pad(
|
281
|
-
torch.cumsum(
|
420
|
+
torch.cumsum(seq_lens_in_batch, dim=0, dtype=torch.int32), (1, 0)
|
282
421
|
)
|
283
|
-
|
284
|
-
metadata.max_seq_len_k
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
|
289
|
-
|
290
|
-
)
|
291
|
-
self.
|
422
|
+
|
423
|
+
max_seq_pages = (metadata.max_seq_len_k + self.page_size - 1) // self.page_size
|
424
|
+
page_indices = self.req_to_token[
|
425
|
+
:, self.decode_cuda_graph_metadata["strided_indices"][:max_seq_pages]
|
426
|
+
]
|
427
|
+
page_indices = page_indices[req_pool_indices[:bs]] // self.page_size
|
428
|
+
metadata.page_table[:, :max_seq_pages].copy_(page_indices)
|
429
|
+
metadata.page_table[:, max_seq_pages:].fill_(0)
|
430
|
+
self.forward_metadata = metadata
|
292
431
|
|
293
432
|
def get_cuda_graph_seq_len_fill_value(self):
|
294
433
|
"""Get the fill value for sequence length in CUDA graph."""
|