sglang 0.3.6.post3__py3-none-any.whl → 0.4.0.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +1 -1
- sglang/bench_one_batch.py +4 -0
- sglang/bench_serving.py +13 -0
- sglang/check_env.py +1 -1
- sglang/srt/_custom_ops.py +118 -0
- sglang/srt/configs/device_config.py +17 -0
- sglang/srt/configs/load_config.py +84 -0
- sglang/srt/configs/model_config.py +161 -4
- sglang/srt/configs/qwen2vl.py +5 -8
- sglang/srt/constrained/outlines_backend.py +11 -1
- sglang/srt/constrained/outlines_jump_forward.py +8 -1
- sglang/srt/constrained/xgrammar_backend.py +5 -5
- sglang/srt/distributed/__init__.py +3 -0
- sglang/srt/distributed/communication_op.py +34 -0
- sglang/srt/distributed/device_communicators/__init__.py +0 -0
- sglang/srt/distributed/device_communicators/cuda_wrapper.py +182 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +352 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +291 -0
- sglang/srt/distributed/device_communicators/hpu_communicator.py +48 -0
- sglang/srt/distributed/device_communicators/pynccl.py +204 -0
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +362 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +568 -0
- sglang/srt/distributed/device_communicators/xpu_communicator.py +47 -0
- sglang/srt/distributed/parallel_state.py +1275 -0
- sglang/srt/distributed/utils.py +223 -0
- sglang/srt/hf_transformers_utils.py +37 -1
- sglang/srt/layers/attention/__init__.py +5 -2
- sglang/srt/layers/attention/double_sparsity_backend.py +22 -8
- sglang/srt/layers/attention/flashinfer_backend.py +33 -20
- sglang/srt/layers/attention/torch_native_backend.py +299 -0
- sglang/srt/layers/attention/triton_backend.py +22 -8
- sglang/srt/layers/attention/triton_ops/extend_attention.py +3 -0
- sglang/srt/layers/ep_moe/__init__.py +0 -0
- sglang/srt/layers/ep_moe/kernels.py +349 -0
- sglang/srt/layers/ep_moe/layer.py +661 -0
- sglang/srt/layers/fused_moe_patch.py +20 -11
- sglang/srt/layers/linear.py +1 -0
- sglang/srt/layers/logits_processor.py +17 -3
- sglang/srt/layers/quantization/__init__.py +36 -2
- sglang/srt/layers/quantization/fp8.py +559 -0
- sglang/srt/layers/quantization/fp8_utils.py +27 -0
- sglang/srt/layers/radix_attention.py +4 -2
- sglang/srt/layers/sampler.py +2 -0
- sglang/srt/layers/torchao_utils.py +23 -45
- sglang/srt/layers/vocab_parallel_embedding.py +1 -0
- sglang/srt/lora/lora.py +1 -1
- sglang/srt/managers/io_struct.py +48 -2
- sglang/srt/managers/schedule_batch.py +19 -14
- sglang/srt/managers/schedule_policy.py +7 -4
- sglang/srt/managers/scheduler.py +145 -85
- sglang/srt/managers/tokenizer_manager.py +166 -68
- sglang/srt/managers/tp_worker.py +36 -3
- sglang/srt/managers/tp_worker_overlap_thread.py +28 -8
- sglang/srt/mem_cache/memory_pool.py +5 -1
- sglang/srt/model_executor/cuda_graph_runner.py +30 -7
- sglang/srt/model_executor/forward_batch_info.py +9 -4
- sglang/srt/model_executor/model_runner.py +146 -153
- sglang/srt/model_loader/__init__.py +34 -0
- sglang/srt/model_loader/loader.py +1139 -0
- sglang/srt/model_loader/utils.py +41 -0
- sglang/srt/model_loader/weight_utils.py +640 -0
- sglang/srt/model_parallel.py +1 -5
- sglang/srt/models/baichuan.py +9 -10
- sglang/srt/models/chatglm.py +6 -15
- sglang/srt/models/commandr.py +4 -5
- sglang/srt/models/dbrx.py +2 -3
- sglang/srt/models/deepseek.py +4 -11
- sglang/srt/models/deepseek_v2.py +90 -18
- sglang/srt/models/exaone.py +2 -3
- sglang/srt/models/gemma.py +2 -6
- sglang/srt/models/gemma2.py +3 -14
- sglang/srt/models/gemma2_reward.py +0 -1
- sglang/srt/models/gpt2.py +5 -12
- sglang/srt/models/gpt_bigcode.py +6 -22
- sglang/srt/models/grok.py +3 -8
- sglang/srt/models/internlm2.py +2 -3
- sglang/srt/models/internlm2_reward.py +0 -1
- sglang/srt/models/llama.py +96 -31
- sglang/srt/models/llama_classification.py +1 -2
- sglang/srt/models/llama_embedding.py +1 -2
- sglang/srt/models/llama_reward.py +2 -3
- sglang/srt/models/llava.py +1 -4
- sglang/srt/models/llavavid.py +1 -2
- sglang/srt/models/minicpm.py +4 -7
- sglang/srt/models/minicpm3.py +6 -19
- sglang/srt/models/mixtral.py +24 -14
- sglang/srt/models/mixtral_quant.py +2 -3
- sglang/srt/models/mllama.py +3 -7
- sglang/srt/models/olmo.py +2 -8
- sglang/srt/models/olmo2.py +0 -1
- sglang/srt/models/olmoe.py +3 -5
- sglang/srt/models/phi3_small.py +8 -13
- sglang/srt/models/qwen.py +2 -3
- sglang/srt/models/qwen2.py +10 -9
- sglang/srt/models/qwen2_moe.py +4 -16
- sglang/srt/models/qwen2_vl.py +2 -6
- sglang/srt/models/registry.py +99 -0
- sglang/srt/models/stablelm.py +2 -3
- sglang/srt/models/torch_native_llama.py +6 -17
- sglang/srt/models/xverse.py +2 -4
- sglang/srt/models/xverse_moe.py +4 -11
- sglang/srt/models/yivl.py +2 -3
- sglang/srt/openai_api/adapter.py +9 -5
- sglang/srt/openai_api/protocol.py +1 -0
- sglang/srt/sampling/sampling_batch_info.py +9 -8
- sglang/srt/server.py +270 -173
- sglang/srt/server_args.py +102 -29
- sglang/srt/utils.py +295 -28
- sglang/test/test_utils.py +7 -0
- sglang/version.py +1 -1
- {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/METADATA +5 -4
- sglang-0.4.0.post1.dist-info/RECORD +189 -0
- sglang-0.3.6.post3.dist-info/RECORD +0 -162
- {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/LICENSE +0 -0
- {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/WHEEL +0 -0
- {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,223 @@
|
|
1
|
+
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/utils.py
|
2
|
+
# Copyright 2023 The vLLM team.
|
3
|
+
# Adapted from
|
4
|
+
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
|
5
|
+
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
6
|
+
import dataclasses
|
7
|
+
import logging
|
8
|
+
import os
|
9
|
+
import pickle
|
10
|
+
import time
|
11
|
+
from collections import deque
|
12
|
+
from typing import Any, Deque, Dict, Optional, Sequence, Tuple
|
13
|
+
|
14
|
+
import torch
|
15
|
+
from torch.distributed import TCPStore
|
16
|
+
|
17
|
+
logger = logging.getLogger(__name__)
|
18
|
+
|
19
|
+
|
20
|
+
def ensure_divisibility(numerator, denominator):
|
21
|
+
"""Ensure that numerator is divisible by the denominator."""
|
22
|
+
assert numerator % denominator == 0, "{} is not divisible by {}".format(
|
23
|
+
numerator, denominator
|
24
|
+
)
|
25
|
+
|
26
|
+
|
27
|
+
def divide(numerator, denominator):
|
28
|
+
"""Ensure that numerator is divisible by the denominator and return
|
29
|
+
the division value."""
|
30
|
+
ensure_divisibility(numerator, denominator)
|
31
|
+
return numerator // denominator
|
32
|
+
|
33
|
+
|
34
|
+
def split_tensor_along_last_dim(
|
35
|
+
tensor: torch.Tensor,
|
36
|
+
num_partitions: int,
|
37
|
+
contiguous_split_chunks: bool = False,
|
38
|
+
) -> Sequence[torch.Tensor]:
|
39
|
+
"""Split a tensor along its last dimension.
|
40
|
+
|
41
|
+
Arguments:
|
42
|
+
tensor: input tensor.
|
43
|
+
num_partitions: number of partitions to split the tensor
|
44
|
+
contiguous_split_chunks: If True, make each chunk contiguous
|
45
|
+
in memory.
|
46
|
+
|
47
|
+
Returns:
|
48
|
+
A list of Tensors
|
49
|
+
"""
|
50
|
+
# Get the size and dimension.
|
51
|
+
last_dim = tensor.dim() - 1
|
52
|
+
last_dim_size = divide(tensor.size()[last_dim], num_partitions)
|
53
|
+
# Split.
|
54
|
+
tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
|
55
|
+
# NOTE: torch.split does not create contiguous tensors by default.
|
56
|
+
if contiguous_split_chunks:
|
57
|
+
return tuple(chunk.contiguous() for chunk in tensor_list)
|
58
|
+
|
59
|
+
return tensor_list
|
60
|
+
|
61
|
+
|
62
|
+
def get_pp_indices(
|
63
|
+
num_hidden_layers: int, pp_rank: int, pp_size: int
|
64
|
+
) -> Tuple[int, int]:
|
65
|
+
"""Try to evenly distribute layers across partitions.
|
66
|
+
If the number of layers is not divisible by the number of partitions,
|
67
|
+
the last partition will have the remaining layers.
|
68
|
+
"""
|
69
|
+
# partition_list_str can be set to None in sglang
|
70
|
+
partition_list_str = os.getenv("SGLANG_PP_LAYER_PARTITION", None)
|
71
|
+
if partition_list_str is not None:
|
72
|
+
try:
|
73
|
+
partitions = [int(layer) for layer in partition_list_str.split(",")]
|
74
|
+
except ValueError as err:
|
75
|
+
raise ValueError(
|
76
|
+
"Invalid partition string: {}".format(partition_list_str)
|
77
|
+
) from err
|
78
|
+
if len(partitions) != pp_size:
|
79
|
+
raise ValueError(f"{len(partitions)=} does not match {pp_size=}.")
|
80
|
+
if sum(partitions) != num_hidden_layers:
|
81
|
+
raise ValueError(f"{sum(partitions)=} does not match {num_hidden_layers=}.")
|
82
|
+
start_layer = sum(partitions[:pp_rank])
|
83
|
+
end_layer = start_layer + partitions[pp_rank]
|
84
|
+
else:
|
85
|
+
layers_per_partition = num_hidden_layers // pp_size
|
86
|
+
start_layer = pp_rank * layers_per_partition
|
87
|
+
end_layer = start_layer + layers_per_partition
|
88
|
+
|
89
|
+
if pp_rank == pp_size - 1:
|
90
|
+
end_layer = num_hidden_layers
|
91
|
+
|
92
|
+
return (start_layer, end_layer)
|
93
|
+
|
94
|
+
|
95
|
+
@dataclasses.dataclass
|
96
|
+
class StatelessProcessGroup:
|
97
|
+
"""A dataclass to hold a metadata store, and the rank, world_size of the
|
98
|
+
group. Only use it to communicate metadata between processes.
|
99
|
+
For data-plane communication, create NCCL-related objects.
|
100
|
+
"""
|
101
|
+
|
102
|
+
rank: int
|
103
|
+
world_size: int
|
104
|
+
store: torch._C._distributed_c10d.Store
|
105
|
+
data_expiration_seconds: int = 3600 # 1 hour
|
106
|
+
|
107
|
+
# dst rank -> counter
|
108
|
+
send_dst_counter: Dict[int, int] = dataclasses.field(default_factory=dict)
|
109
|
+
# src rank -> counter
|
110
|
+
recv_src_counter: Dict[int, int] = dataclasses.field(default_factory=dict)
|
111
|
+
broadcast_send_counter: int = 0
|
112
|
+
broadcast_recv_src_counter: Dict[int, int] = dataclasses.field(default_factory=dict)
|
113
|
+
|
114
|
+
# A deque to store the data entries, with key and timestamp.
|
115
|
+
entries: Deque[Tuple[str, float]] = dataclasses.field(default_factory=deque)
|
116
|
+
|
117
|
+
def __post_init__(self):
|
118
|
+
assert self.rank < self.world_size
|
119
|
+
self.send_dst_counter = {i: 0 for i in range(self.world_size)}
|
120
|
+
self.recv_src_counter = {i: 0 for i in range(self.world_size)}
|
121
|
+
self.broadcast_recv_src_counter = {i: 0 for i in range(self.world_size)}
|
122
|
+
|
123
|
+
def send_obj(self, obj: Any, dst: int):
|
124
|
+
"""Send an object to a destination rank."""
|
125
|
+
self.expire_data()
|
126
|
+
key = f"send_to/{dst}/{self.send_dst_counter[dst]}"
|
127
|
+
self.store.set(key, pickle.dumps(obj))
|
128
|
+
self.send_dst_counter[dst] += 1
|
129
|
+
self.entries.append((key, time.time()))
|
130
|
+
|
131
|
+
def expire_data(self):
|
132
|
+
"""Expire data that is older than `data_expiration_seconds` seconds."""
|
133
|
+
while self.entries:
|
134
|
+
# check the oldest entry
|
135
|
+
key, timestamp = self.entries[0]
|
136
|
+
if time.time() - timestamp > self.data_expiration_seconds:
|
137
|
+
self.store.delete_key(key)
|
138
|
+
self.entries.popleft()
|
139
|
+
else:
|
140
|
+
break
|
141
|
+
|
142
|
+
def recv_obj(self, src: int) -> Any:
|
143
|
+
"""Receive an object from a source rank."""
|
144
|
+
obj = pickle.loads(
|
145
|
+
self.store.get(f"send_to/{self.rank}/{self.recv_src_counter[src]}")
|
146
|
+
)
|
147
|
+
self.recv_src_counter[src] += 1
|
148
|
+
return obj
|
149
|
+
|
150
|
+
def broadcast_obj(self, obj: Optional[Any], src: int) -> Any:
|
151
|
+
"""Broadcast an object from a source rank to all other ranks.
|
152
|
+
It does not clean up after all ranks have received the object.
|
153
|
+
Use it for limited times, e.g., for initialization.
|
154
|
+
"""
|
155
|
+
if self.rank == src:
|
156
|
+
self.expire_data()
|
157
|
+
key = f"broadcast_from/{src}/" f"{self.broadcast_send_counter}"
|
158
|
+
self.store.set(key, pickle.dumps(obj))
|
159
|
+
self.broadcast_send_counter += 1
|
160
|
+
self.entries.append((key, time.time()))
|
161
|
+
return obj
|
162
|
+
else:
|
163
|
+
key = f"broadcast_from/{src}/" f"{self.broadcast_recv_src_counter[src]}"
|
164
|
+
recv_obj = pickle.loads(self.store.get(key))
|
165
|
+
self.broadcast_recv_src_counter[src] += 1
|
166
|
+
return recv_obj
|
167
|
+
|
168
|
+
def all_gather_obj(self, obj: Any) -> list[Any]:
|
169
|
+
"""All gather an object from all ranks."""
|
170
|
+
gathered_objs = []
|
171
|
+
for i in range(self.world_size):
|
172
|
+
if i == self.rank:
|
173
|
+
gathered_objs.append(obj)
|
174
|
+
self.broadcast_obj(obj, src=self.rank)
|
175
|
+
else:
|
176
|
+
recv_obj = self.broadcast_obj(None, src=i)
|
177
|
+
gathered_objs.append(recv_obj)
|
178
|
+
return gathered_objs
|
179
|
+
|
180
|
+
def barrier(self):
|
181
|
+
"""A barrier to synchronize all ranks."""
|
182
|
+
for i in range(self.world_size):
|
183
|
+
if i == self.rank:
|
184
|
+
self.broadcast_obj(None, src=self.rank)
|
185
|
+
else:
|
186
|
+
self.broadcast_obj(None, src=i)
|
187
|
+
|
188
|
+
@staticmethod
|
189
|
+
def create(
|
190
|
+
host: str,
|
191
|
+
port: int,
|
192
|
+
rank: int,
|
193
|
+
world_size: int,
|
194
|
+
data_expiration_seconds: int = 3600,
|
195
|
+
) -> "StatelessProcessGroup":
|
196
|
+
"""A replacement for `torch.distributed.init_process_group` that does not
|
197
|
+
pollute the global state.
|
198
|
+
|
199
|
+
If we have process A and process B called `torch.distributed.init_process_group`
|
200
|
+
to form a group, and then we want to form another group with process A, B, C,
|
201
|
+
D, it is not possible in PyTorch, because process A and process B have already
|
202
|
+
formed a group, and process C and process D cannot join that group. This
|
203
|
+
function is a workaround for this issue.
|
204
|
+
|
205
|
+
`torch.distributed.init_process_group` is a global call, while this function
|
206
|
+
is a stateless call. It will return a `StatelessProcessGroup` object that can be
|
207
|
+
used for exchanging metadata. With this function, process A and process B
|
208
|
+
can call `StatelessProcessGroup.create` to form a group, and then process A, B,
|
209
|
+
C, and D can call `StatelessProcessGroup.create` to form another group.
|
210
|
+
""" # noqa
|
211
|
+
store = TCPStore(
|
212
|
+
host_name=host,
|
213
|
+
port=port,
|
214
|
+
world_size=world_size,
|
215
|
+
is_master=(rank == 0),
|
216
|
+
)
|
217
|
+
|
218
|
+
return StatelessProcessGroup(
|
219
|
+
rank=rank,
|
220
|
+
world_size=world_size,
|
221
|
+
store=store,
|
222
|
+
data_expiration_seconds=data_expiration_seconds,
|
223
|
+
)
|
@@ -16,6 +16,7 @@
|
|
16
16
|
import contextlib
|
17
17
|
import os
|
18
18
|
import warnings
|
19
|
+
from pathlib import Path
|
19
20
|
from typing import Dict, Optional, Type, Union
|
20
21
|
|
21
22
|
from huggingface_hub import snapshot_download
|
@@ -27,6 +28,7 @@ from transformers import (
|
|
27
28
|
PreTrainedTokenizer,
|
28
29
|
PreTrainedTokenizerFast,
|
29
30
|
)
|
31
|
+
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
30
32
|
|
31
33
|
try:
|
32
34
|
from vllm.transformers_utils.configs import ChatGLMConfig, DbrxConfig
|
@@ -60,15 +62,31 @@ def get_config(
|
|
60
62
|
trust_remote_code: bool,
|
61
63
|
revision: Optional[str] = None,
|
62
64
|
model_override_args: Optional[dict] = None,
|
65
|
+
**kwargs,
|
63
66
|
):
|
67
|
+
is_gguf = check_gguf_file(model)
|
68
|
+
if is_gguf:
|
69
|
+
kwargs["gguf_file"] = model
|
70
|
+
model = Path(model).parent
|
71
|
+
|
64
72
|
config = AutoConfig.from_pretrained(
|
65
|
-
model, trust_remote_code=trust_remote_code, revision=revision
|
73
|
+
model, trust_remote_code=trust_remote_code, revision=revision, **kwargs
|
66
74
|
)
|
67
75
|
if config.model_type in _CONFIG_REGISTRY:
|
68
76
|
config_class = _CONFIG_REGISTRY[config.model_type]
|
69
77
|
config = config_class.from_pretrained(model, revision=revision)
|
78
|
+
# NOTE(HandH1998): Qwen2VL requires `_name_or_path` attribute in `config`.
|
79
|
+
setattr(config, "_name_or_path", model)
|
70
80
|
if model_override_args:
|
71
81
|
config.update(model_override_args)
|
82
|
+
|
83
|
+
# Special architecture mapping check for GGUF models
|
84
|
+
if is_gguf:
|
85
|
+
if config.model_type not in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
86
|
+
raise RuntimeError(f"Can't get gguf config for {config.model_type}.")
|
87
|
+
model_type = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[config.model_type]
|
88
|
+
config.update({"architectures": [model_type]})
|
89
|
+
|
72
90
|
return config
|
73
91
|
|
74
92
|
|
@@ -123,6 +141,11 @@ def get_tokenizer(
|
|
123
141
|
raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.")
|
124
142
|
kwargs["use_fast"] = False
|
125
143
|
|
144
|
+
is_gguf = check_gguf_file(tokenizer_name)
|
145
|
+
if is_gguf:
|
146
|
+
kwargs["gguf_file"] = tokenizer_name
|
147
|
+
tokenizer_name = Path(tokenizer_name).parent
|
148
|
+
|
126
149
|
try:
|
127
150
|
tokenizer = AutoTokenizer.from_pretrained(
|
128
151
|
tokenizer_name,
|
@@ -195,3 +218,16 @@ def attach_additional_stop_token_ids(tokenizer):
|
|
195
218
|
)
|
196
219
|
else:
|
197
220
|
tokenizer.additional_stop_token_ids = None
|
221
|
+
|
222
|
+
|
223
|
+
def check_gguf_file(model: Union[str, os.PathLike]) -> bool:
|
224
|
+
"""Check if the file is a GGUF model."""
|
225
|
+
model = Path(model)
|
226
|
+
if not model.is_file():
|
227
|
+
return False
|
228
|
+
elif model.suffix == ".gguf":
|
229
|
+
return True
|
230
|
+
|
231
|
+
with open(model, "rb") as f:
|
232
|
+
header = f.read(4)
|
233
|
+
return header == b"GGUF"
|
@@ -52,12 +52,13 @@ class AttentionBackend(ABC):
|
|
52
52
|
v: torch.Tensor,
|
53
53
|
layer: RadixAttention,
|
54
54
|
forward_batch: ForwardBatch,
|
55
|
+
save_kv_cache: bool = True,
|
55
56
|
):
|
56
57
|
"""Run forward on an attention layer."""
|
57
58
|
if forward_batch.forward_mode.is_decode():
|
58
|
-
return self.forward_decode(q, k, v, layer, forward_batch)
|
59
|
+
return self.forward_decode(q, k, v, layer, forward_batch, save_kv_cache)
|
59
60
|
else:
|
60
|
-
return self.forward_extend(q, k, v, layer, forward_batch)
|
61
|
+
return self.forward_extend(q, k, v, layer, forward_batch, save_kv_cache)
|
61
62
|
|
62
63
|
def forward_decode(
|
63
64
|
self,
|
@@ -66,6 +67,7 @@ class AttentionBackend(ABC):
|
|
66
67
|
v: torch.Tensor,
|
67
68
|
layer: RadixAttention,
|
68
69
|
forward_batch: ForwardBatch,
|
70
|
+
save_kv_cache: bool = True,
|
69
71
|
):
|
70
72
|
"""Run a forward for decode."""
|
71
73
|
raise NotImplementedError()
|
@@ -77,6 +79,7 @@ class AttentionBackend(ABC):
|
|
77
79
|
v: torch.Tensor,
|
78
80
|
layer: RadixAttention,
|
79
81
|
forward_batch: ForwardBatch,
|
82
|
+
save_kv_cache: bool = True,
|
80
83
|
):
|
81
84
|
"""Run a forward for extend."""
|
82
85
|
raise NotImplementedError()
|
@@ -165,7 +165,13 @@ class DoubleSparseAttnBackend(AttentionBackend):
|
|
165
165
|
return 1
|
166
166
|
|
167
167
|
def forward_extend(
|
168
|
-
self,
|
168
|
+
self,
|
169
|
+
q,
|
170
|
+
k,
|
171
|
+
v,
|
172
|
+
layer: RadixAttention,
|
173
|
+
forward_batch: ForwardBatch,
|
174
|
+
save_kv_cache=True,
|
169
175
|
):
|
170
176
|
# TODO: reuse the buffer across layers
|
171
177
|
if layer.qk_head_dim != layer.v_head_dim:
|
@@ -181,9 +187,10 @@ class DoubleSparseAttnBackend(AttentionBackend):
|
|
181
187
|
.expand(k.shape[0], -1, -1),
|
182
188
|
)
|
183
189
|
|
184
|
-
|
185
|
-
|
186
|
-
|
190
|
+
if save_kv_cache:
|
191
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
192
|
+
layer, forward_batch.out_cache_loc, k, v, k_label
|
193
|
+
)
|
187
194
|
|
188
195
|
(
|
189
196
|
start_loc,
|
@@ -212,7 +219,13 @@ class DoubleSparseAttnBackend(AttentionBackend):
|
|
212
219
|
return o
|
213
220
|
|
214
221
|
def forward_decode(
|
215
|
-
self,
|
222
|
+
self,
|
223
|
+
q,
|
224
|
+
k,
|
225
|
+
v,
|
226
|
+
layer: RadixAttention,
|
227
|
+
forward_batch: ForwardBatch,
|
228
|
+
save_kv_cache=True,
|
216
229
|
):
|
217
230
|
# During torch.compile, there is a bug in rotary_emb that causes the
|
218
231
|
# output value to have a 3D tensor shape. This reshapes the output correctly.
|
@@ -242,9 +255,10 @@ class DoubleSparseAttnBackend(AttentionBackend):
|
|
242
255
|
.expand(k.shape[0], -1, -1),
|
243
256
|
)
|
244
257
|
|
245
|
-
|
246
|
-
|
247
|
-
|
258
|
+
if save_kv_cache:
|
259
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
260
|
+
layer, forward_batch.out_cache_loc, k, v, k_label
|
261
|
+
)
|
248
262
|
|
249
263
|
# NOTE(Andy) shouldn't be used when max_len_in_batch < heavy_token_num
|
250
264
|
# and set a minimum value for sparse_decode
|
@@ -18,7 +18,11 @@ import triton.language as tl
|
|
18
18
|
from sglang.global_config import global_config
|
19
19
|
from sglang.srt.layers.attention import AttentionBackend
|
20
20
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
21
|
-
from sglang.srt.utils import
|
21
|
+
from sglang.srt.utils import (
|
22
|
+
get_bool_env_var,
|
23
|
+
is_flashinfer_available,
|
24
|
+
should_use_tensor_core,
|
25
|
+
)
|
22
26
|
|
23
27
|
if TYPE_CHECKING:
|
24
28
|
from sglang.srt.layers.radix_attention import RadixAttention
|
@@ -31,7 +35,6 @@ if is_flashinfer_available():
|
|
31
35
|
BatchPrefillWithRaggedKVCacheWrapper,
|
32
36
|
)
|
33
37
|
from flashinfer.cascade import merge_state
|
34
|
-
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
35
38
|
|
36
39
|
|
37
40
|
class WrapperDispatch(Enum):
|
@@ -45,19 +48,14 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
45
48
|
def __init__(self, model_runner: ModelRunner):
|
46
49
|
super().__init__()
|
47
50
|
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
model_runner.model_config.get_num_kv_heads(model_runner.tp_size),
|
57
|
-
):
|
58
|
-
self.decode_use_tensor_cores = True
|
59
|
-
else:
|
60
|
-
self.decode_use_tensor_cores = False
|
51
|
+
self.decode_use_tensor_cores = should_use_tensor_core(
|
52
|
+
kv_cache_dtype=model_runner.kv_cache_dtype,
|
53
|
+
num_attention_heads=model_runner.model_config.num_attention_heads
|
54
|
+
// model_runner.tp_size,
|
55
|
+
num_kv_heads=model_runner.model_config.get_num_kv_heads(
|
56
|
+
model_runner.tp_size
|
57
|
+
),
|
58
|
+
)
|
61
59
|
|
62
60
|
self.max_context_len = model_runner.model_config.context_len
|
63
61
|
|
@@ -223,7 +221,13 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
223
221
|
return 0
|
224
222
|
|
225
223
|
def forward_extend(
|
226
|
-
self,
|
224
|
+
self,
|
225
|
+
q,
|
226
|
+
k,
|
227
|
+
v,
|
228
|
+
layer: RadixAttention,
|
229
|
+
forward_batch: ForwardBatch,
|
230
|
+
save_kv_cache=True,
|
227
231
|
):
|
228
232
|
prefill_wrapper_paged = self.prefill_wrappers_paged[
|
229
233
|
self._get_wrapper_idx(layer)
|
@@ -239,7 +243,8 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
239
243
|
if not use_ragged:
|
240
244
|
if k is not None:
|
241
245
|
assert v is not None
|
242
|
-
|
246
|
+
if save_kv_cache:
|
247
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
|
243
248
|
|
244
249
|
o = prefill_wrapper_paged.forward(
|
245
250
|
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|
@@ -272,12 +277,19 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
272
277
|
|
273
278
|
o, _ = merge_state(o1, s1, o2, s2)
|
274
279
|
|
275
|
-
|
280
|
+
if save_kv_cache:
|
281
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
|
276
282
|
|
277
283
|
return o.view(-1, layer.tp_q_head_num * layer.head_dim)
|
278
284
|
|
279
285
|
def forward_decode(
|
280
|
-
self,
|
286
|
+
self,
|
287
|
+
q,
|
288
|
+
k,
|
289
|
+
v,
|
290
|
+
layer: RadixAttention,
|
291
|
+
forward_batch: ForwardBatch,
|
292
|
+
save_kv_cache=True,
|
281
293
|
):
|
282
294
|
decode_wrapper = self.forward_metadata[0][self._get_wrapper_idx(layer)]
|
283
295
|
cache_loc = (
|
@@ -288,7 +300,8 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
288
300
|
|
289
301
|
if k is not None:
|
290
302
|
assert v is not None
|
291
|
-
|
303
|
+
if save_kv_cache:
|
304
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(layer, cache_loc, k, v)
|
292
305
|
|
293
306
|
o = decode_wrapper.forward(
|
294
307
|
q.contiguous().view(-1, layer.tp_q_head_num, layer.head_dim),
|