sglang 0.4.3.post3__py3-none-any.whl → 0.4.3.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +1 -1
- sglang/srt/layers/attention/flashinfer_backend.py +94 -48
- sglang/srt/layers/attention/triton_backend.py +4 -2
- sglang/srt/managers/io_struct.py +1 -0
- sglang/srt/managers/scheduler.py +144 -127
- sglang/srt/managers/tokenizer_manager.py +1 -0
- sglang/srt/mem_cache/memory_pool.py +34 -29
- sglang/srt/metrics/collector.py +8 -0
- sglang/srt/model_executor/cuda_graph_runner.py +1 -7
- sglang/srt/model_executor/model_runner.py +97 -78
- sglang/srt/server_args.py +3 -12
- sglang/srt/speculative/build_eagle_tree.py +6 -1
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -11
- sglang/srt/speculative/eagle_utils.py +2 -1
- sglang/srt/speculative/eagle_worker.py +67 -32
- sglang/version.py +1 -1
- {sglang-0.4.3.post3.dist-info → sglang-0.4.3.post4.dist-info}/METADATA +2 -1
- {sglang-0.4.3.post3.dist-info → sglang-0.4.3.post4.dist-info}/RECORD +21 -21
- {sglang-0.4.3.post3.dist-info → sglang-0.4.3.post4.dist-info}/LICENSE +0 -0
- {sglang-0.4.3.post3.dist-info → sglang-0.4.3.post4.dist-info}/WHEEL +0 -0
- {sglang-0.4.3.post3.dist-info → sglang-0.4.3.post4.dist-info}/top_level.txt +0 -0
@@ -122,66 +122,17 @@ class ModelRunner:
|
|
122
122
|
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
123
123
|
|
124
124
|
# Model-specific adjustment
|
125
|
-
|
126
|
-
self.model_config.attention_arch == AttentionArch.MLA
|
127
|
-
and not self.server_args.disable_mla
|
128
|
-
):
|
129
|
-
# TODO: add MLA optimization on CPU
|
130
|
-
if self.server_args.device != "cpu":
|
131
|
-
if server_args.enable_flashinfer_mla:
|
132
|
-
logger.info(
|
133
|
-
"MLA optimization is turned on. Use flashinfer mla backend."
|
134
|
-
)
|
135
|
-
self.server_args.attention_backend = "flashinfer_mla"
|
136
|
-
else:
|
137
|
-
logger.info("MLA optimization is turned on. Use triton backend.")
|
138
|
-
self.server_args.attention_backend = "triton"
|
125
|
+
self.model_specific_adjustment()
|
139
126
|
|
140
|
-
if self.server_args.enable_double_sparsity:
|
141
|
-
logger.info(
|
142
|
-
"Double sparsity optimization is turned on. Use triton backend without CUDA graph."
|
143
|
-
)
|
144
|
-
self.server_args.attention_backend = "triton"
|
145
|
-
self.server_args.disable_cuda_graph = True
|
146
|
-
if self.server_args.ds_heavy_channel_type is None:
|
147
|
-
raise ValueError(
|
148
|
-
"Please specify the heavy channel type for double sparsity optimization."
|
149
|
-
)
|
150
|
-
self.init_double_sparsity_channel_config(
|
151
|
-
self.server_args.ds_heavy_channel_type
|
152
|
-
)
|
153
|
-
|
154
|
-
if self.is_multimodal:
|
155
|
-
self.mem_fraction_static *= 0.95
|
156
|
-
logger.info(
|
157
|
-
f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} "
|
158
|
-
f"because this is a multimodal model."
|
159
|
-
)
|
160
|
-
|
161
|
-
if self.model_config.hf_config.architectures == [
|
162
|
-
"MllamaForConditionalGeneration"
|
163
|
-
]:
|
164
|
-
logger.info("Automatically turn off --chunked-prefill-size for mllama.")
|
165
|
-
server_args.chunked_prefill_size = -1
|
166
|
-
|
167
|
-
if self.model_config.hf_config.architectures == [
|
168
|
-
"Qwen2VLForConditionalGeneration"
|
169
|
-
]:
|
170
|
-
# TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
|
171
|
-
logger.info(
|
172
|
-
"Automatically turn off --chunked-prefill-size and disable radix cache for qwen2-vl."
|
173
|
-
)
|
174
|
-
server_args.chunked_prefill_size = -1
|
175
|
-
server_args.disable_radix_cache = True
|
176
|
-
|
177
|
-
# Global vars
|
178
127
|
if server_args.show_time_cost:
|
179
128
|
enable_show_time_cost()
|
129
|
+
|
180
130
|
if server_args.disable_outlines_disk_cache:
|
181
131
|
from outlines.caching import disable_cache
|
182
132
|
|
183
133
|
disable_cache()
|
184
134
|
|
135
|
+
# Global vars
|
185
136
|
global_server_args_dict.update(
|
186
137
|
{
|
187
138
|
"attention_backend": server_args.attention_backend,
|
@@ -203,6 +154,7 @@ class ModelRunner:
|
|
203
154
|
}
|
204
155
|
)
|
205
156
|
|
157
|
+
# CPU offload
|
206
158
|
set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3))
|
207
159
|
|
208
160
|
# Get memory before model loading
|
@@ -216,18 +168,6 @@ class ModelRunner:
|
|
216
168
|
self.sampler = Sampler()
|
217
169
|
self.load_model()
|
218
170
|
|
219
|
-
# Handle the case where some of models don't finish loading.
|
220
|
-
try:
|
221
|
-
dist.monitored_barrier(
|
222
|
-
group=get_tp_group().cpu_group,
|
223
|
-
timeout=datetime.timedelta(seconds=UNBALANCED_MODEL_LOADING_TIMEOUT_S),
|
224
|
-
wait_all_ranks=True,
|
225
|
-
)
|
226
|
-
except RuntimeError:
|
227
|
-
raise ValueError(
|
228
|
-
f"TP rank {self.tp_rank} could finish the model loading, but there are other ranks that didn't finish loading. It is likely due to unexpected failures (e.g., OOM) or a slow node."
|
229
|
-
) from None
|
230
|
-
|
231
171
|
# Apply torchao quantization
|
232
172
|
torchao_applied = getattr(self.model, "torchao_applied", False)
|
233
173
|
# In layered loading, torchao may have been applied
|
@@ -244,9 +184,11 @@ class ModelRunner:
|
|
244
184
|
else:
|
245
185
|
self.torch_tp_applied = False
|
246
186
|
|
247
|
-
# Init
|
187
|
+
# Init lora
|
248
188
|
if server_args.lora_paths is not None:
|
249
189
|
self.init_lora_manager()
|
190
|
+
|
191
|
+
# Init memory pool and attention backends
|
250
192
|
self.init_memory_pool(
|
251
193
|
min_per_gpu_memory,
|
252
194
|
server_args.max_running_requests,
|
@@ -260,10 +202,63 @@ class ModelRunner:
|
|
260
202
|
self.cuda_graph_runner = None
|
261
203
|
self.init_attention_backend()
|
262
204
|
|
205
|
+
def model_specific_adjustment(self):
|
206
|
+
server_args = self.server_args
|
207
|
+
|
208
|
+
if (
|
209
|
+
self.model_config.attention_arch == AttentionArch.MLA
|
210
|
+
and not server_args.disable_mla
|
211
|
+
):
|
212
|
+
# TODO: add MLA optimization on CPU
|
213
|
+
if server_args.device != "cpu":
|
214
|
+
if server_args.enable_flashinfer_mla:
|
215
|
+
logger.info(
|
216
|
+
"MLA optimization is turned on. Use flashinfer mla backend."
|
217
|
+
)
|
218
|
+
server_args.attention_backend = "flashinfer_mla"
|
219
|
+
else:
|
220
|
+
logger.info("MLA optimization is turned on. Use triton backend.")
|
221
|
+
server_args.attention_backend = "triton"
|
222
|
+
|
223
|
+
if server_args.enable_double_sparsity:
|
224
|
+
logger.info(
|
225
|
+
"Double sparsity optimization is turned on. Use triton backend without CUDA graph."
|
226
|
+
)
|
227
|
+
server_args.attention_backend = "triton"
|
228
|
+
server_args.disable_cuda_graph = True
|
229
|
+
if server_args.ds_heavy_channel_type is None:
|
230
|
+
raise ValueError(
|
231
|
+
"Please specify the heavy channel type for double sparsity optimization."
|
232
|
+
)
|
233
|
+
self.init_double_sparsity_channel_config(server_args.ds_heavy_channel_type)
|
234
|
+
|
235
|
+
if self.is_multimodal:
|
236
|
+
self.mem_fraction_static *= 0.95
|
237
|
+
logger.info(
|
238
|
+
f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} "
|
239
|
+
f"because this is a multimodal model."
|
240
|
+
)
|
241
|
+
|
242
|
+
if self.model_config.hf_config.architectures == [
|
243
|
+
"MllamaForConditionalGeneration"
|
244
|
+
]:
|
245
|
+
logger.info("Automatically turn off --chunked-prefill-size for mllama.")
|
246
|
+
server_args.chunked_prefill_size = -1
|
247
|
+
|
248
|
+
if self.model_config.hf_config.architectures == [
|
249
|
+
"Qwen2VLForConditionalGeneration"
|
250
|
+
]:
|
251
|
+
# TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
|
252
|
+
logger.info(
|
253
|
+
"Automatically turn off --chunked-prefill-size and disable radix cache for qwen2-vl."
|
254
|
+
)
|
255
|
+
server_args.chunked_prefill_size = -1
|
256
|
+
server_args.disable_radix_cache = True
|
257
|
+
|
263
258
|
def init_torch_distributed(self):
|
264
259
|
logger.info("Init torch distributed begin.")
|
265
|
-
torch.get_device_module(self.device).set_device(self.gpu_id)
|
266
260
|
|
261
|
+
torch.get_device_module(self.device).set_device(self.gpu_id)
|
267
262
|
if self.device == "cuda":
|
268
263
|
backend = "nccl"
|
269
264
|
elif self.device == "xpu":
|
@@ -400,6 +395,18 @@ class ModelRunner:
|
|
400
395
|
f"mem usage={(before_avail_memory - after_avail_memory):.2f} GB."
|
401
396
|
)
|
402
397
|
|
398
|
+
# Handle the case where some ranks do not finish loading.
|
399
|
+
try:
|
400
|
+
dist.monitored_barrier(
|
401
|
+
group=get_tp_group().cpu_group,
|
402
|
+
timeout=datetime.timedelta(seconds=UNBALANCED_MODEL_LOADING_TIMEOUT_S),
|
403
|
+
wait_all_ranks=True,
|
404
|
+
)
|
405
|
+
except RuntimeError:
|
406
|
+
raise ValueError(
|
407
|
+
f"TP rank {self.tp_rank} could finish the model loading, but there are other ranks that didn't finish loading. It is likely due to unexpected failures (e.g., OOM) or a slow node."
|
408
|
+
) from None
|
409
|
+
|
403
410
|
def update_weights_from_disk(
|
404
411
|
self, model_path: str, load_format: str
|
405
412
|
) -> tuple[bool, str]:
|
@@ -710,15 +717,6 @@ class ModelRunner:
|
|
710
717
|
# Draft worker shares req_to_token_pool with the target worker.
|
711
718
|
assert self.is_draft_worker
|
712
719
|
|
713
|
-
if self.token_to_kv_pool_allocator is None:
|
714
|
-
self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
|
715
|
-
self.max_total_num_tokens,
|
716
|
-
dtype=self.kv_cache_dtype,
|
717
|
-
device=self.device,
|
718
|
-
)
|
719
|
-
else:
|
720
|
-
assert self.is_draft_worker
|
721
|
-
|
722
720
|
if (
|
723
721
|
self.model_config.attention_arch == AttentionArch.MLA
|
724
722
|
and not self.server_args.disable_mla
|
@@ -753,6 +751,17 @@ class ModelRunner:
|
|
753
751
|
device=self.device,
|
754
752
|
enable_memory_saver=self.server_args.enable_memory_saver,
|
755
753
|
)
|
754
|
+
|
755
|
+
if self.token_to_kv_pool_allocator is None:
|
756
|
+
self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
|
757
|
+
self.max_total_num_tokens,
|
758
|
+
dtype=self.kv_cache_dtype,
|
759
|
+
device=self.device,
|
760
|
+
kvcache=self.token_to_kv_pool,
|
761
|
+
)
|
762
|
+
else:
|
763
|
+
assert self.is_draft_worker
|
764
|
+
|
756
765
|
logger.info(
|
757
766
|
f"Memory pool end. "
|
758
767
|
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
@@ -770,6 +779,10 @@ class ModelRunner:
|
|
770
779
|
def init_attention_backend(self):
|
771
780
|
"""Init attention kernel backend."""
|
772
781
|
if self.server_args.attention_backend == "flashinfer":
|
782
|
+
# Init streams
|
783
|
+
if self.server_args.speculative_algorithm == "EAGLE":
|
784
|
+
self.plan_stream_for_flashinfer = torch.cuda.Stream()
|
785
|
+
|
773
786
|
self.attn_backend = FlashInferAttnBackend(self)
|
774
787
|
elif self.server_args.attention_backend == "triton":
|
775
788
|
assert self.sliding_window_size is None, (
|
@@ -878,18 +891,24 @@ class ModelRunner:
|
|
878
891
|
forward_batch.input_ids, forward_batch.positions, forward_batch
|
879
892
|
)
|
880
893
|
|
881
|
-
def forward(
|
894
|
+
def forward(
|
895
|
+
self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False
|
896
|
+
) -> LogitsProcessorOutput:
|
882
897
|
if (
|
883
898
|
forward_batch.forward_mode.is_cuda_graph()
|
884
899
|
and self.cuda_graph_runner
|
885
900
|
and self.cuda_graph_runner.can_run(forward_batch)
|
886
901
|
):
|
887
|
-
return self.cuda_graph_runner.replay(
|
902
|
+
return self.cuda_graph_runner.replay(
|
903
|
+
forward_batch, skip_attn_backend_init=skip_attn_backend_init
|
904
|
+
)
|
888
905
|
|
889
906
|
if forward_batch.forward_mode.is_decode():
|
890
907
|
return self.forward_decode(forward_batch)
|
891
908
|
elif forward_batch.forward_mode.is_extend():
|
892
|
-
return self.forward_extend(
|
909
|
+
return self.forward_extend(
|
910
|
+
forward_batch, skip_attn_backend_init=skip_attn_backend_init
|
911
|
+
)
|
893
912
|
elif forward_batch.forward_mode.is_idle():
|
894
913
|
return self.forward_idle(forward_batch)
|
895
914
|
else:
|
sglang/srt/server_args.py
CHANGED
@@ -71,7 +71,6 @@ class ServerArgs:
|
|
71
71
|
schedule_policy: str = "fcfs"
|
72
72
|
schedule_conservativeness: float = 1.0
|
73
73
|
cpu_offload_gb: int = 0
|
74
|
-
prefill_only_one_req: bool = False
|
75
74
|
|
76
75
|
# Other runtime options
|
77
76
|
tp_size: int = 1
|
@@ -277,19 +276,17 @@ class ServerArgs:
|
|
277
276
|
self.speculative_algorithm = "EAGLE"
|
278
277
|
|
279
278
|
if self.speculative_algorithm == "EAGLE":
|
280
|
-
self.disable_overlap_schedule = True
|
281
|
-
self.prefill_only_one_req = True
|
282
|
-
self.disable_cuda_graph_padding = True
|
283
279
|
if self.max_running_requests is None:
|
284
280
|
self.max_running_requests = 32
|
281
|
+
self.disable_overlap_schedule = True
|
282
|
+
self.disable_cuda_graph_padding = True
|
285
283
|
logger.info(
|
286
284
|
"Overlap scheduler are disabled because of using "
|
287
285
|
"eagle speculative decoding."
|
288
|
-
"Max running request set to 32 because of using eagle speculative decoding."
|
289
286
|
)
|
290
287
|
# The token generated from the verify step is counted.
|
291
288
|
# If sepculative_num_steps >= speculative_num_draft_tokens, the additional tokens will definitely be discarded.
|
292
|
-
assert self.speculative_num_steps < self.speculative_num_draft_tokens
|
289
|
+
# assert self.speculative_num_steps < self.speculative_num_draft_tokens
|
293
290
|
|
294
291
|
# GGUF
|
295
292
|
if (
|
@@ -509,12 +506,6 @@ class ServerArgs:
|
|
509
506
|
default=ServerArgs.cpu_offload_gb,
|
510
507
|
help="How many GBs of RAM to reserve for CPU offloading",
|
511
508
|
)
|
512
|
-
parser.add_argument(
|
513
|
-
"--prefill-only-one-req",
|
514
|
-
type=bool,
|
515
|
-
help="If true, we only prefill one request at one prefill batch",
|
516
|
-
default=ServerArgs.prefill_only_one_req,
|
517
|
-
)
|
518
509
|
|
519
510
|
# Other runtime options
|
520
511
|
parser.add_argument(
|
@@ -26,7 +26,12 @@ def build_tree_kernel_efficient_preprocess(
|
|
26
26
|
|
27
27
|
draft_tokens = torch.gather(ss_token_list, index=top_scores_index, dim=1)
|
28
28
|
draft_tokens = torch.cat((verified_id.unsqueeze(1), draft_tokens), dim=1).flatten()
|
29
|
-
|
29
|
+
|
30
|
+
if len(parents_list) > 1:
|
31
|
+
parent_list = torch.cat(parents_list[:-1], dim=1)
|
32
|
+
else:
|
33
|
+
batch_size = parents_list[0].shape[0]
|
34
|
+
parent_list = torch.empty(batch_size, 0, device=parents_list[0].device)
|
30
35
|
|
31
36
|
return parent_list, top_scores_index, draft_tokens
|
32
37
|
|
@@ -1,7 +1,6 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
import bisect
|
4
|
-
import time
|
5
4
|
from typing import TYPE_CHECKING, Callable
|
6
5
|
|
7
6
|
import torch
|
@@ -162,20 +161,11 @@ class EAGLEDraftCudaGraphRunner:
|
|
162
161
|
|
163
162
|
run_once()
|
164
163
|
|
165
|
-
torch.cuda.synchronize()
|
166
|
-
self.model_runner.tp_group.barrier()
|
167
|
-
|
168
|
-
torch.cuda.synchronize()
|
169
|
-
self.model_runner.tp_group.barrier()
|
170
|
-
|
171
164
|
with torch.cuda.graph(
|
172
165
|
graph, pool=get_global_graph_memory_pool(), stream=stream
|
173
166
|
):
|
174
167
|
out = run_once()
|
175
168
|
|
176
|
-
torch.cuda.synchronize()
|
177
|
-
self.model_runner.tp_group.barrier()
|
178
|
-
|
179
169
|
set_global_graph_memory_pool(graph.pool())
|
180
170
|
return graph, out
|
181
171
|
|
@@ -204,7 +194,7 @@ class EAGLEDraftCudaGraphRunner:
|
|
204
194
|
|
205
195
|
# Attention backend
|
206
196
|
self.model_runner.draft_attn_backend.init_forward_metadata_replay_cuda_graph(
|
207
|
-
forward_batch
|
197
|
+
forward_batch, forward_batch.batch_size
|
208
198
|
)
|
209
199
|
|
210
200
|
# Replay
|
@@ -1,7 +1,7 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
3
|
from dataclasses import dataclass
|
4
|
-
from typing import TYPE_CHECKING,
|
4
|
+
from typing import TYPE_CHECKING, List
|
5
5
|
|
6
6
|
import torch
|
7
7
|
import torch.nn.functional as F
|
@@ -62,6 +62,7 @@ class EagleDraftInput:
|
|
62
62
|
batch.input_ids[pt : pt + extend_len] = torch.concat(
|
63
63
|
(input_ids[1:], self.verified_id[i].reshape(1))
|
64
64
|
)
|
65
|
+
pt += extend_len
|
65
66
|
|
66
67
|
def prepare_extend_after_decode(self, batch: ScheduleBatch, speculative_num_steps):
|
67
68
|
assert self.verified_id.numel() == batch.out_cache_loc.shape[0]
|
@@ -1,20 +1,19 @@
|
|
1
1
|
import logging
|
2
2
|
import os
|
3
3
|
import time
|
4
|
-
from typing import
|
4
|
+
from typing import List, Optional, Tuple
|
5
5
|
|
6
6
|
import torch
|
7
7
|
from huggingface_hub import snapshot_download
|
8
8
|
|
9
9
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
10
|
-
from sglang.srt.managers.schedule_batch import
|
10
|
+
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
11
11
|
from sglang.srt.managers.tp_worker import TpModelWorker
|
12
12
|
from sglang.srt.model_executor.forward_batch_info import (
|
13
13
|
CaptureHiddenMode,
|
14
14
|
ForwardBatch,
|
15
15
|
ForwardMode,
|
16
16
|
)
|
17
|
-
from sglang.srt.model_executor.model_runner import ModelRunner
|
18
17
|
from sglang.srt.server_args import ServerArgs
|
19
18
|
from sglang.srt.speculative.eagle_draft_cuda_graph_runner import (
|
20
19
|
EAGLEDraftCudaGraphRunner,
|
@@ -27,7 +26,6 @@ from sglang.srt.speculative.eagle_utils import (
|
|
27
26
|
fast_topk,
|
28
27
|
select_top_k_tokens,
|
29
28
|
)
|
30
|
-
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
31
29
|
from sglang.srt.utils import get_available_gpu_memory
|
32
30
|
|
33
31
|
logger = logging.getLogger(__name__)
|
@@ -44,16 +42,30 @@ class EAGLEWorker(TpModelWorker):
|
|
44
42
|
nccl_port: int,
|
45
43
|
target_worker: TpModelWorker,
|
46
44
|
):
|
45
|
+
# Parse arguments
|
46
|
+
self.server_args = server_args
|
47
|
+
self.topk = server_args.speculative_eagle_topk
|
48
|
+
self.speculative_num_steps = server_args.speculative_num_steps
|
49
|
+
self.padded_static_len = self.speculative_num_steps + 1
|
50
|
+
self.enable_nan_detection = server_args.enable_nan_detection
|
51
|
+
self.gpu_id = gpu_id
|
52
|
+
self.device = server_args.device
|
53
|
+
self.target_worker = target_worker
|
54
|
+
|
47
55
|
# Override context length with target model's context length
|
48
56
|
server_args.context_length = target_worker.model_runner.model_config.context_len
|
49
|
-
os.environ["SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN"] = "1"
|
50
57
|
|
51
58
|
# Do not capture cuda graph in `super().__init__()`
|
52
|
-
#
|
59
|
+
# It will be captured later.
|
53
60
|
backup_disable_cuda_graph = server_args.disable_cuda_graph
|
54
61
|
server_args.disable_cuda_graph = True
|
62
|
+
# Share the allocator with a target worker.
|
63
|
+
# Draft and target worker own their own KV cache pools.
|
64
|
+
self.req_to_token_pool, self.token_to_kv_pool_allocator = (
|
65
|
+
target_worker.get_memory_pool()
|
66
|
+
)
|
55
67
|
|
56
|
-
#
|
68
|
+
# Load hot token ids
|
57
69
|
if server_args.speculative_token_map is not None:
|
58
70
|
self.hot_token_id = load_token_map(server_args.speculative_token_map)
|
59
71
|
server_args.json_model_override_args = (
|
@@ -62,13 +74,7 @@ class EAGLEWorker(TpModelWorker):
|
|
62
74
|
else:
|
63
75
|
self.hot_token_id = None
|
64
76
|
|
65
|
-
#
|
66
|
-
# owns its own KV cache.
|
67
|
-
self.req_to_token_pool, self.token_to_kv_pool_allocator = (
|
68
|
-
target_worker.get_memory_pool()
|
69
|
-
)
|
70
|
-
|
71
|
-
# Init target worker
|
77
|
+
# Init draft worker
|
72
78
|
super().__init__(
|
73
79
|
gpu_id=gpu_id,
|
74
80
|
tp_rank=tp_rank,
|
@@ -79,18 +85,6 @@ class EAGLEWorker(TpModelWorker):
|
|
79
85
|
req_to_token_pool=self.req_to_token_pool,
|
80
86
|
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
81
87
|
)
|
82
|
-
self.target_worker = target_worker
|
83
|
-
|
84
|
-
# Parse arguments
|
85
|
-
self.topk = server_args.speculative_eagle_topk
|
86
|
-
self.speculative_num_steps = server_args.speculative_num_steps
|
87
|
-
self.speculative_algorithm = SpeculativeAlgorithm.from_string(
|
88
|
-
server_args.speculative_algorithm
|
89
|
-
)
|
90
|
-
self.server_args = server_args
|
91
|
-
self.use_nan_detection = self.server_args.enable_nan_detection
|
92
|
-
self.device = self.model_runner.device
|
93
|
-
self.gpu_id = self.model_runner.gpu_id
|
94
88
|
|
95
89
|
# Share the embedding and lm_head
|
96
90
|
embed, head = self.target_worker.model_runner.model.get_embed_and_head()
|
@@ -103,8 +97,12 @@ class EAGLEWorker(TpModelWorker):
|
|
103
97
|
backup_disable_cuda_graph
|
104
98
|
)
|
105
99
|
|
100
|
+
self.init_attention_backend()
|
101
|
+
self.init_cuda_graphs()
|
102
|
+
|
103
|
+
def init_attention_backend(self):
|
106
104
|
# Create multi-step attn backends and cuda graph runners
|
107
|
-
if server_args.attention_backend == "flashinfer":
|
105
|
+
if self.server_args.attention_backend == "flashinfer":
|
108
106
|
from sglang.srt.layers.attention.flashinfer_backend import (
|
109
107
|
FlashInferMultiStepDraftBackend,
|
110
108
|
)
|
@@ -114,7 +112,7 @@ class EAGLEWorker(TpModelWorker):
|
|
114
112
|
self.topk,
|
115
113
|
self.speculative_num_steps,
|
116
114
|
)
|
117
|
-
elif server_args.attention_backend == "triton":
|
115
|
+
elif self.server_args.attention_backend == "triton":
|
118
116
|
from sglang.srt.layers.attention.triton_backend import (
|
119
117
|
TritonMultiStepDraftBackend,
|
120
118
|
)
|
@@ -126,11 +124,9 @@ class EAGLEWorker(TpModelWorker):
|
|
126
124
|
)
|
127
125
|
else:
|
128
126
|
raise ValueError(
|
129
|
-
f"EAGLE is not supportted in attention backend {server_args.attention_backend}"
|
127
|
+
f"EAGLE is not supportted in attention backend {self.server_args.attention_backend}"
|
130
128
|
)
|
131
|
-
|
132
129
|
self.draft_model_runner.draft_attn_backend = self.draft_attn_backend
|
133
|
-
self.init_cuda_graphs()
|
134
130
|
|
135
131
|
def init_cuda_graphs(self):
|
136
132
|
"""Capture cuda graphs."""
|
@@ -356,6 +352,41 @@ class EAGLEWorker(TpModelWorker):
|
|
356
352
|
batch.forward_mode = ForwardMode.DECODE
|
357
353
|
batch.spec_info = res.draft_input
|
358
354
|
|
355
|
+
if batch.return_logprob:
|
356
|
+
# Compute output logprobs using the sampler.
|
357
|
+
num_tokens_per_req = [
|
358
|
+
accept + 1 for accept in res.accept_length_per_req_cpu
|
359
|
+
]
|
360
|
+
self.target_worker.model_runner.update_output_logprobs(
|
361
|
+
logits_output,
|
362
|
+
batch.sampling_info,
|
363
|
+
batch.top_logprobs_nums,
|
364
|
+
batch.token_ids_logprobs,
|
365
|
+
res.verified_id,
|
366
|
+
# +1 for bonus token.
|
367
|
+
num_tokens_per_req=num_tokens_per_req,
|
368
|
+
)
|
369
|
+
|
370
|
+
# Add output logprobs to the request.
|
371
|
+
pt = 0
|
372
|
+
# NOTE: tolist() of these values are skipped when output is processed
|
373
|
+
next_token_logprobs = res.logits_output.next_token_logprobs.tolist()
|
374
|
+
verified_ids = res.verified_id.tolist()
|
375
|
+
for req, num_tokens in zip(batch.reqs, num_tokens_per_req):
|
376
|
+
for _ in range(num_tokens):
|
377
|
+
if req.return_logprob:
|
378
|
+
token_id = verified_ids[pt]
|
379
|
+
req.output_token_logprobs_val.append(next_token_logprobs[pt])
|
380
|
+
req.output_token_logprobs_idx.append(token_id)
|
381
|
+
if req.top_logprobs_num > 0:
|
382
|
+
req.output_top_logprobs_val.append(
|
383
|
+
res.logits_output.next_token_top_logprobs_val[pt]
|
384
|
+
)
|
385
|
+
req.output_top_logprobs_idx.append(
|
386
|
+
res.logits_output.next_token_top_logprobs_idx[pt]
|
387
|
+
)
|
388
|
+
pt += 1
|
389
|
+
|
359
390
|
return logits_output, res, model_worker_batch
|
360
391
|
|
361
392
|
def forward_draft_extend(
|
@@ -381,6 +412,7 @@ class EAGLEWorker(TpModelWorker):
|
|
381
412
|
forward_batch = ForwardBatch.init_new(
|
382
413
|
model_worker_batch, self.draft_model_runner
|
383
414
|
)
|
415
|
+
forward_batch.return_logprob = False
|
384
416
|
logits_output = self.draft_model_runner.forward(forward_batch)
|
385
417
|
self._detect_nan_if_needed(logits_output)
|
386
418
|
assert isinstance(forward_batch.spec_info, EagleDraftInput)
|
@@ -393,6 +425,8 @@ class EAGLEWorker(TpModelWorker):
|
|
393
425
|
batch.spec_info.prepare_extend_after_decode(batch, self.speculative_num_steps)
|
394
426
|
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
395
427
|
# We don't need logprob for this extend.
|
428
|
+
original_return_logprob = batch.return_logprob
|
429
|
+
batch.return_logprob = False
|
396
430
|
model_worker_batch = batch.get_model_worker_batch()
|
397
431
|
forward_batch = ForwardBatch.init_new(
|
398
432
|
model_worker_batch, self.draft_model_runner
|
@@ -404,6 +438,7 @@ class EAGLEWorker(TpModelWorker):
|
|
404
438
|
|
405
439
|
# Restore backup.
|
406
440
|
# This is because `seq_lens` can be modified in `prepare_extend_after_decode`
|
441
|
+
batch.return_logprob = original_return_logprob
|
407
442
|
batch.forward_mode = ForwardMode.DECODE
|
408
443
|
batch.seq_lens = seq_lens_backup
|
409
444
|
|
@@ -415,7 +450,7 @@ class EAGLEWorker(TpModelWorker):
|
|
415
450
|
draft_input.hidden_states = logits_output.hidden_states
|
416
451
|
|
417
452
|
def _detect_nan_if_needed(self, logits_output: LogitsProcessorOutput):
|
418
|
-
if self.
|
453
|
+
if self.enable_nan_detection:
|
419
454
|
logits = logits_output.next_token_logits
|
420
455
|
if torch.any(torch.isnan(logits)):
|
421
456
|
logger.warning("Detected errors during sampling! NaN in the logits.")
|
sglang/version.py
CHANGED
@@ -1 +1 @@
|
|
1
|
-
__version__ = "0.4.3.
|
1
|
+
__version__ = "0.4.3.post4"
|
@@ -1,6 +1,6 @@
|
|
1
1
|
Metadata-Version: 2.2
|
2
2
|
Name: sglang
|
3
|
-
Version: 0.4.3.
|
3
|
+
Version: 0.4.3.post4
|
4
4
|
Summary: SGLang is yet another fast serving framework for large language models and vision language models.
|
5
5
|
License: Apache License
|
6
6
|
Version 2.0, January 2004
|
@@ -239,6 +239,7 @@ Requires-Dist: xgrammar==0.1.14; extra == "runtime-common"
|
|
239
239
|
Requires-Dist: ninja; extra == "runtime-common"
|
240
240
|
Requires-Dist: transformers==4.48.3; extra == "runtime-common"
|
241
241
|
Requires-Dist: llguidance>=0.6.15; extra == "runtime-common"
|
242
|
+
Requires-Dist: datasets; extra == "runtime-common"
|
242
243
|
Provides-Extra: srt
|
243
244
|
Requires-Dist: sglang[runtime_common]; extra == "srt"
|
244
245
|
Requires-Dist: sgl-kernel==0.0.3.post6; extra == "srt"
|