sglang 0.3.4__py3-none-any.whl → 0.3.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +2 -1
- sglang/lang/chat_template.py +17 -0
- sglang/launch_server_llavavid.py +1 -1
- sglang/srt/configs/__init__.py +3 -0
- sglang/srt/configs/model_config.py +27 -2
- sglang/srt/configs/qwen2vl.py +133 -0
- sglang/srt/constrained/fsm_cache.py +10 -3
- sglang/srt/conversation.py +27 -0
- sglang/srt/hf_transformers_utils.py +16 -1
- sglang/srt/layers/attention/__init__.py +16 -5
- sglang/srt/layers/attention/double_sparsity_backend.py +22 -6
- sglang/srt/layers/attention/flashinfer_backend.py +174 -54
- sglang/srt/layers/attention/triton_backend.py +22 -6
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +26 -4
- sglang/srt/layers/linear.py +89 -63
- sglang/srt/layers/logits_processor.py +5 -5
- sglang/srt/layers/rotary_embedding.py +112 -0
- sglang/srt/layers/sampler.py +51 -39
- sglang/srt/lora/lora.py +3 -1
- sglang/srt/managers/data_parallel_controller.py +1 -1
- sglang/srt/managers/detokenizer_manager.py +4 -0
- sglang/srt/managers/image_processor.py +186 -13
- sglang/srt/managers/io_struct.py +10 -0
- sglang/srt/managers/schedule_batch.py +238 -68
- sglang/srt/managers/scheduler.py +69 -50
- sglang/srt/managers/tokenizer_manager.py +24 -4
- sglang/srt/managers/tp_worker.py +26 -111
- sglang/srt/managers/tp_worker_overlap_thread.py +209 -0
- sglang/srt/mem_cache/memory_pool.py +56 -10
- sglang/srt/mem_cache/radix_cache.py +4 -3
- sglang/srt/model_executor/cuda_graph_runner.py +87 -28
- sglang/srt/model_executor/forward_batch_info.py +83 -3
- sglang/srt/model_executor/model_runner.py +32 -11
- sglang/srt/models/chatglm.py +3 -3
- sglang/srt/models/deepseek_v2.py +2 -2
- sglang/srt/models/mllama.py +1004 -0
- sglang/srt/models/qwen2_vl.py +724 -0
- sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +6 -3
- sglang/srt/sampling/sampling_batch_info.py +13 -3
- sglang/srt/sampling/sampling_params.py +5 -7
- sglang/srt/server.py +12 -0
- sglang/srt/server_args.py +10 -0
- sglang/srt/utils.py +22 -0
- sglang/test/run_eval.py +2 -0
- sglang/test/runners.py +20 -1
- sglang/test/srt/sampling/penaltylib/utils.py +1 -0
- sglang/test/test_utils.py +100 -3
- sglang/version.py +1 -1
- {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/METADATA +17 -18
- {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/RECORD +53 -48
- {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/LICENSE +0 -0
- {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/WHEEL +0 -0
- {sglang-0.3.4.dist-info → sglang-0.3.4.post2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,209 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2023-2024 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
"""
|
15
|
+
|
16
|
+
"""A tensor parallel worker."""
|
17
|
+
|
18
|
+
import logging
|
19
|
+
import threading
|
20
|
+
import time
|
21
|
+
from queue import Queue
|
22
|
+
from typing import Optional
|
23
|
+
|
24
|
+
import torch
|
25
|
+
|
26
|
+
from sglang.srt.managers.io_struct import UpdateWeightReqInput
|
27
|
+
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
|
28
|
+
from sglang.srt.managers.tp_worker import TpModelWorker
|
29
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
30
|
+
from sglang.srt.server_args import ServerArgs
|
31
|
+
|
32
|
+
logger = logging.getLogger(__name__)
|
33
|
+
|
34
|
+
|
35
|
+
@torch.compile(dynamic=True)
|
36
|
+
def resolve_future_token_ids(input_ids, future_token_ids_map):
|
37
|
+
input_ids[:] = torch.where(
|
38
|
+
input_ids < 0,
|
39
|
+
future_token_ids_map[torch.clamp(-input_ids, min=0)],
|
40
|
+
input_ids,
|
41
|
+
)
|
42
|
+
|
43
|
+
|
44
|
+
class TpModelWorkerClient:
|
45
|
+
"""A tensor parallel model worker."""
|
46
|
+
|
47
|
+
def __init__(
|
48
|
+
self,
|
49
|
+
server_args: ServerArgs,
|
50
|
+
gpu_id: int,
|
51
|
+
tp_rank: int,
|
52
|
+
dp_rank: Optional[int],
|
53
|
+
nccl_port: int,
|
54
|
+
):
|
55
|
+
# Load the model
|
56
|
+
self.worker = TpModelWorker(server_args, gpu_id, tp_rank, dp_rank, nccl_port)
|
57
|
+
self.max_running_requests = self.worker.max_running_requests
|
58
|
+
self.device = self.worker.device
|
59
|
+
|
60
|
+
# Init future mappings
|
61
|
+
self.future_token_ids_ct = 0
|
62
|
+
self.future_token_ids_limit = self.max_running_requests * 3
|
63
|
+
self.future_token_ids_map = torch.empty(
|
64
|
+
(self.max_running_requests * 5,), dtype=torch.int32, device=self.device
|
65
|
+
)
|
66
|
+
|
67
|
+
# Launch threads
|
68
|
+
self.input_queue = Queue()
|
69
|
+
self.output_queue = Queue()
|
70
|
+
self.forward_stream = torch.cuda.Stream()
|
71
|
+
self.forward_thread = threading.Thread(
|
72
|
+
target=self.forward_thread_func,
|
73
|
+
)
|
74
|
+
self.forward_thread.start()
|
75
|
+
|
76
|
+
self.copy_queue = Queue()
|
77
|
+
self.copy_thread = threading.Thread(
|
78
|
+
target=self.copy_thread_func,
|
79
|
+
)
|
80
|
+
self.copy_thread.start()
|
81
|
+
|
82
|
+
def get_worker_info(self):
|
83
|
+
return self.worker.get_worker_info()
|
84
|
+
|
85
|
+
def get_pad_input_ids_func(self):
|
86
|
+
return self.worker.get_pad_input_ids_func()
|
87
|
+
|
88
|
+
def get_tp_cpu_group(self):
|
89
|
+
return self.worker.get_tp_cpu_group()
|
90
|
+
|
91
|
+
def get_memory_pool(self):
|
92
|
+
return (
|
93
|
+
self.worker.model_runner.req_to_token_pool,
|
94
|
+
self.worker.model_runner.token_to_kv_pool,
|
95
|
+
)
|
96
|
+
|
97
|
+
def forward_thread_func(self):
|
98
|
+
with torch.cuda.stream(self.forward_stream):
|
99
|
+
self.forward_thread_func_()
|
100
|
+
|
101
|
+
@torch.inference_mode()
|
102
|
+
def forward_thread_func_(self):
|
103
|
+
while True:
|
104
|
+
self.has_inflight_batch = False
|
105
|
+
model_worker_batch, future_token_ids_ct = self.input_queue.get()
|
106
|
+
if not model_worker_batch:
|
107
|
+
break
|
108
|
+
self.has_inflight_batch = True
|
109
|
+
self.launch_event = threading.Event()
|
110
|
+
|
111
|
+
# Resolve future tokens in the input
|
112
|
+
input_ids = model_worker_batch.input_ids
|
113
|
+
resolve_future_token_ids(input_ids, self.future_token_ids_map)
|
114
|
+
|
115
|
+
# Run forward
|
116
|
+
logits_output, next_token_ids = self.worker.forward_batch_generation(
|
117
|
+
model_worker_batch
|
118
|
+
)
|
119
|
+
|
120
|
+
# Update the future token ids map
|
121
|
+
bs = len(model_worker_batch.seq_lens)
|
122
|
+
self.future_token_ids_map[
|
123
|
+
future_token_ids_ct + 1 : future_token_ids_ct + bs + 1
|
124
|
+
] = next_token_ids
|
125
|
+
|
126
|
+
# Copy results to the CPU
|
127
|
+
if model_worker_batch.return_logprob:
|
128
|
+
logits_output.next_token_logprobs = logits_output.next_token_logprobs[
|
129
|
+
torch.arange(len(next_token_ids), device=self.device),
|
130
|
+
next_token_ids,
|
131
|
+
].to("cpu", non_blocking=True)
|
132
|
+
if logits_output.input_token_logprobs is not None:
|
133
|
+
logits_output.input_token_logprobs = (
|
134
|
+
logits_output.input_token_logprobs.to("cpu", non_blocking=True)
|
135
|
+
)
|
136
|
+
logits_output.normalized_prompt_logprobs = (
|
137
|
+
logits_output.normalized_prompt_logprobs.to(
|
138
|
+
"cpu", non_blocking=True
|
139
|
+
)
|
140
|
+
)
|
141
|
+
next_token_ids = next_token_ids.to("cpu", non_blocking=True)
|
142
|
+
copy_event = torch.cuda.Event(blocking=True)
|
143
|
+
copy_event.record()
|
144
|
+
|
145
|
+
self.launch_event.set()
|
146
|
+
self.copy_queue.put((copy_event, logits_output, next_token_ids))
|
147
|
+
|
148
|
+
def copy_thread_func(self):
|
149
|
+
while True:
|
150
|
+
copy_event, logits_output, next_token_ids = self.copy_queue.get()
|
151
|
+
if not copy_event:
|
152
|
+
break
|
153
|
+
while not copy_event.query():
|
154
|
+
time.sleep(1e-5)
|
155
|
+
|
156
|
+
if logits_output.next_token_logprobs is not None:
|
157
|
+
logits_output.next_token_logprobs = (
|
158
|
+
logits_output.next_token_logprobs.tolist()
|
159
|
+
)
|
160
|
+
if logits_output.input_token_logprobs is not None:
|
161
|
+
logits_output.input_token_logprobs = (
|
162
|
+
logits_output.input_token_logprobs.tolist()
|
163
|
+
)
|
164
|
+
logits_output.normalized_prompt_logprobs = (
|
165
|
+
logits_output.normalized_prompt_logprobs.tolist()
|
166
|
+
)
|
167
|
+
|
168
|
+
self.output_queue.put((logits_output, next_token_ids.tolist()))
|
169
|
+
|
170
|
+
def resulve_batch_result(self, bid: int):
|
171
|
+
logits_output, next_token_ids = self.output_queue.get()
|
172
|
+
if self.has_inflight_batch:
|
173
|
+
# Wait until the batch is launched
|
174
|
+
self.launch_event.wait()
|
175
|
+
return logits_output, next_token_ids
|
176
|
+
|
177
|
+
def forward_batch_generation(self, model_worker_batch: ModelWorkerBatch):
|
178
|
+
# Push a new batch to the queue
|
179
|
+
self.input_queue.put((model_worker_batch.copy(), self.future_token_ids_ct))
|
180
|
+
|
181
|
+
# Allocate output future objects
|
182
|
+
bs = len(model_worker_batch.seq_lens)
|
183
|
+
future_next_token_ids = torch.arange(
|
184
|
+
-(self.future_token_ids_ct + 1),
|
185
|
+
-(self.future_token_ids_ct + 1 + bs),
|
186
|
+
-1,
|
187
|
+
dtype=torch.int32,
|
188
|
+
device=self.device,
|
189
|
+
)
|
190
|
+
self.future_token_ids_ct = (
|
191
|
+
self.future_token_ids_ct + bs
|
192
|
+
) % self.future_token_ids_limit
|
193
|
+
return None, future_next_token_ids
|
194
|
+
|
195
|
+
def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
|
196
|
+
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
197
|
+
logits_output = self.model_runner.forward(forward_batch)
|
198
|
+
embeddings = logits_output.embeddings
|
199
|
+
return embeddings
|
200
|
+
|
201
|
+
def update_weights(self, recv_req: UpdateWeightReqInput):
|
202
|
+
success, message = self.model_runner.update_weights(
|
203
|
+
recv_req.model_path, recv_req.load_format
|
204
|
+
)
|
205
|
+
return success, message
|
206
|
+
|
207
|
+
def __delete__(self):
|
208
|
+
self.input_queue.put((None, None))
|
209
|
+
self.copy_queue.put((None, None, None))
|
@@ -13,27 +13,46 @@ See the License for the specific language governing permissions and
|
|
13
13
|
limitations under the License.
|
14
14
|
"""
|
15
15
|
|
16
|
-
"""
|
16
|
+
"""
|
17
|
+
Memory pool.
|
18
|
+
|
19
|
+
SGLang has two levels of memory pool.
|
20
|
+
ReqToTokenPool maps a a request to its token locations.
|
21
|
+
BaseTokenToKVPool maps a token location to its KV cache data.
|
22
|
+
"""
|
17
23
|
|
18
24
|
import logging
|
19
25
|
from typing import List, Tuple, Union
|
20
26
|
|
21
27
|
import torch
|
22
28
|
|
29
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
30
|
+
|
23
31
|
logger = logging.getLogger(__name__)
|
24
32
|
|
25
33
|
|
26
34
|
class ReqToTokenPool:
|
27
35
|
"""A memory pool that maps a request to its token locations."""
|
28
36
|
|
29
|
-
def __init__(self, size: int, max_context_len: int, device: str):
|
37
|
+
def __init__(self, size: int, max_context_len: int, device: str, use_records: bool):
|
30
38
|
self.size = size
|
31
39
|
self.max_context_len = max_context_len
|
32
40
|
self.device = device
|
33
|
-
self.req_to_token = torch.
|
41
|
+
self.req_to_token = torch.zeros(
|
34
42
|
(size, max_context_len), dtype=torch.int32, device=device
|
35
43
|
)
|
36
44
|
self.free_slots = list(range(size))
|
45
|
+
self.write_records = []
|
46
|
+
self.use_records = use_records
|
47
|
+
|
48
|
+
if self.use_records:
|
49
|
+
self.write = self.write_with_records
|
50
|
+
else:
|
51
|
+
self.write = self.write_without_records
|
52
|
+
|
53
|
+
def write(self, indices, values):
|
54
|
+
# Keep the signature for type checking. It will be assigned during runtime.
|
55
|
+
raise NotImplementedError()
|
37
56
|
|
38
57
|
def available_size(self):
|
39
58
|
return len(self.free_slots)
|
@@ -55,10 +74,27 @@ class ReqToTokenPool:
|
|
55
74
|
|
56
75
|
def clear(self):
|
57
76
|
self.free_slots = list(range(self.size))
|
77
|
+
self.write_records = []
|
78
|
+
|
79
|
+
def write_without_records(self, indices, values):
|
80
|
+
self.req_to_token[indices] = values
|
81
|
+
|
82
|
+
def write_with_records(self, indices, values):
|
83
|
+
self.req_to_token[indices] = values
|
84
|
+
self.write_records.append((indices, values))
|
85
|
+
|
86
|
+
def get_write_records(self):
|
87
|
+
ret = self.write_records
|
88
|
+
self.write_records = []
|
89
|
+
return ret
|
90
|
+
|
91
|
+
def apply_write_records(self, write_records: List[Tuple]):
|
92
|
+
for indices, values in write_records:
|
93
|
+
self.req_to_token[indices] = values
|
58
94
|
|
59
95
|
|
60
96
|
class BaseTokenToKVPool:
|
61
|
-
"""A memory pool that maps a token to its kv cache
|
97
|
+
"""A memory pool that maps a token location to its kv cache data."""
|
62
98
|
|
63
99
|
def __init__(
|
64
100
|
self,
|
@@ -68,12 +104,12 @@ class BaseTokenToKVPool:
|
|
68
104
|
):
|
69
105
|
self.size = size
|
70
106
|
self.dtype = dtype
|
71
|
-
self.device = device
|
72
107
|
if dtype == torch.float8_e5m2:
|
73
108
|
# NOTE: Store as torch.uint8 because Tensor index_put is not implemented for torch.float8_e5m2
|
74
109
|
self.store_dtype = torch.uint8
|
75
110
|
else:
|
76
111
|
self.store_dtype = dtype
|
112
|
+
self.device = device
|
77
113
|
|
78
114
|
self.free_slots = None
|
79
115
|
self.is_not_in_free_group = True
|
@@ -124,7 +160,7 @@ class BaseTokenToKVPool:
|
|
124
160
|
|
125
161
|
def set_kv_buffer(
|
126
162
|
self,
|
127
|
-
|
163
|
+
layer: RadixAttention,
|
128
164
|
loc: torch.Tensor,
|
129
165
|
cache_k: torch.Tensor,
|
130
166
|
cache_v: torch.Tensor,
|
@@ -179,14 +215,14 @@ class MHATokenToKVPool(BaseTokenToKVPool):
|
|
179
215
|
|
180
216
|
def set_kv_buffer(
|
181
217
|
self,
|
182
|
-
|
218
|
+
layer: RadixAttention,
|
183
219
|
loc: torch.Tensor,
|
184
220
|
cache_k: torch.Tensor,
|
185
221
|
cache_v: torch.Tensor,
|
186
222
|
):
|
223
|
+
layer_id = layer.layer_id
|
187
224
|
if cache_k.dtype != self.dtype:
|
188
225
|
cache_k = cache_k.to(self.dtype)
|
189
|
-
if cache_v.dtype != self.dtype:
|
190
226
|
cache_v = cache_v.to(self.dtype)
|
191
227
|
if self.store_dtype != self.dtype:
|
192
228
|
self.k_buffer[layer_id][loc] = cache_k.view(self.store_dtype)
|
@@ -196,6 +232,14 @@ class MHATokenToKVPool(BaseTokenToKVPool):
|
|
196
232
|
self.v_buffer[layer_id][loc] = cache_v
|
197
233
|
|
198
234
|
|
235
|
+
# This compiled version is slower in the unit test
|
236
|
+
# python3 -m unittest test_bench_serving.TestBenchServing.test_offline_throughput_non_stream_small_batch_size
|
237
|
+
@torch.compile(dynamic=True)
|
238
|
+
def copy_two_array(loc, dst_1, src_1, dst_2, src_2, dtype, store_dtype):
|
239
|
+
dst_1[loc] = src_1.to(dtype).view(store_dtype)
|
240
|
+
dst_2[loc] = src_2.to(dtype).view(store_dtype)
|
241
|
+
|
242
|
+
|
199
243
|
class MLATokenToKVPool(BaseTokenToKVPool):
|
200
244
|
|
201
245
|
def __init__(
|
@@ -235,11 +279,12 @@ class MLATokenToKVPool(BaseTokenToKVPool):
|
|
235
279
|
|
236
280
|
def set_kv_buffer(
|
237
281
|
self,
|
238
|
-
|
282
|
+
layer: RadixAttention,
|
239
283
|
loc: torch.Tensor,
|
240
284
|
cache_k: torch.Tensor,
|
241
285
|
cache_v: torch.Tensor,
|
242
286
|
):
|
287
|
+
layer_id = layer.layer_id
|
243
288
|
if cache_k.dtype != self.dtype:
|
244
289
|
cache_k = cache_k.to(self.dtype)
|
245
290
|
if self.store_dtype != self.dtype:
|
@@ -294,13 +339,14 @@ class DoubleSparseTokenToKVPool(BaseTokenToKVPool):
|
|
294
339
|
|
295
340
|
def set_kv_buffer(
|
296
341
|
self,
|
297
|
-
|
342
|
+
layer: RadixAttention,
|
298
343
|
loc: torch.Tensor,
|
299
344
|
cache_k: torch.Tensor,
|
300
345
|
cache_v: torch.Tensor,
|
301
346
|
cache_label: torch.Tensor,
|
302
347
|
):
|
303
348
|
# NOTE(Andy): ignore the dtype check
|
349
|
+
layer_id = layer.layer_id
|
304
350
|
self.k_buffer[layer_id][loc] = cache_k
|
305
351
|
self.v_buffer[layer_id][loc] = cache_v
|
306
352
|
self.label_buffer[layer_id][loc] = cache_label
|
@@ -145,9 +145,10 @@ class RadixCache(BasePrefixCache):
|
|
145
145
|
# The prefix indices could be updated, reuse it
|
146
146
|
new_indices, new_last_node = self.match_prefix(token_ids)
|
147
147
|
assert len(new_indices) == len(token_ids)
|
148
|
-
self.req_to_token_pool.
|
149
|
-
req.req_pool_idx, len(req.prefix_indices)
|
150
|
-
|
148
|
+
self.req_to_token_pool.write(
|
149
|
+
(req.req_pool_idx, slice(len(req.prefix_indices), len(new_indices))),
|
150
|
+
new_indices[len(req.prefix_indices) :],
|
151
|
+
)
|
151
152
|
|
152
153
|
self.dec_lock_ref(req.last_node)
|
153
154
|
self.inc_lock_ref(new_last_node)
|
@@ -92,6 +92,11 @@ def set_torch_compile_config():
|
|
92
92
|
torch._dynamo.config.accumulated_cache_size_limit = 1024
|
93
93
|
|
94
94
|
|
95
|
+
@torch.compile(dynamic=True)
|
96
|
+
def clamp_position(seq_lens):
|
97
|
+
return torch.clamp((seq_lens - 1), min=0).to(torch.int64)
|
98
|
+
|
99
|
+
|
95
100
|
class CudaGraphRunner:
|
96
101
|
"""A CudaGraphRunner runs the forward pass of a model with cuda graph and torch.compile."""
|
97
102
|
|
@@ -105,13 +110,13 @@ class CudaGraphRunner:
|
|
105
110
|
self.graph_memory_pool = None
|
106
111
|
self.use_torch_compile = model_runner.server_args.enable_torch_compile
|
107
112
|
self.disable_padding = model_runner.server_args.disable_cuda_graph_padding
|
113
|
+
self.is_encoder_decoder = self.model_runner.model_config.is_encoder_decoder
|
108
114
|
|
109
115
|
# Batch sizes to capture
|
110
116
|
if self.model_runner.server_args.disable_cuda_graph_padding:
|
111
117
|
self.capture_bs = list(range(1, 32)) + [64, 128]
|
112
118
|
else:
|
113
|
-
self.capture_bs = [1, 2, 4] + [i * 8 for i in range(1, 21)]
|
114
|
-
|
119
|
+
self.capture_bs = [1, 2, 3, 4] + [i * 8 for i in range(1, 21)]
|
115
120
|
self.capture_bs = [
|
116
121
|
bs for bs in self.capture_bs if bs <= model_runner.req_to_token_pool.size
|
117
122
|
]
|
@@ -128,10 +133,14 @@ class CudaGraphRunner:
|
|
128
133
|
# Attention backend
|
129
134
|
self.max_bs = max(self.capture_bs)
|
130
135
|
self.model_runner.attn_backend.init_cuda_graph_state(self.max_bs)
|
136
|
+
|
131
137
|
self.seq_len_fill_value = (
|
132
138
|
self.model_runner.attn_backend.get_cuda_graph_seq_len_fill_value()
|
133
139
|
)
|
134
140
|
|
141
|
+
# FIXME(lsyin): leave it here for now, I don't know whether it is necessary
|
142
|
+
self.encoder_len_fill_value = 0
|
143
|
+
|
135
144
|
if self.use_torch_compile:
|
136
145
|
set_torch_compile_config()
|
137
146
|
|
@@ -143,10 +152,20 @@ class CudaGraphRunner:
|
|
143
152
|
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
|
144
153
|
)
|
145
154
|
self.out_cache_loc = torch.zeros((self.max_bs,), dtype=torch.int32)
|
155
|
+
self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int32)
|
156
|
+
|
157
|
+
if self.is_encoder_decoder:
|
158
|
+
# NOTE: encoder_lens can influence the full_text_row_masked_out_mask tensor when doing mixed batch
|
159
|
+
self.encoder_lens = torch.full(
|
160
|
+
(self.max_bs,), self.encoder_len_fill_value, dtype=torch.int32
|
161
|
+
)
|
162
|
+
else:
|
163
|
+
self.encoder_lens = None
|
146
164
|
|
147
165
|
# Capture
|
148
166
|
try:
|
149
|
-
self.
|
167
|
+
with self.model_capture_mode():
|
168
|
+
self.capture()
|
150
169
|
except RuntimeError as e:
|
151
170
|
raise Exception(
|
152
171
|
f"Capture cuda graph failed: {e}\n"
|
@@ -157,11 +176,32 @@ class CudaGraphRunner:
|
|
157
176
|
"Open an issue on GitHub https://github.com/sgl-project/sglang/issues/new/choose \n"
|
158
177
|
)
|
159
178
|
|
160
|
-
|
161
|
-
|
162
|
-
|
163
|
-
|
164
|
-
|
179
|
+
@contextmanager
|
180
|
+
def model_capture_mode(self):
|
181
|
+
if hasattr(self.model_runner.model, "capture_mode"):
|
182
|
+
self.model_runner.model.capture_mode = True
|
183
|
+
|
184
|
+
yield
|
185
|
+
|
186
|
+
if hasattr(self.model_runner.model, "capture_mode"):
|
187
|
+
self.model_runner.model.capture_mode = False
|
188
|
+
|
189
|
+
def can_run(self, forward_batch: ForwardBatch):
|
190
|
+
is_bs_supported = (
|
191
|
+
forward_batch.batch_size in self.graphs
|
192
|
+
if self.disable_padding
|
193
|
+
else forward_batch.batch_size <= self.max_bs
|
194
|
+
)
|
195
|
+
|
196
|
+
# NOTE: cuda graph cannot handle mixed batch (encoder_len = 0)
|
197
|
+
# If mixed batch cannot be supported, then encoder_lens can be removed in cuda graph
|
198
|
+
# because the full_text_row_masked_out_mask tensor will always be ones
|
199
|
+
is_encoder_lens_supported = (
|
200
|
+
torch.all(forward_batch.encoder_lens > 0)
|
201
|
+
if self.is_encoder_decoder
|
202
|
+
else True
|
203
|
+
)
|
204
|
+
return is_bs_supported and is_encoder_lens_supported
|
165
205
|
|
166
206
|
def capture(self):
|
167
207
|
with graph_capture() as graph_capture_context:
|
@@ -188,10 +228,20 @@ class CudaGraphRunner:
|
|
188
228
|
req_pool_indices = self.req_pool_indices[:bs]
|
189
229
|
seq_lens = self.seq_lens[:bs]
|
190
230
|
out_cache_loc = self.out_cache_loc[:bs]
|
231
|
+
if self.is_encoder_decoder:
|
232
|
+
encoder_lens = self.encoder_lens[:bs]
|
233
|
+
else:
|
234
|
+
encoder_lens = None
|
235
|
+
|
236
|
+
seq_lens_sum = seq_lens.sum().item()
|
237
|
+
mrope_positions = self.mrope_positions[:, :bs]
|
191
238
|
|
192
239
|
# Attention backend
|
193
240
|
self.model_runner.attn_backend.init_forward_metadata_capture_cuda_graph(
|
194
|
-
bs,
|
241
|
+
bs,
|
242
|
+
req_pool_indices,
|
243
|
+
seq_lens,
|
244
|
+
encoder_lens,
|
195
245
|
)
|
196
246
|
|
197
247
|
# Run and capture
|
@@ -206,11 +256,15 @@ class CudaGraphRunner:
|
|
206
256
|
token_to_kv_pool=self.model_runner.token_to_kv_pool,
|
207
257
|
attn_backend=self.model_runner.attn_backend,
|
208
258
|
out_cache_loc=out_cache_loc,
|
259
|
+
seq_lens_sum=seq_lens_sum,
|
260
|
+
encoder_lens=encoder_lens,
|
209
261
|
return_logprob=False,
|
210
262
|
top_logprobs_nums=[0] * bs,
|
211
|
-
positions=
|
263
|
+
positions=clamp_position(seq_lens),
|
264
|
+
mrope_positions=mrope_positions,
|
212
265
|
)
|
213
|
-
|
266
|
+
logits_output = forward(input_ids, forward_batch.positions, forward_batch)
|
267
|
+
return logits_output.next_token_logits
|
214
268
|
|
215
269
|
for _ in range(2):
|
216
270
|
torch.cuda.synchronize()
|
@@ -241,7 +295,7 @@ class CudaGraphRunner:
|
|
241
295
|
index = bisect.bisect_left(self.capture_bs, raw_bs)
|
242
296
|
bs = self.capture_bs[index]
|
243
297
|
if bs != raw_bs:
|
244
|
-
self.seq_lens.fill_(
|
298
|
+
self.seq_lens.fill_(1)
|
245
299
|
self.out_cache_loc.zero_()
|
246
300
|
|
247
301
|
# Common inputs
|
@@ -249,31 +303,32 @@ class CudaGraphRunner:
|
|
249
303
|
self.req_pool_indices[:raw_bs].copy_(forward_batch.req_pool_indices)
|
250
304
|
self.seq_lens[:raw_bs].copy_(forward_batch.seq_lens)
|
251
305
|
self.out_cache_loc[:raw_bs].copy_(forward_batch.out_cache_loc)
|
306
|
+
if self.is_encoder_decoder:
|
307
|
+
self.encoder_lens[:raw_bs].copy_(forward_batch.encoder_lens)
|
308
|
+
if forward_batch.mrope_positions is not None:
|
309
|
+
self.mrope_positions[:, :raw_bs].copy_(forward_batch.mrope_positions)
|
252
310
|
|
253
311
|
# Attention backend
|
254
312
|
self.model_runner.attn_backend.init_forward_metadata_replay_cuda_graph(
|
255
|
-
bs,
|
313
|
+
bs,
|
314
|
+
self.req_pool_indices,
|
315
|
+
self.seq_lens,
|
316
|
+
forward_batch.seq_lens_sum + (bs - raw_bs),
|
317
|
+
self.encoder_lens,
|
256
318
|
)
|
257
319
|
|
258
320
|
# Replay
|
259
321
|
self.graphs[bs].replay()
|
260
|
-
|
261
|
-
|
262
|
-
# Unpad
|
263
|
-
if bs != raw_bs:
|
264
|
-
logits_output = LogitsProcessorOutput(
|
265
|
-
next_token_logits=logits_output.next_token_logits[:raw_bs],
|
266
|
-
next_token_logprobs=None,
|
267
|
-
normalized_prompt_logprobs=None,
|
268
|
-
input_token_logprobs=None,
|
269
|
-
input_top_logprobs=None,
|
270
|
-
output_top_logprobs=None,
|
271
|
-
)
|
322
|
+
next_token_logits = self.output_buffers[bs][:raw_bs]
|
272
323
|
|
273
324
|
# Extract logprobs
|
274
325
|
if forward_batch.return_logprob:
|
275
|
-
|
276
|
-
|
326
|
+
next_token_logprobs = torch.nn.functional.log_softmax(
|
327
|
+
next_token_logits, dim=-1
|
328
|
+
)
|
329
|
+
logits_output = LogitsProcessorOutput(
|
330
|
+
next_token_logits=next_token_logits,
|
331
|
+
next_token_logprobs=next_token_logprobs,
|
277
332
|
)
|
278
333
|
return_top_logprob = any(x > 0 for x in forward_batch.top_logprobs_nums)
|
279
334
|
if return_top_logprob:
|
@@ -282,7 +337,11 @@ class CudaGraphRunner:
|
|
282
337
|
top_logprobs_nums=forward_batch.top_logprobs_nums,
|
283
338
|
)
|
284
339
|
logits_output.output_top_logprobs = LogitsProcessor.get_top_logprobs(
|
285
|
-
|
340
|
+
next_token_logprobs, logits_metadata
|
286
341
|
)[1]
|
342
|
+
else:
|
343
|
+
logits_output = LogitsProcessorOutput(
|
344
|
+
next_token_logits=next_token_logits,
|
345
|
+
)
|
287
346
|
|
288
347
|
return logits_output
|