sglang 0.3.6.post2__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +55 -2
- sglang/bench_one_batch.py +7 -6
- sglang/bench_one_batch_server.py +4 -3
- sglang/bench_serving.py +13 -0
- sglang/check_env.py +1 -1
- sglang/launch_server.py +3 -2
- sglang/srt/_custom_ops.py +118 -0
- sglang/srt/configs/device_config.py +17 -0
- sglang/srt/configs/load_config.py +84 -0
- sglang/srt/configs/model_config.py +161 -4
- sglang/srt/configs/qwen2vl.py +5 -8
- sglang/srt/constrained/outlines_backend.py +6 -1
- sglang/srt/constrained/outlines_jump_forward.py +8 -1
- sglang/srt/distributed/__init__.py +3 -0
- sglang/srt/distributed/communication_op.py +34 -0
- sglang/srt/distributed/device_communicators/__init__.py +0 -0
- sglang/srt/distributed/device_communicators/cuda_wrapper.py +182 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +352 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +291 -0
- sglang/srt/distributed/device_communicators/hpu_communicator.py +48 -0
- sglang/srt/distributed/device_communicators/pynccl.py +204 -0
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +362 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +568 -0
- sglang/srt/distributed/device_communicators/xpu_communicator.py +47 -0
- sglang/srt/distributed/parallel_state.py +1275 -0
- sglang/srt/distributed/utils.py +223 -0
- sglang/srt/hf_transformers_utils.py +37 -1
- sglang/srt/layers/attention/flashinfer_backend.py +13 -15
- sglang/srt/layers/attention/torch_native_backend.py +285 -0
- sglang/srt/layers/fused_moe_patch.py +20 -11
- sglang/srt/layers/linear.py +1 -0
- sglang/srt/layers/logits_processor.py +17 -3
- sglang/srt/layers/quantization/__init__.py +34 -0
- sglang/srt/layers/vocab_parallel_embedding.py +1 -0
- sglang/srt/lora/lora.py +1 -1
- sglang/srt/managers/data_parallel_controller.py +7 -11
- sglang/srt/managers/detokenizer_manager.py +7 -4
- sglang/srt/managers/image_processor.py +1 -1
- sglang/srt/managers/io_struct.py +48 -12
- sglang/srt/managers/schedule_batch.py +42 -36
- sglang/srt/managers/schedule_policy.py +7 -4
- sglang/srt/managers/scheduler.py +111 -46
- sglang/srt/managers/session_controller.py +0 -3
- sglang/srt/managers/tokenizer_manager.py +169 -100
- sglang/srt/managers/tp_worker.py +36 -3
- sglang/srt/managers/tp_worker_overlap_thread.py +32 -5
- sglang/srt/model_executor/cuda_graph_runner.py +16 -7
- sglang/srt/model_executor/forward_batch_info.py +9 -4
- sglang/srt/model_executor/model_runner.py +136 -150
- sglang/srt/model_loader/__init__.py +34 -0
- sglang/srt/model_loader/loader.py +1139 -0
- sglang/srt/model_loader/utils.py +41 -0
- sglang/srt/model_loader/weight_utils.py +640 -0
- sglang/srt/models/baichuan.py +9 -10
- sglang/srt/models/chatglm.py +6 -15
- sglang/srt/models/commandr.py +2 -3
- sglang/srt/models/dbrx.py +2 -3
- sglang/srt/models/deepseek.py +4 -11
- sglang/srt/models/deepseek_v2.py +3 -11
- sglang/srt/models/exaone.py +2 -3
- sglang/srt/models/gemma.py +2 -6
- sglang/srt/models/gemma2.py +3 -14
- sglang/srt/models/gemma2_reward.py +0 -1
- sglang/srt/models/gpt2.py +5 -12
- sglang/srt/models/gpt_bigcode.py +6 -22
- sglang/srt/models/grok.py +14 -51
- sglang/srt/models/internlm2.py +2 -3
- sglang/srt/models/internlm2_reward.py +0 -1
- sglang/srt/models/llama.py +97 -27
- sglang/srt/models/llama_classification.py +1 -2
- sglang/srt/models/llama_embedding.py +1 -2
- sglang/srt/models/llama_reward.py +2 -3
- sglang/srt/models/llava.py +10 -12
- sglang/srt/models/llavavid.py +1 -2
- sglang/srt/models/minicpm.py +4 -7
- sglang/srt/models/minicpm3.py +6 -19
- sglang/srt/models/mixtral.py +12 -5
- sglang/srt/models/mixtral_quant.py +2 -3
- sglang/srt/models/mllama.py +3 -7
- sglang/srt/models/olmo.py +2 -8
- sglang/srt/models/olmo2.py +391 -0
- sglang/srt/models/olmoe.py +3 -5
- sglang/srt/models/phi3_small.py +8 -8
- sglang/srt/models/qwen.py +2 -3
- sglang/srt/models/qwen2.py +10 -9
- sglang/srt/models/qwen2_moe.py +4 -11
- sglang/srt/models/qwen2_vl.py +12 -9
- sglang/srt/models/registry.py +99 -0
- sglang/srt/models/stablelm.py +2 -3
- sglang/srt/models/torch_native_llama.py +6 -12
- sglang/srt/models/xverse.py +2 -4
- sglang/srt/models/xverse_moe.py +4 -11
- sglang/srt/models/yivl.py +2 -3
- sglang/srt/openai_api/adapter.py +10 -6
- sglang/srt/openai_api/protocol.py +1 -0
- sglang/srt/server.py +303 -204
- sglang/srt/server_args.py +65 -31
- sglang/srt/utils.py +253 -48
- sglang/test/test_utils.py +27 -7
- sglang/utils.py +2 -2
- sglang/version.py +1 -1
- {sglang-0.3.6.post2.dist-info → sglang-0.4.0.dist-info}/METADATA +2 -1
- sglang-0.4.0.dist-info/RECORD +184 -0
- sglang/srt/layers/fused_moe_grok/__init__.py +0 -1
- sglang/srt/layers/fused_moe_grok/fused_moe.py +0 -692
- sglang/srt/layers/fused_moe_grok/layer.py +0 -630
- sglang-0.3.6.post2.dist-info/RECORD +0 -164
- {sglang-0.3.6.post2.dist-info → sglang-0.4.0.dist-info}/LICENSE +0 -0
- {sglang-0.3.6.post2.dist-info → sglang-0.4.0.dist-info}/WHEEL +0 -0
- {sglang-0.3.6.post2.dist-info → sglang-0.4.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,223 @@
|
|
1
|
+
# Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/utils.py
|
2
|
+
# Copyright 2023 The vLLM team.
|
3
|
+
# Adapted from
|
4
|
+
# https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/tensor_parallel/utils.py
|
5
|
+
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
6
|
+
import dataclasses
|
7
|
+
import logging
|
8
|
+
import os
|
9
|
+
import pickle
|
10
|
+
import time
|
11
|
+
from collections import deque
|
12
|
+
from typing import Any, Deque, Dict, Optional, Sequence, Tuple
|
13
|
+
|
14
|
+
import torch
|
15
|
+
from torch.distributed import TCPStore
|
16
|
+
|
17
|
+
logger = logging.getLogger(__name__)
|
18
|
+
|
19
|
+
|
20
|
+
def ensure_divisibility(numerator, denominator):
|
21
|
+
"""Ensure that numerator is divisible by the denominator."""
|
22
|
+
assert numerator % denominator == 0, "{} is not divisible by {}".format(
|
23
|
+
numerator, denominator
|
24
|
+
)
|
25
|
+
|
26
|
+
|
27
|
+
def divide(numerator, denominator):
|
28
|
+
"""Ensure that numerator is divisible by the denominator and return
|
29
|
+
the division value."""
|
30
|
+
ensure_divisibility(numerator, denominator)
|
31
|
+
return numerator // denominator
|
32
|
+
|
33
|
+
|
34
|
+
def split_tensor_along_last_dim(
|
35
|
+
tensor: torch.Tensor,
|
36
|
+
num_partitions: int,
|
37
|
+
contiguous_split_chunks: bool = False,
|
38
|
+
) -> Sequence[torch.Tensor]:
|
39
|
+
"""Split a tensor along its last dimension.
|
40
|
+
|
41
|
+
Arguments:
|
42
|
+
tensor: input tensor.
|
43
|
+
num_partitions: number of partitions to split the tensor
|
44
|
+
contiguous_split_chunks: If True, make each chunk contiguous
|
45
|
+
in memory.
|
46
|
+
|
47
|
+
Returns:
|
48
|
+
A list of Tensors
|
49
|
+
"""
|
50
|
+
# Get the size and dimension.
|
51
|
+
last_dim = tensor.dim() - 1
|
52
|
+
last_dim_size = divide(tensor.size()[last_dim], num_partitions)
|
53
|
+
# Split.
|
54
|
+
tensor_list = torch.split(tensor, last_dim_size, dim=last_dim)
|
55
|
+
# NOTE: torch.split does not create contiguous tensors by default.
|
56
|
+
if contiguous_split_chunks:
|
57
|
+
return tuple(chunk.contiguous() for chunk in tensor_list)
|
58
|
+
|
59
|
+
return tensor_list
|
60
|
+
|
61
|
+
|
62
|
+
def get_pp_indices(
|
63
|
+
num_hidden_layers: int, pp_rank: int, pp_size: int
|
64
|
+
) -> Tuple[int, int]:
|
65
|
+
"""Try to evenly distribute layers across partitions.
|
66
|
+
If the number of layers is not divisible by the number of partitions,
|
67
|
+
the last partition will have the remaining layers.
|
68
|
+
"""
|
69
|
+
# partition_list_str can be set to None in sglang
|
70
|
+
partition_list_str = os.getenv("SGLANG_PP_LAYER_PARTITION", None)
|
71
|
+
if partition_list_str is not None:
|
72
|
+
try:
|
73
|
+
partitions = [int(layer) for layer in partition_list_str.split(",")]
|
74
|
+
except ValueError as err:
|
75
|
+
raise ValueError(
|
76
|
+
"Invalid partition string: {}".format(partition_list_str)
|
77
|
+
) from err
|
78
|
+
if len(partitions) != pp_size:
|
79
|
+
raise ValueError(f"{len(partitions)=} does not match {pp_size=}.")
|
80
|
+
if sum(partitions) != num_hidden_layers:
|
81
|
+
raise ValueError(f"{sum(partitions)=} does not match {num_hidden_layers=}.")
|
82
|
+
start_layer = sum(partitions[:pp_rank])
|
83
|
+
end_layer = start_layer + partitions[pp_rank]
|
84
|
+
else:
|
85
|
+
layers_per_partition = num_hidden_layers // pp_size
|
86
|
+
start_layer = pp_rank * layers_per_partition
|
87
|
+
end_layer = start_layer + layers_per_partition
|
88
|
+
|
89
|
+
if pp_rank == pp_size - 1:
|
90
|
+
end_layer = num_hidden_layers
|
91
|
+
|
92
|
+
return (start_layer, end_layer)
|
93
|
+
|
94
|
+
|
95
|
+
@dataclasses.dataclass
|
96
|
+
class StatelessProcessGroup:
|
97
|
+
"""A dataclass to hold a metadata store, and the rank, world_size of the
|
98
|
+
group. Only use it to communicate metadata between processes.
|
99
|
+
For data-plane communication, create NCCL-related objects.
|
100
|
+
"""
|
101
|
+
|
102
|
+
rank: int
|
103
|
+
world_size: int
|
104
|
+
store: torch._C._distributed_c10d.Store
|
105
|
+
data_expiration_seconds: int = 3600 # 1 hour
|
106
|
+
|
107
|
+
# dst rank -> counter
|
108
|
+
send_dst_counter: Dict[int, int] = dataclasses.field(default_factory=dict)
|
109
|
+
# src rank -> counter
|
110
|
+
recv_src_counter: Dict[int, int] = dataclasses.field(default_factory=dict)
|
111
|
+
broadcast_send_counter: int = 0
|
112
|
+
broadcast_recv_src_counter: Dict[int, int] = dataclasses.field(default_factory=dict)
|
113
|
+
|
114
|
+
# A deque to store the data entries, with key and timestamp.
|
115
|
+
entries: Deque[Tuple[str, float]] = dataclasses.field(default_factory=deque)
|
116
|
+
|
117
|
+
def __post_init__(self):
|
118
|
+
assert self.rank < self.world_size
|
119
|
+
self.send_dst_counter = {i: 0 for i in range(self.world_size)}
|
120
|
+
self.recv_src_counter = {i: 0 for i in range(self.world_size)}
|
121
|
+
self.broadcast_recv_src_counter = {i: 0 for i in range(self.world_size)}
|
122
|
+
|
123
|
+
def send_obj(self, obj: Any, dst: int):
|
124
|
+
"""Send an object to a destination rank."""
|
125
|
+
self.expire_data()
|
126
|
+
key = f"send_to/{dst}/{self.send_dst_counter[dst]}"
|
127
|
+
self.store.set(key, pickle.dumps(obj))
|
128
|
+
self.send_dst_counter[dst] += 1
|
129
|
+
self.entries.append((key, time.time()))
|
130
|
+
|
131
|
+
def expire_data(self):
|
132
|
+
"""Expire data that is older than `data_expiration_seconds` seconds."""
|
133
|
+
while self.entries:
|
134
|
+
# check the oldest entry
|
135
|
+
key, timestamp = self.entries[0]
|
136
|
+
if time.time() - timestamp > self.data_expiration_seconds:
|
137
|
+
self.store.delete_key(key)
|
138
|
+
self.entries.popleft()
|
139
|
+
else:
|
140
|
+
break
|
141
|
+
|
142
|
+
def recv_obj(self, src: int) -> Any:
|
143
|
+
"""Receive an object from a source rank."""
|
144
|
+
obj = pickle.loads(
|
145
|
+
self.store.get(f"send_to/{self.rank}/{self.recv_src_counter[src]}")
|
146
|
+
)
|
147
|
+
self.recv_src_counter[src] += 1
|
148
|
+
return obj
|
149
|
+
|
150
|
+
def broadcast_obj(self, obj: Optional[Any], src: int) -> Any:
|
151
|
+
"""Broadcast an object from a source rank to all other ranks.
|
152
|
+
It does not clean up after all ranks have received the object.
|
153
|
+
Use it for limited times, e.g., for initialization.
|
154
|
+
"""
|
155
|
+
if self.rank == src:
|
156
|
+
self.expire_data()
|
157
|
+
key = f"broadcast_from/{src}/" f"{self.broadcast_send_counter}"
|
158
|
+
self.store.set(key, pickle.dumps(obj))
|
159
|
+
self.broadcast_send_counter += 1
|
160
|
+
self.entries.append((key, time.time()))
|
161
|
+
return obj
|
162
|
+
else:
|
163
|
+
key = f"broadcast_from/{src}/" f"{self.broadcast_recv_src_counter[src]}"
|
164
|
+
recv_obj = pickle.loads(self.store.get(key))
|
165
|
+
self.broadcast_recv_src_counter[src] += 1
|
166
|
+
return recv_obj
|
167
|
+
|
168
|
+
def all_gather_obj(self, obj: Any) -> list[Any]:
|
169
|
+
"""All gather an object from all ranks."""
|
170
|
+
gathered_objs = []
|
171
|
+
for i in range(self.world_size):
|
172
|
+
if i == self.rank:
|
173
|
+
gathered_objs.append(obj)
|
174
|
+
self.broadcast_obj(obj, src=self.rank)
|
175
|
+
else:
|
176
|
+
recv_obj = self.broadcast_obj(None, src=i)
|
177
|
+
gathered_objs.append(recv_obj)
|
178
|
+
return gathered_objs
|
179
|
+
|
180
|
+
def barrier(self):
|
181
|
+
"""A barrier to synchronize all ranks."""
|
182
|
+
for i in range(self.world_size):
|
183
|
+
if i == self.rank:
|
184
|
+
self.broadcast_obj(None, src=self.rank)
|
185
|
+
else:
|
186
|
+
self.broadcast_obj(None, src=i)
|
187
|
+
|
188
|
+
@staticmethod
|
189
|
+
def create(
|
190
|
+
host: str,
|
191
|
+
port: int,
|
192
|
+
rank: int,
|
193
|
+
world_size: int,
|
194
|
+
data_expiration_seconds: int = 3600,
|
195
|
+
) -> "StatelessProcessGroup":
|
196
|
+
"""A replacement for `torch.distributed.init_process_group` that does not
|
197
|
+
pollute the global state.
|
198
|
+
|
199
|
+
If we have process A and process B called `torch.distributed.init_process_group`
|
200
|
+
to form a group, and then we want to form another group with process A, B, C,
|
201
|
+
D, it is not possible in PyTorch, because process A and process B have already
|
202
|
+
formed a group, and process C and process D cannot join that group. This
|
203
|
+
function is a workaround for this issue.
|
204
|
+
|
205
|
+
`torch.distributed.init_process_group` is a global call, while this function
|
206
|
+
is a stateless call. It will return a `StatelessProcessGroup` object that can be
|
207
|
+
used for exchanging metadata. With this function, process A and process B
|
208
|
+
can call `StatelessProcessGroup.create` to form a group, and then process A, B,
|
209
|
+
C, and D can call `StatelessProcessGroup.create` to form another group.
|
210
|
+
""" # noqa
|
211
|
+
store = TCPStore(
|
212
|
+
host_name=host,
|
213
|
+
port=port,
|
214
|
+
world_size=world_size,
|
215
|
+
is_master=(rank == 0),
|
216
|
+
)
|
217
|
+
|
218
|
+
return StatelessProcessGroup(
|
219
|
+
rank=rank,
|
220
|
+
world_size=world_size,
|
221
|
+
store=store,
|
222
|
+
data_expiration_seconds=data_expiration_seconds,
|
223
|
+
)
|
@@ -16,6 +16,7 @@
|
|
16
16
|
import contextlib
|
17
17
|
import os
|
18
18
|
import warnings
|
19
|
+
from pathlib import Path
|
19
20
|
from typing import Dict, Optional, Type, Union
|
20
21
|
|
21
22
|
from huggingface_hub import snapshot_download
|
@@ -27,6 +28,7 @@ from transformers import (
|
|
27
28
|
PreTrainedTokenizer,
|
28
29
|
PreTrainedTokenizerFast,
|
29
30
|
)
|
31
|
+
from transformers.models.auto.modeling_auto import MODEL_FOR_CAUSAL_LM_MAPPING_NAMES
|
30
32
|
|
31
33
|
try:
|
32
34
|
from vllm.transformers_utils.configs import ChatGLMConfig, DbrxConfig
|
@@ -60,15 +62,31 @@ def get_config(
|
|
60
62
|
trust_remote_code: bool,
|
61
63
|
revision: Optional[str] = None,
|
62
64
|
model_override_args: Optional[dict] = None,
|
65
|
+
**kwargs,
|
63
66
|
):
|
67
|
+
is_gguf = check_gguf_file(model)
|
68
|
+
if is_gguf:
|
69
|
+
kwargs["gguf_file"] = model
|
70
|
+
model = Path(model).parent
|
71
|
+
|
64
72
|
config = AutoConfig.from_pretrained(
|
65
|
-
model, trust_remote_code=trust_remote_code, revision=revision
|
73
|
+
model, trust_remote_code=trust_remote_code, revision=revision, **kwargs
|
66
74
|
)
|
67
75
|
if config.model_type in _CONFIG_REGISTRY:
|
68
76
|
config_class = _CONFIG_REGISTRY[config.model_type]
|
69
77
|
config = config_class.from_pretrained(model, revision=revision)
|
78
|
+
# NOTE(HandH1998): Qwen2VL requires `_name_or_path` attribute in `config`.
|
79
|
+
setattr(config, "_name_or_path", model)
|
70
80
|
if model_override_args:
|
71
81
|
config.update(model_override_args)
|
82
|
+
|
83
|
+
# Special architecture mapping check for GGUF models
|
84
|
+
if is_gguf:
|
85
|
+
if config.model_type not in MODEL_FOR_CAUSAL_LM_MAPPING_NAMES:
|
86
|
+
raise RuntimeError(f"Can't get gguf config for {config.model_type}.")
|
87
|
+
model_type = MODEL_FOR_CAUSAL_LM_MAPPING_NAMES[config.model_type]
|
88
|
+
config.update({"architectures": [model_type]})
|
89
|
+
|
72
90
|
return config
|
73
91
|
|
74
92
|
|
@@ -123,6 +141,11 @@ def get_tokenizer(
|
|
123
141
|
raise ValueError("Cannot use the fast tokenizer in slow tokenizer mode.")
|
124
142
|
kwargs["use_fast"] = False
|
125
143
|
|
144
|
+
is_gguf = check_gguf_file(tokenizer_name)
|
145
|
+
if is_gguf:
|
146
|
+
kwargs["gguf_file"] = tokenizer_name
|
147
|
+
tokenizer_name = Path(tokenizer_name).parent
|
148
|
+
|
126
149
|
try:
|
127
150
|
tokenizer = AutoTokenizer.from_pretrained(
|
128
151
|
tokenizer_name,
|
@@ -195,3 +218,16 @@ def attach_additional_stop_token_ids(tokenizer):
|
|
195
218
|
)
|
196
219
|
else:
|
197
220
|
tokenizer.additional_stop_token_ids = None
|
221
|
+
|
222
|
+
|
223
|
+
def check_gguf_file(model: Union[str, os.PathLike]) -> bool:
|
224
|
+
"""Check if the file is a GGUF model."""
|
225
|
+
model = Path(model)
|
226
|
+
if not model.is_file():
|
227
|
+
return False
|
228
|
+
elif model.suffix == ".gguf":
|
229
|
+
return True
|
230
|
+
|
231
|
+
with open(model, "rb") as f:
|
232
|
+
header = f.read(4)
|
233
|
+
return header == b"GGUF"
|
@@ -18,7 +18,11 @@ import triton.language as tl
|
|
18
18
|
from sglang.global_config import global_config
|
19
19
|
from sglang.srt.layers.attention import AttentionBackend
|
20
20
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
21
|
-
from sglang.srt.utils import
|
21
|
+
from sglang.srt.utils import (
|
22
|
+
get_bool_env_var,
|
23
|
+
is_flashinfer_available,
|
24
|
+
should_use_tensor_core,
|
25
|
+
)
|
22
26
|
|
23
27
|
if TYPE_CHECKING:
|
24
28
|
from sglang.srt.layers.radix_attention import RadixAttention
|
@@ -31,7 +35,6 @@ if is_flashinfer_available():
|
|
31
35
|
BatchPrefillWithRaggedKVCacheWrapper,
|
32
36
|
)
|
33
37
|
from flashinfer.cascade import merge_state
|
34
|
-
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
35
38
|
|
36
39
|
|
37
40
|
class WrapperDispatch(Enum):
|
@@ -45,19 +48,14 @@ class FlashInferAttnBackend(AttentionBackend):
|
|
45
48
|
def __init__(self, model_runner: ModelRunner):
|
46
49
|
super().__init__()
|
47
50
|
|
48
|
-
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
|
53
|
-
|
54
|
-
|
55
|
-
|
56
|
-
model_runner.model_config.get_num_kv_heads(model_runner.tp_size),
|
57
|
-
):
|
58
|
-
self.decode_use_tensor_cores = True
|
59
|
-
else:
|
60
|
-
self.decode_use_tensor_cores = False
|
51
|
+
self.decode_use_tensor_cores = should_use_tensor_core(
|
52
|
+
kv_cache_dtype=model_runner.kv_cache_dtype,
|
53
|
+
num_attention_heads=model_runner.model_config.num_attention_heads
|
54
|
+
// model_runner.tp_size,
|
55
|
+
num_kv_heads=model_runner.model_config.get_num_kv_heads(
|
56
|
+
model_runner.tp_size
|
57
|
+
),
|
58
|
+
)
|
61
59
|
|
62
60
|
self.max_context_len = model_runner.model_config.context_len
|
63
61
|
|
@@ -0,0 +1,285 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
from typing import TYPE_CHECKING, Optional
|
4
|
+
|
5
|
+
import torch
|
6
|
+
from torch.nn.functional import scaled_dot_product_attention
|
7
|
+
|
8
|
+
from sglang.srt.layers.attention import AttentionBackend
|
9
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
10
|
+
|
11
|
+
if TYPE_CHECKING:
|
12
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
13
|
+
from sglang.srt.model_executor.model_runner import ModelRunner
|
14
|
+
|
15
|
+
|
16
|
+
class TorchNativeAttnBackend(AttentionBackend):
|
17
|
+
def __init__(self, model_runner: ModelRunner):
|
18
|
+
super().__init__()
|
19
|
+
self.forward_metadata = None
|
20
|
+
self.device = model_runner.device
|
21
|
+
|
22
|
+
def init_forward_metadata(self, forward_batch: ForwardBatch):
|
23
|
+
"""Init the metadata for a forward pass."""
|
24
|
+
pass
|
25
|
+
|
26
|
+
def init_cuda_graph_state(self, max_bs: int):
|
27
|
+
# TODO: Support CUDA graph
|
28
|
+
raise ValueError(
|
29
|
+
"Torch native attention does not support CUDA graph for now. Please --disable-cuda-graph"
|
30
|
+
)
|
31
|
+
|
32
|
+
def init_forward_metadata_capture_cuda_graph(
|
33
|
+
self,
|
34
|
+
bs: int,
|
35
|
+
req_pool_indices: torch.Tensor,
|
36
|
+
seq_lens: torch.Tensor,
|
37
|
+
encoder_lens: Optional[torch.Tensor] = None,
|
38
|
+
):
|
39
|
+
# TODO: Support CUDA graph
|
40
|
+
raise ValueError(
|
41
|
+
"Torch native attention does not support CUDA graph for now. Please --disable-cuda-graph"
|
42
|
+
)
|
43
|
+
|
44
|
+
def init_forward_metadata_replay_cuda_graph(
|
45
|
+
self,
|
46
|
+
bs: int,
|
47
|
+
req_pool_indices: torch.Tensor,
|
48
|
+
seq_lens: torch.Tensor,
|
49
|
+
seq_lens_sum: int,
|
50
|
+
encoder_lens: Optional[torch.Tensor] = None,
|
51
|
+
):
|
52
|
+
# TODO: Support CUDA graph
|
53
|
+
raise ValueError(
|
54
|
+
"Torch native attention does not support CUDA graph for now. Please --disable-cuda-graph"
|
55
|
+
)
|
56
|
+
|
57
|
+
def get_cuda_graph_seq_len_fill_value(self):
|
58
|
+
# TODO: Support CUDA graph
|
59
|
+
raise ValueError(
|
60
|
+
"Torch native attention does not support CUDA graph for now. Please --disable-cuda-graph"
|
61
|
+
)
|
62
|
+
|
63
|
+
def _run_sdpa_forward_extend(
|
64
|
+
self,
|
65
|
+
query: torch.Tensor,
|
66
|
+
output: torch.Tensor,
|
67
|
+
k_cache: torch.Tensor,
|
68
|
+
v_cache: torch.Tensor,
|
69
|
+
req_to_token: torch.Tensor,
|
70
|
+
req_pool_indices: torch.Tensor,
|
71
|
+
seq_lens: torch.Tensor,
|
72
|
+
extend_prefix_lens: torch.Tensor,
|
73
|
+
extend_seq_lens: torch.Tensor,
|
74
|
+
scaling=None,
|
75
|
+
enable_gqa=False,
|
76
|
+
causal=False,
|
77
|
+
):
|
78
|
+
"""Run the extend forward by using torch native sdpa op.
|
79
|
+
|
80
|
+
Args:
|
81
|
+
query: [num_tokens, num_heads, head_size]
|
82
|
+
output: [num_tokens, num_heads, head_size]
|
83
|
+
k_cache: [max_total_num_tokens, num_heads, head_size]
|
84
|
+
v_cache: [max_total_num_tokens, num_heads, head_size]
|
85
|
+
req_to_token: [max_num_reqs, max_context_len]
|
86
|
+
req_pool_indices: [num_seqs]
|
87
|
+
seq_lens: [num_seqs]
|
88
|
+
extend_prefix_lens: [num_seqs]
|
89
|
+
extend_seq_lens: [num_seqs]
|
90
|
+
scaling: float or None
|
91
|
+
enable_gqa: bool
|
92
|
+
causal: bool
|
93
|
+
|
94
|
+
Returns:
|
95
|
+
output: [num_tokens, num_heads, head_size]
|
96
|
+
"""
|
97
|
+
|
98
|
+
assert seq_lens.shape[0] == extend_prefix_lens.shape[0]
|
99
|
+
assert seq_lens.shape[0] == extend_seq_lens.shape[0]
|
100
|
+
|
101
|
+
# [num_tokens, num_heads, head_size] -> [num_heads, num_tokens, head_size]
|
102
|
+
query = query.movedim(0, query.dim() - 2)
|
103
|
+
|
104
|
+
start_q, start_kv = 0, 0
|
105
|
+
for seq_idx in range(seq_lens.shape[0]):
|
106
|
+
# TODO: this loop process a sequence per iter, this is inefficient.
|
107
|
+
# Need optimize the performance later.
|
108
|
+
|
109
|
+
extend_seq_len_q = extend_seq_lens[seq_idx]
|
110
|
+
prefill_seq_len_q = extend_prefix_lens[seq_idx]
|
111
|
+
|
112
|
+
seq_len_kv = seq_lens[seq_idx]
|
113
|
+
end_q = start_q + extend_seq_len_q
|
114
|
+
end_kv = start_kv + seq_len_kv
|
115
|
+
|
116
|
+
per_req_query = query[:, start_q:end_q, :]
|
117
|
+
per_req_query_redudant = torch.empty(
|
118
|
+
(per_req_query.shape[0], seq_len_kv, per_req_query.shape[2]),
|
119
|
+
dtype=per_req_query.dtype,
|
120
|
+
device=per_req_query.device,
|
121
|
+
)
|
122
|
+
|
123
|
+
per_req_query_redudant[:, prefill_seq_len_q:, :] = per_req_query
|
124
|
+
|
125
|
+
# get key and value from cache. per_req_tokens contains the kv cache
|
126
|
+
# index for each token in the sequence.
|
127
|
+
req_pool_idx = req_pool_indices[seq_idx]
|
128
|
+
per_req_tokens = req_to_token[req_pool_idx, :seq_len_kv]
|
129
|
+
per_req_key = k_cache[per_req_tokens].movedim(0, query.dim() - 2)
|
130
|
+
per_req_value = v_cache[per_req_tokens].movedim(0, query.dim() - 2)
|
131
|
+
|
132
|
+
per_req_out_redudant = (
|
133
|
+
scaled_dot_product_attention(
|
134
|
+
per_req_query_redudant.unsqueeze(0),
|
135
|
+
per_req_key.unsqueeze(0),
|
136
|
+
per_req_value.unsqueeze(0),
|
137
|
+
enable_gqa=enable_gqa,
|
138
|
+
scale=scaling,
|
139
|
+
is_causal=causal,
|
140
|
+
)
|
141
|
+
.squeeze(0)
|
142
|
+
.movedim(query.dim() - 2, 0)
|
143
|
+
)
|
144
|
+
output[start_q:end_q, :, :] = per_req_out_redudant[prefill_seq_len_q:, :, :]
|
145
|
+
start_q, start_kv = end_q, end_kv
|
146
|
+
return output
|
147
|
+
|
148
|
+
def _run_sdpa_forward_decode(
|
149
|
+
self,
|
150
|
+
query: torch.Tensor,
|
151
|
+
output: torch.Tensor,
|
152
|
+
k_cache: torch.Tensor,
|
153
|
+
v_cache: torch.Tensor,
|
154
|
+
req_to_token: torch.Tensor,
|
155
|
+
req_pool_indices: torch.Tensor,
|
156
|
+
seq_lens: torch.Tensor,
|
157
|
+
scaling=None,
|
158
|
+
enable_gqa=False,
|
159
|
+
causal=False,
|
160
|
+
):
|
161
|
+
"""Run the decode forward by using torch native sdpa op.
|
162
|
+
|
163
|
+
Args:
|
164
|
+
query: [num_tokens, num_heads, head_size]
|
165
|
+
output: [num_tokens, num_heads, head_size]
|
166
|
+
k_cache: [max_total_num_tokens, num_heads, head_size]
|
167
|
+
v_cache: [max_total_num_tokens, num_heads, head_size]
|
168
|
+
req_to_token: [max_num_reqs, max_context_len]
|
169
|
+
req_pool_indices: [num_seqs]
|
170
|
+
seq_lens: [num_seqs]
|
171
|
+
scaling: float or None
|
172
|
+
enable_gqa: bool
|
173
|
+
causal: bool
|
174
|
+
|
175
|
+
Returns:
|
176
|
+
output: [num_tokens, num_heads, head_size]
|
177
|
+
"""
|
178
|
+
|
179
|
+
# [num_tokens, num_heads, head_size] -> [num_heads, num_tokens, head_size]
|
180
|
+
query = query.movedim(0, query.dim() - 2)
|
181
|
+
|
182
|
+
start_q, start_kv = 0, 0
|
183
|
+
for seq_idx in range(seq_lens.shape[0]):
|
184
|
+
# TODO: this loop process a sequence per iter, this is inefficient.
|
185
|
+
# Need optimize the performance later.
|
186
|
+
|
187
|
+
seq_len_q = 1
|
188
|
+
seq_len_kv = seq_lens[seq_idx]
|
189
|
+
end_q = start_q + seq_len_q
|
190
|
+
end_kv = start_kv + seq_len_kv
|
191
|
+
|
192
|
+
per_req_query = query[:, start_q:end_q, :]
|
193
|
+
|
194
|
+
# get key and value from cache. per_req_tokens contains the kv cache
|
195
|
+
# index for each token in the sequence.
|
196
|
+
req_pool_idx = req_pool_indices[seq_idx]
|
197
|
+
per_req_tokens = req_to_token[req_pool_idx, :seq_len_kv]
|
198
|
+
per_req_key = k_cache[per_req_tokens].movedim(0, query.dim() - 2)
|
199
|
+
per_req_value = v_cache[per_req_tokens].movedim(0, query.dim() - 2)
|
200
|
+
|
201
|
+
per_req_out = (
|
202
|
+
scaled_dot_product_attention(
|
203
|
+
per_req_query.unsqueeze(0),
|
204
|
+
per_req_key.unsqueeze(0),
|
205
|
+
per_req_value.unsqueeze(0),
|
206
|
+
enable_gqa=enable_gqa,
|
207
|
+
scale=scaling,
|
208
|
+
is_causal=causal,
|
209
|
+
)
|
210
|
+
.squeeze(0)
|
211
|
+
.movedim(query.dim() - 2, 0)
|
212
|
+
)
|
213
|
+
output[start_q:end_q, :, :] = per_req_out
|
214
|
+
start_q, start_kv = end_q, end_kv
|
215
|
+
|
216
|
+
return output
|
217
|
+
|
218
|
+
def forward_extend(
|
219
|
+
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
|
220
|
+
):
|
221
|
+
if layer.qk_head_dim != layer.v_head_dim:
|
222
|
+
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
|
223
|
+
else:
|
224
|
+
o = torch.empty_like(q)
|
225
|
+
|
226
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
227
|
+
layer, forward_batch.out_cache_loc, k, v
|
228
|
+
)
|
229
|
+
|
230
|
+
use_gqa = layer.tp_q_head_num != layer.tp_k_head_num
|
231
|
+
|
232
|
+
q_ = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
|
233
|
+
o_ = o.view(-1, layer.tp_q_head_num, layer.v_head_dim)
|
234
|
+
|
235
|
+
self._run_sdpa_forward_extend(
|
236
|
+
q_,
|
237
|
+
o_,
|
238
|
+
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
|
239
|
+
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
|
240
|
+
forward_batch.req_to_token_pool.req_to_token,
|
241
|
+
forward_batch.req_pool_indices,
|
242
|
+
forward_batch.seq_lens,
|
243
|
+
forward_batch.extend_prefix_lens,
|
244
|
+
forward_batch.extend_seq_lens,
|
245
|
+
scaling=layer.scaling,
|
246
|
+
enable_gqa=use_gqa,
|
247
|
+
causal=not layer.is_cross_attention,
|
248
|
+
)
|
249
|
+
return o
|
250
|
+
|
251
|
+
def forward_decode(
|
252
|
+
self, q, k, v, layer: RadixAttention, forward_batch: ForwardBatch
|
253
|
+
):
|
254
|
+
# During torch.compile, there is a bug in rotary_emb that causes the
|
255
|
+
# output value to have a 3D tensor shape. This reshapes the output correctly.
|
256
|
+
q = q.reshape(-1, layer.tp_q_head_num * layer.qk_head_dim)
|
257
|
+
|
258
|
+
if layer.qk_head_dim != layer.v_head_dim:
|
259
|
+
o = q.new_empty((q.shape[0], layer.tp_q_head_num * layer.v_head_dim))
|
260
|
+
else:
|
261
|
+
o = torch.empty_like(q)
|
262
|
+
|
263
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
264
|
+
layer, forward_batch.out_cache_loc, k, v
|
265
|
+
)
|
266
|
+
|
267
|
+
use_gqa = layer.tp_q_head_num != layer.tp_k_head_num
|
268
|
+
|
269
|
+
q_ = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
|
270
|
+
o_ = o.view(-1, layer.tp_q_head_num, layer.v_head_dim)
|
271
|
+
|
272
|
+
self._run_sdpa_forward_decode(
|
273
|
+
q_,
|
274
|
+
o_,
|
275
|
+
forward_batch.token_to_kv_pool.get_key_buffer(layer.layer_id),
|
276
|
+
forward_batch.token_to_kv_pool.get_value_buffer(layer.layer_id),
|
277
|
+
forward_batch.req_to_token_pool.req_to_token,
|
278
|
+
forward_batch.req_pool_indices,
|
279
|
+
forward_batch.seq_lens,
|
280
|
+
scaling=layer.scaling,
|
281
|
+
enable_gqa=use_gqa,
|
282
|
+
causal=False,
|
283
|
+
)
|
284
|
+
|
285
|
+
return o
|
@@ -105,20 +105,29 @@ def fused_moe_forward_native(
|
|
105
105
|
num_expert_group: Optional[int] = None,
|
106
106
|
custom_routing_function: Optional[Callable] = None,
|
107
107
|
) -> torch.Tensor:
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
108
|
+
|
109
|
+
if use_grouped_topk:
|
110
|
+
assert num_expert_group is not None and topk_group is not None
|
111
|
+
topk_weights, topk_ids = grouped_topk(
|
112
|
+
x,
|
113
|
+
router_logits,
|
114
|
+
top_k,
|
115
|
+
renormalize,
|
116
|
+
num_expert_group,
|
117
|
+
topk_group,
|
118
|
+
)
|
119
|
+
elif custom_routing_function is None:
|
120
|
+
topk_weights, topk_ids = fused_topk_native(x, router_logits, top_k, renormalize)
|
121
|
+
else:
|
122
|
+
topk_weights, topk_ids = custom_routing_function(
|
123
|
+
x, router_logits, top_k, renormalize
|
124
|
+
)
|
125
|
+
|
118
126
|
w13_weights = layer.w13_weight[topk_ids]
|
119
127
|
w1_weights, w3_weights = torch.chunk(w13_weights, 2, dim=2)
|
120
128
|
w2_weights = layer.w2_weight[topk_ids]
|
121
|
-
x1 =
|
129
|
+
x1 = torch.einsum("ti,taoi -> tao", x, w1_weights)
|
130
|
+
x1 = F.silu(x1)
|
122
131
|
x3 = torch.einsum("ti, taoi -> tao", x, w3_weights)
|
123
132
|
expert_outs = torch.einsum("tao, taio -> tai", (x1 * x3), w2_weights)
|
124
133
|
return torch.einsum("tai,ta -> ti", expert_outs, topk_weights.to(expert_outs.dtype))
|