sglang 0.3.3.post1__py3-none-any.whl → 0.3.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +28 -10
- sglang/bench_server_latency.py +21 -10
- sglang/bench_serving.py +101 -7
- sglang/global_config.py +0 -1
- sglang/srt/layers/attention/__init__.py +27 -5
- sglang/srt/layers/attention/double_sparsity_backend.py +281 -0
- sglang/srt/layers/attention/flashinfer_backend.py +352 -83
- sglang/srt/layers/attention/triton_backend.py +6 -4
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +772 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +5 -3
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +4 -2
- sglang/srt/layers/sampler.py +6 -2
- sglang/srt/managers/detokenizer_manager.py +31 -10
- sglang/srt/managers/io_struct.py +4 -0
- sglang/srt/managers/schedule_batch.py +120 -43
- sglang/srt/managers/schedule_policy.py +2 -1
- sglang/srt/managers/scheduler.py +202 -140
- sglang/srt/managers/tokenizer_manager.py +5 -1
- sglang/srt/managers/tp_worker.py +111 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -4
- sglang/srt/mem_cache/memory_pool.py +77 -4
- sglang/srt/mem_cache/radix_cache.py +15 -7
- sglang/srt/model_executor/cuda_graph_runner.py +4 -4
- sglang/srt/model_executor/forward_batch_info.py +16 -21
- sglang/srt/model_executor/model_runner.py +60 -1
- sglang/srt/models/baichuan.py +2 -3
- sglang/srt/models/chatglm.py +5 -6
- sglang/srt/models/commandr.py +1 -2
- sglang/srt/models/dbrx.py +1 -2
- sglang/srt/models/deepseek.py +4 -5
- sglang/srt/models/deepseek_v2.py +5 -6
- sglang/srt/models/exaone.py +1 -2
- sglang/srt/models/gemma.py +2 -2
- sglang/srt/models/gemma2.py +5 -5
- sglang/srt/models/gpt_bigcode.py +5 -5
- sglang/srt/models/grok.py +1 -2
- sglang/srt/models/internlm2.py +1 -2
- sglang/srt/models/llama.py +1 -2
- sglang/srt/models/llama_classification.py +1 -2
- sglang/srt/models/llama_reward.py +2 -3
- sglang/srt/models/llava.py +4 -8
- sglang/srt/models/llavavid.py +1 -2
- sglang/srt/models/minicpm.py +1 -2
- sglang/srt/models/minicpm3.py +5 -6
- sglang/srt/models/mixtral.py +1 -2
- sglang/srt/models/mixtral_quant.py +1 -2
- sglang/srt/models/olmo.py +352 -0
- sglang/srt/models/olmoe.py +1 -2
- sglang/srt/models/qwen.py +1 -2
- sglang/srt/models/qwen2.py +1 -2
- sglang/srt/models/qwen2_moe.py +4 -5
- sglang/srt/models/stablelm.py +1 -2
- sglang/srt/models/torch_native_llama.py +1 -2
- sglang/srt/models/xverse.py +1 -2
- sglang/srt/models/xverse_moe.py +4 -5
- sglang/srt/models/yivl.py +1 -2
- sglang/srt/openai_api/adapter.py +92 -49
- sglang/srt/openai_api/protocol.py +10 -2
- sglang/srt/sampling/penaltylib/orchestrator.py +28 -9
- sglang/srt/sampling/sampling_batch_info.py +92 -58
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server.py +116 -17
- sglang/srt/server_args.py +121 -45
- sglang/srt/utils.py +11 -3
- sglang/test/few_shot_gsm8k.py +4 -1
- sglang/test/few_shot_gsm8k_engine.py +144 -0
- sglang/test/srt/sampling/penaltylib/utils.py +16 -12
- sglang/version.py +1 -1
- {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/METADATA +72 -29
- {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/RECORD +73 -70
- {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/WHEEL +1 -1
- sglang/srt/layers/attention/flashinfer_utils.py +0 -237
- {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/LICENSE +0 -0
- {sglang-0.3.3.post1.dist-info → sglang-0.3.4.dist-info}/top_level.txt +0 -0
sglang/srt/managers/tp_worker.py
CHANGED
@@ -17,6 +17,11 @@ limitations under the License.
|
|
17
17
|
|
18
18
|
import json
|
19
19
|
import logging
|
20
|
+
import threading
|
21
|
+
import time
|
22
|
+
from queue import Queue
|
23
|
+
|
24
|
+
import torch
|
20
25
|
|
21
26
|
from sglang.srt.configs.model_config import ModelConfig
|
22
27
|
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
@@ -75,6 +80,7 @@ class TpModelWorker:
|
|
75
80
|
tokenizer_mode=server_args.tokenizer_mode,
|
76
81
|
trust_remote_code=server_args.trust_remote_code,
|
77
82
|
)
|
83
|
+
self.device = self.model_runner.device
|
78
84
|
|
79
85
|
# Profile number of tokens
|
80
86
|
self.max_total_num_tokens = self.model_runner.max_total_num_tokens
|
@@ -100,6 +106,9 @@ class TpModelWorker:
|
|
100
106
|
)[0]
|
101
107
|
set_random_seed(self.random_seed)
|
102
108
|
|
109
|
+
if server_args.enable_overlap_schedule:
|
110
|
+
self.init_overlap_status()
|
111
|
+
|
103
112
|
def get_token_and_memory_info(self):
|
104
113
|
return (
|
105
114
|
self.max_total_num_tokens,
|
@@ -109,6 +118,81 @@ class TpModelWorker:
|
|
109
118
|
self.random_seed,
|
110
119
|
)
|
111
120
|
|
121
|
+
def init_overlap_status(self):
|
122
|
+
self.future_logits_output_dict = dict()
|
123
|
+
self.future_logits_output_ct = 0
|
124
|
+
self.future_token_ids_ct = 0
|
125
|
+
self.future_token_ids_map = torch.empty(
|
126
|
+
(self.max_running_requests * 5,), dtype=torch.int32, device=self.device
|
127
|
+
)
|
128
|
+
self.future_token_ids_limit = self.max_running_requests * 3
|
129
|
+
self.future_token_ids_output = dict()
|
130
|
+
|
131
|
+
self.future_event_map = dict()
|
132
|
+
self.forward_queue = Queue()
|
133
|
+
self.forward_stream = torch.cuda.Stream()
|
134
|
+
self.forward_thread = threading.Thread(
|
135
|
+
target=self.forward_thread_func,
|
136
|
+
)
|
137
|
+
self.forward_thread.start()
|
138
|
+
|
139
|
+
def forward_thread_func(self):
|
140
|
+
with torch.cuda.stream(self.forward_stream):
|
141
|
+
self.forward_thread_func_()
|
142
|
+
|
143
|
+
@torch.inference_mode()
|
144
|
+
def forward_thread_func_(self):
|
145
|
+
while True:
|
146
|
+
tic1 = time.time()
|
147
|
+
model_worker_batch, future_logits_output, future_next_token_ids = (
|
148
|
+
self.forward_queue.get()
|
149
|
+
)
|
150
|
+
|
151
|
+
# Resolve future tokens in the input
|
152
|
+
tic2 = time.time()
|
153
|
+
resolved_input_ids = model_worker_batch.input_ids
|
154
|
+
future_mask = resolved_input_ids < 0
|
155
|
+
resolved_input_ids[future_mask] = self.future_token_ids_map[
|
156
|
+
-resolved_input_ids[future_mask]
|
157
|
+
]
|
158
|
+
|
159
|
+
# Run forward
|
160
|
+
logits_output, next_token_ids = self.forward_batch_generation(
|
161
|
+
model_worker_batch
|
162
|
+
)
|
163
|
+
|
164
|
+
# Set future values
|
165
|
+
if model_worker_batch.return_logprob:
|
166
|
+
self.future_logits_output_dict[future_logits_output] = logits_output
|
167
|
+
|
168
|
+
# logger.info(f"set output {future_next_token_ids=}, {next_token_ids=}")
|
169
|
+
self.future_token_ids_map[-future_next_token_ids] = next_token_ids.to(
|
170
|
+
torch.int32
|
171
|
+
)
|
172
|
+
# logger.info("Set event")
|
173
|
+
self.future_token_ids_output[model_worker_batch.bid] = (
|
174
|
+
next_token_ids.tolist()
|
175
|
+
)
|
176
|
+
self.future_event_map[model_worker_batch.bid].set()
|
177
|
+
|
178
|
+
if False:
|
179
|
+
tic3 = time.time()
|
180
|
+
self.acc_time_with_waiting += tic3 - tic1
|
181
|
+
self.acc_time_without_waiting += tic3 - tic2
|
182
|
+
if self.forward_queue.qsize() == 0:
|
183
|
+
logger.info(
|
184
|
+
f"{self.acc_time_with_waiting=:.3f}, {self.acc_time_without_waiting=:.3f}, {self.forward_queue.qsize()=}"
|
185
|
+
)
|
186
|
+
|
187
|
+
def resolve_future_token_ids(self, bid: int):
|
188
|
+
self.future_event_map[bid].wait()
|
189
|
+
ret = self.future_token_ids_output[bid]
|
190
|
+
del self.future_event_map[bid]
|
191
|
+
return ret
|
192
|
+
|
193
|
+
def resolve_future_logits_output(self, future_obj):
|
194
|
+
return self.future_logits_output_dict.pop(future_obj)
|
195
|
+
|
112
196
|
def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
|
113
197
|
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
114
198
|
logits_output = self.model_runner.forward(forward_batch)
|
@@ -118,9 +202,35 @@ class TpModelWorker:
|
|
118
202
|
def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
|
119
203
|
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
120
204
|
logits_output = self.model_runner.forward(forward_batch)
|
121
|
-
embeddings = logits_output.embeddings
|
205
|
+
embeddings = logits_output.embeddings
|
122
206
|
return embeddings
|
123
207
|
|
208
|
+
def forward_batch_generation_non_blocking(
|
209
|
+
self, model_worker_batch: ModelWorkerBatch
|
210
|
+
):
|
211
|
+
# Allocate output future objects
|
212
|
+
future_logits_output = self.future_logits_output_ct
|
213
|
+
self.future_logits_output_ct += 1
|
214
|
+
|
215
|
+
bs = len(model_worker_batch.seq_lens)
|
216
|
+
with torch.cuda.stream(self.forward_stream):
|
217
|
+
future_next_token_ids = -torch.arange(
|
218
|
+
self.future_token_ids_ct + 1,
|
219
|
+
self.future_token_ids_ct + 1 + bs,
|
220
|
+
dtype=torch.int32,
|
221
|
+
device=self.device,
|
222
|
+
)
|
223
|
+
self.future_token_ids_ct = (
|
224
|
+
self.future_token_ids_ct + bs
|
225
|
+
) % self.future_token_ids_limit
|
226
|
+
ret = future_logits_output, future_next_token_ids
|
227
|
+
|
228
|
+
self.future_event_map[model_worker_batch.bid] = threading.Event()
|
229
|
+
self.forward_queue.put(
|
230
|
+
(model_worker_batch.copy(), future_logits_output, future_next_token_ids)
|
231
|
+
)
|
232
|
+
return ret
|
233
|
+
|
124
234
|
def update_weights(self, recv_req: UpdateWeightReqInput):
|
125
235
|
success, message = self.model_runner.update_weights(
|
126
236
|
recv_req.model_path, recv_req.load_format
|
@@ -40,10 +40,12 @@ class ChunkCache(BasePrefixCache):
|
|
40
40
|
|
41
41
|
def cache_finished_req(self, req: Req, token_ids: Optional[List[int]] = None):
|
42
42
|
if token_ids is None:
|
43
|
-
|
43
|
+
token_id_len = len(req.origin_input_ids) + len(req.output_ids) - 1
|
44
|
+
else:
|
45
|
+
token_id_len = len(token_ids)
|
44
46
|
|
45
47
|
kv_indices = self.req_to_token_pool.req_to_token[
|
46
|
-
req.req_pool_idx, :
|
48
|
+
req.req_pool_idx, :token_id_len
|
47
49
|
]
|
48
50
|
self.req_to_token_pool.free(req.req_pool_idx)
|
49
51
|
self.token_to_kv_pool.free(kv_indices)
|
@@ -53,10 +55,12 @@ class ChunkCache(BasePrefixCache):
|
|
53
55
|
|
54
56
|
def cache_unfinished_req(self, req: Req, token_ids: Optional[List[int]] = None):
|
55
57
|
if token_ids is None:
|
56
|
-
|
58
|
+
token_id_len = len(req.fill_ids)
|
59
|
+
else:
|
60
|
+
token_id_len = len(token_ids)
|
57
61
|
|
58
62
|
kv_indices = self.req_to_token_pool.req_to_token[
|
59
|
-
req.req_pool_idx, :
|
63
|
+
req.req_pool_idx, :token_id_len
|
60
64
|
]
|
61
65
|
|
62
66
|
if req.rid not in self.entries:
|
@@ -18,7 +18,6 @@ limitations under the License.
|
|
18
18
|
import logging
|
19
19
|
from typing import List, Tuple, Union
|
20
20
|
|
21
|
-
import numpy as np
|
22
21
|
import torch
|
23
22
|
|
24
23
|
logger = logging.getLogger(__name__)
|
@@ -77,6 +76,8 @@ class BaseTokenToKVPool:
|
|
77
76
|
self.store_dtype = dtype
|
78
77
|
|
79
78
|
self.free_slots = None
|
79
|
+
self.is_not_in_free_group = True
|
80
|
+
self.free_group = []
|
80
81
|
self.clear()
|
81
82
|
|
82
83
|
def available_size(self):
|
@@ -89,14 +90,28 @@ class BaseTokenToKVPool:
|
|
89
90
|
select_index = self.free_slots[:need_size]
|
90
91
|
self.free_slots = self.free_slots[need_size:]
|
91
92
|
|
92
|
-
return
|
93
|
+
return select_index.to(self.device, non_blocking=True)
|
93
94
|
|
94
95
|
def free(self, free_index: torch.Tensor):
|
95
|
-
|
96
|
+
if self.is_not_in_free_group:
|
97
|
+
self.free_slots = torch.concat((self.free_slots, free_index.cpu()))
|
98
|
+
else:
|
99
|
+
self.free_group.append(free_index)
|
100
|
+
|
101
|
+
def free_group_begin(self):
|
102
|
+
self.is_not_in_free_group = False
|
103
|
+
self.free_group = []
|
104
|
+
|
105
|
+
def free_group_end(self):
|
106
|
+
self.is_not_in_free_group = True
|
107
|
+
if self.free_group:
|
108
|
+
self.free(torch.concat(self.free_group))
|
96
109
|
|
97
110
|
def clear(self):
|
98
111
|
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
99
|
-
self.free_slots =
|
112
|
+
self.free_slots = torch.arange(1, self.size + 1, dtype=torch.int32)
|
113
|
+
self.is_in_free_group = False
|
114
|
+
self.free_group = []
|
100
115
|
|
101
116
|
def get_key_buffer(self, layer_id: int) -> torch.Tensor:
|
102
117
|
raise NotImplementedError()
|
@@ -231,3 +246,61 @@ class MLATokenToKVPool(BaseTokenToKVPool):
|
|
231
246
|
self.kv_buffer[layer_id][loc] = cache_k.view(self.store_dtype)
|
232
247
|
else:
|
233
248
|
self.kv_buffer[layer_id][loc] = cache_k
|
249
|
+
|
250
|
+
|
251
|
+
class DoubleSparseTokenToKVPool(BaseTokenToKVPool):
|
252
|
+
|
253
|
+
def __init__(
|
254
|
+
self,
|
255
|
+
size: int,
|
256
|
+
dtype: torch.dtype,
|
257
|
+
head_num: int,
|
258
|
+
head_dim: int,
|
259
|
+
layer_num: int,
|
260
|
+
device: str,
|
261
|
+
heavy_channel_num: int,
|
262
|
+
):
|
263
|
+
super().__init__(size, dtype, device)
|
264
|
+
|
265
|
+
# [size, head_num, head_dim] for each layer
|
266
|
+
self.k_buffer = [
|
267
|
+
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device)
|
268
|
+
for _ in range(layer_num)
|
269
|
+
]
|
270
|
+
self.v_buffer = [
|
271
|
+
torch.empty((size + 1, head_num, head_dim), dtype=dtype, device=device)
|
272
|
+
for _ in range(layer_num)
|
273
|
+
]
|
274
|
+
|
275
|
+
# [size, head_num, heavy_channel_num] for each layer
|
276
|
+
self.label_buffer = [
|
277
|
+
torch.empty(
|
278
|
+
(size + 1, head_num, heavy_channel_num), dtype=dtype, device=device
|
279
|
+
)
|
280
|
+
for _ in range(layer_num)
|
281
|
+
]
|
282
|
+
|
283
|
+
def get_key_buffer(self, layer_id: int):
|
284
|
+
return self.k_buffer[layer_id]
|
285
|
+
|
286
|
+
def get_value_buffer(self, layer_id: int):
|
287
|
+
return self.v_buffer[layer_id]
|
288
|
+
|
289
|
+
def get_label_buffer(self, layer_id: int):
|
290
|
+
return self.label_buffer[layer_id]
|
291
|
+
|
292
|
+
def get_kv_buffer(self, layer_id: int):
|
293
|
+
return self.k_buffer[layer_id], self.v_buffer[layer_id]
|
294
|
+
|
295
|
+
def set_kv_buffer(
|
296
|
+
self,
|
297
|
+
layer_id: int,
|
298
|
+
loc: torch.Tensor,
|
299
|
+
cache_k: torch.Tensor,
|
300
|
+
cache_v: torch.Tensor,
|
301
|
+
cache_label: torch.Tensor,
|
302
|
+
):
|
303
|
+
# NOTE(Andy): ignore the dtype check
|
304
|
+
self.k_buffer[layer_id][loc] = cache_k
|
305
|
+
self.v_buffer[layer_id][loc] = cache_v
|
306
|
+
self.label_buffer[layer_id][loc] = cache_label
|
@@ -99,17 +99,25 @@ class RadixCache(BasePrefixCache):
|
|
99
99
|
|
100
100
|
def cache_finished_req(self, req: Req, token_ids: Optional[List[int]] = None):
|
101
101
|
"""Cache request when it finishes."""
|
102
|
+
if self.disable:
|
103
|
+
if token_ids is None:
|
104
|
+
token_ids_len = len(req.origin_input_ids) + len(req.output_ids) - 1
|
105
|
+
else:
|
106
|
+
token_ids_len = len(token_ids)
|
107
|
+
|
108
|
+
kv_indices = self.req_to_token_pool.req_to_token[
|
109
|
+
req.req_pool_idx, :token_ids_len
|
110
|
+
]
|
111
|
+
self.token_to_kv_pool.free(kv_indices)
|
112
|
+
self.req_to_token_pool.free(req.req_pool_idx)
|
113
|
+
return
|
114
|
+
|
102
115
|
if token_ids is None:
|
103
116
|
token_ids = (req.origin_input_ids + req.output_ids)[:-1]
|
104
117
|
kv_indices = self.req_to_token_pool.req_to_token[
|
105
118
|
req.req_pool_idx, : len(token_ids)
|
106
119
|
]
|
107
120
|
|
108
|
-
if self.disable:
|
109
|
-
self.token_to_kv_pool.free(kv_indices)
|
110
|
-
self.req_to_token_pool.free(req.req_pool_idx)
|
111
|
-
return
|
112
|
-
|
113
121
|
# Radix Cache takes one ref in memory pool
|
114
122
|
new_prefix_len = self.insert(token_ids, kv_indices.clone())
|
115
123
|
self.token_to_kv_pool.free(kv_indices[len(req.prefix_indices) : new_prefix_len])
|
@@ -229,7 +237,7 @@ class RadixCache(BasePrefixCache):
|
|
229
237
|
def _split_node(self, key, child: TreeNode, split_len: int):
|
230
238
|
# new_node -> child
|
231
239
|
new_node = TreeNode()
|
232
|
-
new_node.children = {key[split_len
|
240
|
+
new_node.children = {key[split_len]: child}
|
233
241
|
new_node.parent = child.parent
|
234
242
|
new_node.lock_ref = child.lock_ref
|
235
243
|
new_node.key = child.key[:split_len]
|
@@ -237,7 +245,7 @@ class RadixCache(BasePrefixCache):
|
|
237
245
|
child.parent = new_node
|
238
246
|
child.key = child.key[split_len:]
|
239
247
|
child.value = child.value[split_len:]
|
240
|
-
new_node.parent.children[key[
|
248
|
+
new_node.parent.children[key[0]] = new_node
|
241
249
|
return new_node
|
242
250
|
|
243
251
|
def _insert_helper(self, node: TreeNode, key: List, value):
|
@@ -245,10 +245,10 @@ class CudaGraphRunner:
|
|
245
245
|
self.out_cache_loc.zero_()
|
246
246
|
|
247
247
|
# Common inputs
|
248
|
-
self.input_ids[:raw_bs]
|
249
|
-
self.req_pool_indices[:raw_bs]
|
250
|
-
self.seq_lens[:raw_bs]
|
251
|
-
self.out_cache_loc[:raw_bs]
|
248
|
+
self.input_ids[:raw_bs].copy_(forward_batch.input_ids)
|
249
|
+
self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
|
250
|
+
self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
|
251
|
+
self.out_cache_loc[:raw_bs].copy_(forward_batch.out_cache_loc)
|
252
252
|
|
253
253
|
# Attention backend
|
254
254
|
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
|
@@ -118,7 +118,7 @@ class ForwardBatch:
|
|
118
118
|
batch: ModelWorkerBatch,
|
119
119
|
model_runner: ModelRunner,
|
120
120
|
):
|
121
|
-
device =
|
121
|
+
device = model_runner.device
|
122
122
|
|
123
123
|
ret = cls(
|
124
124
|
forward_mode=batch.forward_mode,
|
@@ -134,27 +134,23 @@ class ForwardBatch:
|
|
134
134
|
)
|
135
135
|
|
136
136
|
# Init position information
|
137
|
-
if ret.forward_mode.is_decode():
|
138
|
-
ret.positions =
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
|
144
|
-
|
145
|
-
|
146
|
-
|
147
|
-
],
|
148
|
-
axis=0,
|
149
|
-
),
|
150
|
-
device=device,
|
151
|
-
).to(torch.int64)
|
152
|
-
|
137
|
+
if not ret.forward_mode.is_decode():
|
138
|
+
ret.positions = torch.concat(
|
139
|
+
[
|
140
|
+
torch.arange(prefix_len, prefix_len + extend_len, device=device)
|
141
|
+
for prefix_len, extend_len in zip(
|
142
|
+
batch.extend_prefix_lens, batch.extend_seq_lens
|
143
|
+
)
|
144
|
+
],
|
145
|
+
axis=0,
|
146
|
+
)
|
153
147
|
ret.image_inputs = batch.image_inputs
|
154
|
-
ret.extend_seq_lens = torch.tensor(
|
148
|
+
ret.extend_seq_lens = torch.tensor(
|
149
|
+
batch.extend_seq_lens, dtype=torch.int32
|
150
|
+
).to(device, non_blocking=True)
|
155
151
|
ret.extend_prefix_lens = torch.tensor(
|
156
|
-
batch.extend_prefix_lens,
|
157
|
-
)
|
152
|
+
batch.extend_prefix_lens, dtype=torch.int32
|
153
|
+
).to(device, non_blocking=True)
|
158
154
|
ret.extend_start_loc = torch.zeros_like(ret.extend_seq_lens)
|
159
155
|
ret.extend_start_loc[1:] = torch.cumsum(ret.extend_seq_lens[:-1], dim=0)
|
160
156
|
ret.extend_seq_lens_cpu = batch.extend_seq_lens
|
@@ -164,7 +160,6 @@ class ForwardBatch:
|
|
164
160
|
ret.req_to_token_pool = model_runner.req_to_token_pool
|
165
161
|
ret.token_to_kv_pool = model_runner.token_to_kv_pool
|
166
162
|
ret.attn_backend = model_runner.attn_backend
|
167
|
-
model_runner.attn_backend.init_forward_metadata(ret)
|
168
163
|
|
169
164
|
# Init lora information
|
170
165
|
if model_runner.server_args.lora_paths is not None:
|
@@ -18,6 +18,7 @@ limitations under the License.
|
|
18
18
|
import gc
|
19
19
|
import importlib
|
20
20
|
import importlib.resources
|
21
|
+
import json
|
21
22
|
import logging
|
22
23
|
import pkgutil
|
23
24
|
from functools import lru_cache
|
@@ -39,6 +40,7 @@ from vllm.model_executor.models import ModelRegistry
|
|
39
40
|
|
40
41
|
from sglang.srt.configs.model_config import AttentionArch, ModelConfig
|
41
42
|
from sglang.srt.constrained import disable_cache
|
43
|
+
from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend
|
42
44
|
from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
|
43
45
|
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
|
44
46
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
@@ -46,6 +48,7 @@ from sglang.srt.layers.sampler import Sampler
|
|
46
48
|
from sglang.srt.lora.lora_manager import LoRAManager
|
47
49
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
48
50
|
from sglang.srt.mem_cache.memory_pool import (
|
51
|
+
DoubleSparseTokenToKVPool,
|
49
52
|
MHATokenToKVPool,
|
50
53
|
MLATokenToKVPool,
|
51
54
|
ReqToTokenPool,
|
@@ -99,6 +102,20 @@ class ModelRunner:
|
|
99
102
|
logger.info("MLA optimization is turned on. Use triton backend.")
|
100
103
|
self.server_args.attention_backend = "triton"
|
101
104
|
|
105
|
+
if self.server_args.enable_double_sparsity:
|
106
|
+
logger.info(
|
107
|
+
"Double sparsity optimization is turned on. Use triton backend without CUDA graph."
|
108
|
+
)
|
109
|
+
self.server_args.attention_backend = "triton"
|
110
|
+
self.server_args.disable_cuda_graph = True
|
111
|
+
if self.server_args.ds_heavy_channel_type is None:
|
112
|
+
raise ValueError(
|
113
|
+
"Please specify the heavy channel type for double sparsity optimization."
|
114
|
+
)
|
115
|
+
self.init_double_sparsity_channel_config(
|
116
|
+
self.server_args.ds_heavy_channel_type
|
117
|
+
)
|
118
|
+
|
102
119
|
if self.is_multimodal_model:
|
103
120
|
logger.info(
|
104
121
|
"Automatically turn off --chunked-prefill-size and adjust --mem-fraction-static for multimodal models."
|
@@ -119,6 +136,8 @@ class ModelRunner:
|
|
119
136
|
"triton_attention_reduce_in_fp32": server_args.triton_attention_reduce_in_fp32,
|
120
137
|
"disable_mla": server_args.disable_mla,
|
121
138
|
"torchao_config": server_args.torchao_config,
|
139
|
+
"disable_penalizer": server_args.disable_penalizer,
|
140
|
+
"disable_nan_detection": server_args.disable_nan_detection,
|
122
141
|
}
|
123
142
|
)
|
124
143
|
|
@@ -138,6 +157,7 @@ class ModelRunner:
|
|
138
157
|
self.init_attention_backend()
|
139
158
|
self.init_cuda_graphs()
|
140
159
|
else:
|
160
|
+
self.cuda_graph_runner = None
|
141
161
|
self.init_attention_backend()
|
142
162
|
|
143
163
|
def init_torch_distributed(self):
|
@@ -146,6 +166,11 @@ class ModelRunner:
|
|
146
166
|
if self.device == "cuda":
|
147
167
|
torch.cuda.set_device(self.gpu_id)
|
148
168
|
backend = "nccl"
|
169
|
+
# ToDO(liangan1):Just use gloo to bypass the initilization fail
|
170
|
+
# Need to use xccl for xpu backend in the future
|
171
|
+
elif self.device == "xpu":
|
172
|
+
torch.xpu.set_device(self.gpu_id)
|
173
|
+
backend = "gloo"
|
149
174
|
|
150
175
|
if not self.server_args.enable_p2p_check:
|
151
176
|
monkey_patch_vllm_p2p_access_check(self.gpu_id)
|
@@ -432,6 +457,16 @@ class ModelRunner:
|
|
432
457
|
layer_num=self.model_config.num_hidden_layers,
|
433
458
|
device=self.device,
|
434
459
|
)
|
460
|
+
elif self.server_args.enable_double_sparsity:
|
461
|
+
self.token_to_kv_pool = DoubleSparseTokenToKVPool(
|
462
|
+
self.max_total_num_tokens,
|
463
|
+
dtype=self.kv_cache_dtype,
|
464
|
+
head_num=self.model_config.get_num_kv_heads(self.tp_size),
|
465
|
+
head_dim=self.model_config.head_dim,
|
466
|
+
layer_num=self.model_config.num_hidden_layers,
|
467
|
+
device=self.device,
|
468
|
+
heavy_channel_num=self.server_args.ds_heavy_channel_num,
|
469
|
+
)
|
435
470
|
else:
|
436
471
|
self.token_to_kv_pool = MHATokenToKVPool(
|
437
472
|
self.max_total_num_tokens,
|
@@ -468,12 +503,33 @@ class ModelRunner:
|
|
468
503
|
"Cross attention is not supported in the triton attention backend. "
|
469
504
|
"Please use `--attention-backend flashinfer`."
|
470
505
|
)
|
471
|
-
self.
|
506
|
+
if self.server_args.enable_double_sparsity:
|
507
|
+
self.attn_backend = DoubleSparseAttnBackend(self)
|
508
|
+
else:
|
509
|
+
self.attn_backend = TritonAttnBackend(self)
|
472
510
|
else:
|
473
511
|
raise ValueError(
|
474
512
|
f"Invalid attention backend: {self.server_args.attention_backend}"
|
475
513
|
)
|
476
514
|
|
515
|
+
def init_double_sparsity_channel_config(self, selected_channel):
|
516
|
+
|
517
|
+
selected_channel = "." + selected_channel + "_proj"
|
518
|
+
self.sorted_channels = []
|
519
|
+
# load channel config
|
520
|
+
with open(self.server_args.ds_channel_config_path, "r") as f:
|
521
|
+
channel_config = json.load(f)
|
522
|
+
|
523
|
+
for i in range(self.model_config.num_hidden_layers):
|
524
|
+
key = "model.layers." + str(i) + ".self_attn" + selected_channel
|
525
|
+
self.sorted_channels.append(
|
526
|
+
torch.tensor(channel_config[key])[
|
527
|
+
:, : self.server_args.ds_heavy_channel_num
|
528
|
+
]
|
529
|
+
.contiguous()
|
530
|
+
.cuda()
|
531
|
+
)
|
532
|
+
|
477
533
|
def init_cuda_graphs(self):
|
478
534
|
"""Capture cuda graphs."""
|
479
535
|
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
@@ -496,11 +552,14 @@ class ModelRunner:
|
|
496
552
|
):
|
497
553
|
return self.cuda_graph_runner.replay(forward_batch)
|
498
554
|
|
555
|
+
forward_batch.positions = (forward_batch.seq_lens - 1).to(torch.int64)
|
556
|
+
self.attn_backend.init_forward_metadata(forward_batch)
|
499
557
|
return self.model.forward(
|
500
558
|
forward_batch.input_ids, forward_batch.positions, forward_batch
|
501
559
|
)
|
502
560
|
|
503
561
|
def forward_extend(self, forward_batch: ForwardBatch):
|
562
|
+
self.attn_backend.init_forward_metadata(forward_batch)
|
504
563
|
if self.is_generation:
|
505
564
|
return self.model.forward(
|
506
565
|
forward_batch.input_ids, forward_batch.positions, forward_batch
|
sglang/srt/models/baichuan.py
CHANGED
@@ -24,7 +24,6 @@ from typing import Iterable, Optional, Tuple
|
|
24
24
|
import torch
|
25
25
|
from torch import nn
|
26
26
|
from transformers import PretrainedConfig
|
27
|
-
from vllm.config import CacheConfig
|
28
27
|
from vllm.distributed import (
|
29
28
|
get_tensor_model_parallel_rank,
|
30
29
|
get_tensor_model_parallel_world_size,
|
@@ -330,7 +329,7 @@ class BaiChuanBaseForCausalLM(nn.Module):
|
|
330
329
|
self,
|
331
330
|
config: PretrainedConfig,
|
332
331
|
position_embedding: str,
|
333
|
-
cache_config
|
332
|
+
cache_config=None,
|
334
333
|
quant_config: Optional[QuantizationConfig] = None,
|
335
334
|
):
|
336
335
|
super().__init__()
|
@@ -404,7 +403,7 @@ class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
|
|
404
403
|
def __init__(
|
405
404
|
self,
|
406
405
|
config,
|
407
|
-
cache_config
|
406
|
+
cache_config=None,
|
408
407
|
quant_config: Optional[QuantizationConfig] = None,
|
409
408
|
):
|
410
409
|
if config.hidden_size == 4096: # baichuan2 7b
|
sglang/srt/models/chatglm.py
CHANGED
@@ -22,7 +22,6 @@ from typing import Iterable, Optional, Tuple
|
|
22
22
|
import torch
|
23
23
|
from torch import nn
|
24
24
|
from torch.nn import LayerNorm
|
25
|
-
from vllm.config import CacheConfig
|
26
25
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
27
26
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
28
27
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
@@ -52,7 +51,7 @@ class GLMAttention(nn.Module):
|
|
52
51
|
self,
|
53
52
|
config,
|
54
53
|
layer_id: int = 0,
|
55
|
-
cache_config
|
54
|
+
cache_config=None,
|
56
55
|
quant_config: Optional[QuantizationConfig] = None,
|
57
56
|
):
|
58
57
|
super().__init__()
|
@@ -188,7 +187,7 @@ class GLMBlock(nn.Module):
|
|
188
187
|
self,
|
189
188
|
config,
|
190
189
|
layer_id: int,
|
191
|
-
cache_config
|
190
|
+
cache_config=None,
|
192
191
|
quant_config: Optional[QuantizationConfig] = None,
|
193
192
|
):
|
194
193
|
super().__init__()
|
@@ -260,7 +259,7 @@ class GLMTransformer(nn.Module):
|
|
260
259
|
def __init__(
|
261
260
|
self,
|
262
261
|
config,
|
263
|
-
cache_config
|
262
|
+
cache_config=None,
|
264
263
|
quant_config: Optional[QuantizationConfig] = None,
|
265
264
|
):
|
266
265
|
super().__init__()
|
@@ -308,7 +307,7 @@ class ChatGLMModel(nn.Module):
|
|
308
307
|
def __init__(
|
309
308
|
self,
|
310
309
|
config,
|
311
|
-
cache_config
|
310
|
+
cache_config=None,
|
312
311
|
quant_config: Optional[QuantizationConfig] = None,
|
313
312
|
):
|
314
313
|
super().__init__()
|
@@ -359,7 +358,7 @@ class ChatGLMForCausalLM(nn.Module):
|
|
359
358
|
def __init__(
|
360
359
|
self,
|
361
360
|
config: ChatGLMConfig,
|
362
|
-
cache_config
|
361
|
+
cache_config=None,
|
363
362
|
quant_config: Optional[QuantizationConfig] = None,
|
364
363
|
lora_config: Optional[LoraConfig] = None,
|
365
364
|
):
|
sglang/srt/models/commandr.py
CHANGED
@@ -45,7 +45,6 @@ import torch.utils.checkpoint
|
|
45
45
|
from torch import nn
|
46
46
|
from torch.nn.parameter import Parameter
|
47
47
|
from transformers import PretrainedConfig
|
48
|
-
from vllm.config import CacheConfig
|
49
48
|
from vllm.distributed import (
|
50
49
|
get_tensor_model_parallel_rank,
|
51
50
|
get_tensor_model_parallel_world_size,
|
@@ -320,7 +319,7 @@ class CohereForCausalLM(nn.Module):
|
|
320
319
|
self,
|
321
320
|
config: PretrainedConfig,
|
322
321
|
quant_config: Optional[QuantizationConfig] = None,
|
323
|
-
cache_config
|
322
|
+
cache_config=None,
|
324
323
|
) -> None:
|
325
324
|
super().__init__()
|
326
325
|
self.config = config
|
sglang/srt/models/dbrx.py
CHANGED
@@ -20,7 +20,6 @@ from typing import Iterable, Optional, Tuple
|
|
20
20
|
|
21
21
|
import torch
|
22
22
|
import torch.nn as nn
|
23
|
-
from vllm.config import CacheConfig
|
24
23
|
from vllm.distributed import (
|
25
24
|
get_tensor_model_parallel_rank,
|
26
25
|
get_tensor_model_parallel_world_size,
|
@@ -368,7 +367,7 @@ class DbrxForCausalLM(nn.Module):
|
|
368
367
|
self,
|
369
368
|
config: DbrxConfig,
|
370
369
|
quant_config: Optional[QuantizationConfig] = None,
|
371
|
-
cache_config
|
370
|
+
cache_config=None,
|
372
371
|
):
|
373
372
|
super().__init__()
|
374
373
|
self.config = config
|