sglang 0.4.1.post6__py3-none-any.whl → 0.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +21 -23
- sglang/api.py +2 -7
- sglang/bench_offline_throughput.py +41 -27
- sglang/bench_one_batch.py +60 -4
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +83 -71
- sglang/lang/backend/runtime_endpoint.py +183 -4
- sglang/lang/chat_template.py +46 -4
- sglang/launch_server.py +1 -1
- sglang/srt/_custom_ops.py +80 -42
- sglang/srt/configs/device_config.py +1 -1
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +1 -0
- sglang/srt/constrained/base_grammar_backend.py +21 -0
- sglang/srt/constrained/xgrammar_backend.py +8 -4
- sglang/srt/conversation.py +14 -1
- sglang/srt/distributed/__init__.py +3 -3
- sglang/srt/distributed/communication_op.py +2 -1
- sglang/srt/distributed/device_communicators/cuda_wrapper.py +2 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +112 -42
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
- sglang/srt/distributed/device_communicators/hpu_communicator.py +2 -1
- sglang/srt/distributed/device_communicators/pynccl.py +80 -1
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +112 -2
- sglang/srt/distributed/device_communicators/shm_broadcast.py +5 -72
- sglang/srt/distributed/device_communicators/xpu_communicator.py +2 -1
- sglang/srt/distributed/parallel_state.py +1 -1
- sglang/srt/distributed/utils.py +2 -1
- sglang/srt/entrypoints/engine.py +452 -0
- sglang/srt/entrypoints/http_server.py +603 -0
- sglang/srt/function_call_parser.py +494 -0
- sglang/srt/layers/activation.py +8 -8
- sglang/srt/layers/attention/flashinfer_backend.py +10 -9
- sglang/srt/layers/attention/triton_backend.py +4 -6
- sglang/srt/layers/attention/vision.py +204 -0
- sglang/srt/layers/dp_attention.py +71 -0
- sglang/srt/layers/layernorm.py +5 -5
- sglang/srt/layers/linear.py +65 -14
- sglang/srt/layers/logits_processor.py +49 -64
- sglang/srt/layers/moe/ep_moe/layer.py +24 -16
- sglang/srt/layers/moe/fused_moe_native.py +84 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +27 -7
- sglang/srt/layers/moe/fused_moe_triton/layer.py +38 -5
- sglang/srt/layers/parameter.py +18 -8
- sglang/srt/layers/quantization/__init__.py +20 -23
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/fp8.py +10 -4
- sglang/srt/layers/quantization/modelopt_quant.py +1 -2
- sglang/srt/layers/quantization/w8a8_int8.py +1 -1
- sglang/srt/layers/radix_attention.py +2 -2
- sglang/srt/layers/rotary_embedding.py +1184 -31
- sglang/srt/layers/sampler.py +64 -6
- sglang/srt/layers/torchao_utils.py +12 -6
- sglang/srt/layers/vocab_parallel_embedding.py +2 -2
- sglang/srt/lora/lora.py +1 -9
- sglang/srt/managers/configure_logging.py +3 -0
- sglang/srt/managers/data_parallel_controller.py +79 -72
- sglang/srt/managers/detokenizer_manager.py +24 -6
- sglang/srt/managers/image_processor.py +158 -2
- sglang/srt/managers/io_struct.py +57 -3
- sglang/srt/managers/schedule_batch.py +78 -45
- sglang/srt/managers/schedule_policy.py +26 -12
- sglang/srt/managers/scheduler.py +326 -201
- sglang/srt/managers/session_controller.py +1 -0
- sglang/srt/managers/tokenizer_manager.py +210 -121
- sglang/srt/managers/tp_worker.py +6 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +5 -8
- sglang/srt/managers/utils.py +44 -0
- sglang/srt/mem_cache/memory_pool.py +10 -32
- sglang/srt/metrics/collector.py +15 -6
- sglang/srt/model_executor/cuda_graph_runner.py +26 -30
- sglang/srt/model_executor/forward_batch_info.py +5 -7
- sglang/srt/model_executor/model_runner.py +44 -19
- sglang/srt/model_loader/loader.py +83 -6
- sglang/srt/model_loader/weight_utils.py +145 -6
- sglang/srt/models/baichuan.py +6 -6
- sglang/srt/models/chatglm.py +2 -2
- sglang/srt/models/commandr.py +17 -5
- sglang/srt/models/dbrx.py +13 -5
- sglang/srt/models/deepseek.py +3 -3
- sglang/srt/models/deepseek_v2.py +11 -11
- sglang/srt/models/exaone.py +2 -2
- sglang/srt/models/gemma.py +2 -2
- sglang/srt/models/gemma2.py +15 -25
- sglang/srt/models/gpt2.py +3 -5
- sglang/srt/models/gpt_bigcode.py +1 -1
- sglang/srt/models/granite.py +2 -2
- sglang/srt/models/grok.py +4 -3
- sglang/srt/models/internlm2.py +2 -2
- sglang/srt/models/llama.py +7 -5
- sglang/srt/models/minicpm.py +2 -2
- sglang/srt/models/minicpm3.py +9 -9
- sglang/srt/models/minicpmv.py +1238 -0
- sglang/srt/models/mixtral.py +3 -3
- sglang/srt/models/mixtral_quant.py +3 -3
- sglang/srt/models/mllama.py +2 -2
- sglang/srt/models/olmo.py +3 -3
- sglang/srt/models/olmo2.py +4 -4
- sglang/srt/models/olmoe.py +7 -13
- sglang/srt/models/phi3_small.py +2 -2
- sglang/srt/models/qwen.py +2 -2
- sglang/srt/models/qwen2.py +41 -4
- sglang/srt/models/qwen2_moe.py +3 -3
- sglang/srt/models/qwen2_vl.py +22 -122
- sglang/srt/models/stablelm.py +2 -2
- sglang/srt/models/torch_native_llama.py +20 -7
- sglang/srt/models/xverse.py +6 -6
- sglang/srt/models/xverse_moe.py +6 -6
- sglang/srt/openai_api/adapter.py +139 -37
- sglang/srt/openai_api/protocol.py +7 -4
- sglang/srt/sampling/custom_logit_processor.py +38 -0
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +11 -14
- sglang/srt/sampling/sampling_batch_info.py +143 -18
- sglang/srt/sampling/sampling_params.py +3 -1
- sglang/srt/server.py +4 -1090
- sglang/srt/server_args.py +77 -15
- sglang/srt/speculative/eagle_utils.py +37 -15
- sglang/srt/speculative/eagle_worker.py +11 -13
- sglang/srt/utils.py +164 -129
- sglang/test/runners.py +8 -13
- sglang/test/test_programs.py +2 -1
- sglang/test/test_utils.py +83 -22
- sglang/utils.py +12 -2
- sglang/version.py +1 -1
- {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/METADATA +21 -10
- {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/RECORD +138 -123
- sglang/launch_server_llavavid.py +0 -25
- sglang/srt/constrained/__init__.py +0 -16
- sglang/srt/distributed/device_communicators/__init__.py +0 -0
- {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/LICENSE +0 -0
- {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/WHEEL +0 -0
- {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/top_level.txt +0 -0
sglang/srt/server_args.py
CHANGED
@@ -29,8 +29,8 @@ from sglang.srt.utils import (
|
|
29
29
|
get_nvgpu_memory_capacity,
|
30
30
|
is_flashinfer_available,
|
31
31
|
is_hip,
|
32
|
-
is_ipv6,
|
33
32
|
is_port_available,
|
33
|
+
is_valid_ipv6_address,
|
34
34
|
nullable_str,
|
35
35
|
)
|
36
36
|
|
@@ -75,6 +75,7 @@ class ServerArgs:
|
|
75
75
|
# Other runtime options
|
76
76
|
tp_size: int = 1
|
77
77
|
stream_interval: int = 1
|
78
|
+
stream_output: bool = False
|
78
79
|
random_seed: Optional[int] = None
|
79
80
|
constrained_json_whitespace_pattern: Optional[str] = None
|
80
81
|
watchdog_timeout: float = 300
|
@@ -157,6 +158,11 @@ class ServerArgs:
|
|
157
158
|
num_continuous_decode_steps: int = 1
|
158
159
|
delete_ckpt_after_loading: bool = False
|
159
160
|
enable_memory_saver: bool = False
|
161
|
+
allow_auto_truncate: bool = False
|
162
|
+
|
163
|
+
# Custom logit processor
|
164
|
+
enable_custom_logit_processor: bool = False
|
165
|
+
tool_call_parser: str = None
|
160
166
|
|
161
167
|
def __post_init__(self):
|
162
168
|
# Set missing default values
|
@@ -240,14 +246,13 @@ class ServerArgs:
|
|
240
246
|
# Others
|
241
247
|
if self.enable_dp_attention:
|
242
248
|
self.dp_size = self.tp_size
|
249
|
+
assert self.tp_size % self.dp_size == 0
|
243
250
|
self.chunked_prefill_size = self.chunked_prefill_size // 2
|
244
251
|
self.schedule_conservativeness = self.schedule_conservativeness * 0.3
|
245
|
-
self.disable_overlap_schedule = True
|
246
252
|
logger.warning(
|
247
253
|
f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. "
|
248
254
|
f"The schedule conservativeness is adjusted to {self.schedule_conservativeness}. "
|
249
255
|
"Data parallel size is adjusted to be the same as tensor parallel size. "
|
250
|
-
"Overlap scheduler is disabled."
|
251
256
|
)
|
252
257
|
|
253
258
|
# Speculative Decoding
|
@@ -314,6 +319,7 @@ class ServerArgs:
|
|
314
319
|
"dummy",
|
315
320
|
"gguf",
|
316
321
|
"bitsandbytes",
|
322
|
+
"layered",
|
317
323
|
],
|
318
324
|
help="The format of the model weights to load. "
|
319
325
|
'"auto" will try to load the weights in the safetensors format '
|
@@ -327,7 +333,10 @@ class ServerArgs:
|
|
327
333
|
"which is mainly for profiling."
|
328
334
|
'"gguf" will load the weights in the gguf format. '
|
329
335
|
'"bitsandbytes" will load the weights using bitsandbytes '
|
330
|
-
"quantization."
|
336
|
+
"quantization."
|
337
|
+
'"layered" loads weights layer by layer so that one can quantize a '
|
338
|
+
"layer before loading another to make the peak memory envelope "
|
339
|
+
"smaller.",
|
331
340
|
)
|
332
341
|
parser.add_argument(
|
333
342
|
"--trust-remote-code",
|
@@ -392,7 +401,7 @@ class ServerArgs:
|
|
392
401
|
"--device",
|
393
402
|
type=str,
|
394
403
|
default="cuda",
|
395
|
-
choices=["cuda", "xpu", "hpu"],
|
404
|
+
choices=["cuda", "xpu", "hpu", "cpu"],
|
396
405
|
help="The device type.",
|
397
406
|
)
|
398
407
|
parser.add_argument(
|
@@ -492,6 +501,11 @@ class ServerArgs:
|
|
492
501
|
default=ServerArgs.stream_interval,
|
493
502
|
help="The interval (or buffer size) for streaming in terms of the token length. A smaller value makes streaming smoother, while a larger value makes the throughput higher",
|
494
503
|
)
|
504
|
+
parser.add_argument(
|
505
|
+
"--stream-output",
|
506
|
+
action="store_true",
|
507
|
+
help="Whether to output as a sequence of disjoint segments.",
|
508
|
+
)
|
495
509
|
parser.add_argument(
|
496
510
|
"--random-seed",
|
497
511
|
type=int,
|
@@ -860,6 +874,24 @@ class ServerArgs:
|
|
860
874
|
action="store_true",
|
861
875
|
help="Allow saving memory using release_memory_occupation and resume_memory_occupation",
|
862
876
|
)
|
877
|
+
parser.add_argument(
|
878
|
+
"--allow-auto-truncate",
|
879
|
+
action="store_true",
|
880
|
+
help="Allow automatically truncating requests that exceed the maximum input length instead of returning an error.",
|
881
|
+
)
|
882
|
+
parser.add_argument(
|
883
|
+
"--enable-custom-logit-processor",
|
884
|
+
action="store_true",
|
885
|
+
help="Enable users to pass custom logit processors to the server (disabled by default for security)",
|
886
|
+
)
|
887
|
+
# Function Calling
|
888
|
+
parser.add_argument(
|
889
|
+
"--tool-call-parser",
|
890
|
+
type=str,
|
891
|
+
choices=["qwen25", "mistral", "llama3"],
|
892
|
+
default=ServerArgs.tool_call_parser,
|
893
|
+
help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', and 'llama3'.",
|
894
|
+
)
|
863
895
|
|
864
896
|
@classmethod
|
865
897
|
def from_cli_args(cls, args: argparse.Namespace):
|
@@ -870,7 +902,7 @@ class ServerArgs:
|
|
870
902
|
return cls(**{attr: getattr(args, attr) for attr in attrs})
|
871
903
|
|
872
904
|
def url(self):
|
873
|
-
if
|
905
|
+
if is_valid_ipv6_address(self.host):
|
874
906
|
return f"http://[{self.host}]:{self.port}"
|
875
907
|
else:
|
876
908
|
return f"http://{self.host}:{self.port}"
|
@@ -880,8 +912,8 @@ class ServerArgs:
|
|
880
912
|
self.tp_size % self.nnodes == 0
|
881
913
|
), "tp_size must be divisible by number of nodes"
|
882
914
|
assert not (
|
883
|
-
self.dp_size > 1 and self.nnodes != 1
|
884
|
-
), "multi-node data parallel is not supported"
|
915
|
+
self.dp_size > 1 and self.nnodes != 1 and not self.enable_dp_attention
|
916
|
+
), "multi-node data parallel is not supported unless dp attention!"
|
885
917
|
assert (
|
886
918
|
self.max_loras_per_batch > 0
|
887
919
|
# FIXME
|
@@ -919,6 +951,9 @@ def prepare_server_args(argv: List[str]) -> ServerArgs:
|
|
919
951
|
return server_args
|
920
952
|
|
921
953
|
|
954
|
+
ZMQ_TCP_PORT_DELTA = 233
|
955
|
+
|
956
|
+
|
922
957
|
@dataclasses.dataclass
|
923
958
|
class PortArgs:
|
924
959
|
# The ipc filename for tokenizer to receive inputs from detokenizer (zmq)
|
@@ -932,7 +967,7 @@ class PortArgs:
|
|
932
967
|
nccl_port: int
|
933
968
|
|
934
969
|
@staticmethod
|
935
|
-
def init_new(server_args) -> "PortArgs":
|
970
|
+
def init_new(server_args, dp_rank: Optional[int] = None) -> "PortArgs":
|
936
971
|
port = server_args.port + random.randint(100, 1000)
|
937
972
|
while True:
|
938
973
|
if is_port_available(port):
|
@@ -942,12 +977,39 @@ class PortArgs:
|
|
942
977
|
else:
|
943
978
|
port -= 43
|
944
979
|
|
945
|
-
|
946
|
-
|
947
|
-
|
948
|
-
|
949
|
-
|
950
|
-
|
980
|
+
if not server_args.enable_dp_attention:
|
981
|
+
# Normal case, use IPC within a single node
|
982
|
+
return PortArgs(
|
983
|
+
tokenizer_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
|
984
|
+
scheduler_input_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
|
985
|
+
detokenizer_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
|
986
|
+
nccl_port=port,
|
987
|
+
)
|
988
|
+
else:
|
989
|
+
# DP attention. Use TCP + port to handle both single-node and multi-node.
|
990
|
+
if server_args.nnodes == 1 and server_args.dist_init_addr is None:
|
991
|
+
dist_init_addr = ("127.0.0.1", server_args.port + ZMQ_TCP_PORT_DELTA)
|
992
|
+
else:
|
993
|
+
dist_init_addr = server_args.dist_init_addr.split(":")
|
994
|
+
assert (
|
995
|
+
len(dist_init_addr) == 2
|
996
|
+
), "please provide --dist-init-addr as host:port of head node"
|
997
|
+
|
998
|
+
dist_init_host, dist_init_port = dist_init_addr
|
999
|
+
port_base = int(dist_init_port) + 1
|
1000
|
+
if dp_rank is None:
|
1001
|
+
scheduler_input_port = (
|
1002
|
+
port_base + 2
|
1003
|
+
) # TokenizerManager to DataParallelController
|
1004
|
+
else:
|
1005
|
+
scheduler_input_port = port_base + 2 + 1 + dp_rank
|
1006
|
+
|
1007
|
+
return PortArgs(
|
1008
|
+
tokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base}",
|
1009
|
+
scheduler_input_ipc_name=f"tcp://{dist_init_host}:{scheduler_input_port}",
|
1010
|
+
detokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base + 1}",
|
1011
|
+
nccl_port=port,
|
1012
|
+
)
|
951
1013
|
|
952
1014
|
|
953
1015
|
class LoRAPathAction(argparse.Action):
|
@@ -180,7 +180,6 @@ def generate_draft_decode_kv_indices(
|
|
180
180
|
class EAGLEDraftInput(SpecInfo):
|
181
181
|
def __init__(self):
|
182
182
|
self.prev_mode = ForwardMode.DECODE
|
183
|
-
self.sample_output = None
|
184
183
|
|
185
184
|
self.scores: torch.Tensor = None
|
186
185
|
self.score_list: List[torch.Tensor] = []
|
@@ -190,12 +189,16 @@ class EAGLEDraftInput(SpecInfo):
|
|
190
189
|
self.cache_list: List[torch.Tenor] = []
|
191
190
|
self.iter = 0
|
192
191
|
|
192
|
+
# shape: (b, hidden_size)
|
193
193
|
self.hidden_states: torch.Tensor = None
|
194
|
+
# shape: (b,)
|
194
195
|
self.verified_id: torch.Tensor = None
|
196
|
+
# shape: (b, vocab_size)
|
197
|
+
self.sample_output: torch.Tensor = None
|
198
|
+
|
195
199
|
self.positions: torch.Tensor = None
|
196
200
|
self.accept_length: torch.Tensor = None
|
197
|
-
self.
|
198
|
-
self.unfinished_index: List[int] = None
|
201
|
+
self.accept_length_cpu: List[int] = None
|
199
202
|
|
200
203
|
def load_server_args(self, server_args: ServerArgs):
|
201
204
|
self.topk: int = server_args.speculative_eagle_topk
|
@@ -218,7 +221,7 @@ class EAGLEDraftInput(SpecInfo):
|
|
218
221
|
:pre_len
|
219
222
|
] = req.prefix_indices
|
220
223
|
|
221
|
-
batch.req_to_token_pool.req_to_token[req.req_pool_idx
|
224
|
+
batch.req_to_token_pool.req_to_token[req.req_pool_idx, pre_len:seq_len] = (
|
222
225
|
out_cache_loc[pt : pt + req.extend_input_len]
|
223
226
|
)
|
224
227
|
|
@@ -228,6 +231,14 @@ class EAGLEDraftInput(SpecInfo):
|
|
228
231
|
assert len(batch.extend_lens) == 1
|
229
232
|
batch.input_ids = torch.concat((batch.input_ids[1:], self.verified_id))
|
230
233
|
|
234
|
+
def filter_batch(
|
235
|
+
self,
|
236
|
+
new_indices: torch.Tensor,
|
237
|
+
):
|
238
|
+
self.sample_output = self.sample_output[: len(new_indices)]
|
239
|
+
self.hidden_states = self.hidden_states[: len(new_indices)]
|
240
|
+
self.verified_id = self.verified_id[: len(new_indices)]
|
241
|
+
|
231
242
|
def prepare_for_decode(self, batch: ScheduleBatch):
|
232
243
|
prob = self.sample_output # shape: (b * top_k, vocab) or (b, vocab)
|
233
244
|
top = torch.topk(prob, self.topk, dim=-1)
|
@@ -287,7 +298,9 @@ class EAGLEDraftInput(SpecInfo):
|
|
287
298
|
self.cache_list.append(batch.out_cache_loc)
|
288
299
|
self.positions = (
|
289
300
|
batch.seq_lens[:, None]
|
290
|
-
+ torch.
|
301
|
+
+ torch.full(
|
302
|
+
[1, self.topk], fill_value=self.iter, device="cuda", dtype=torch.long
|
303
|
+
)
|
291
304
|
).flatten()
|
292
305
|
|
293
306
|
bs = len(batch.seq_lens)
|
@@ -304,24 +317,25 @@ class EAGLEDraftInput(SpecInfo):
|
|
304
317
|
|
305
318
|
def prepare_extend_after_decode(self, batch: ScheduleBatch):
|
306
319
|
batch.out_cache_loc = batch.alloc_token_slots(self.verified_id.numel())
|
307
|
-
|
320
|
+
accept_length_cpu = batch.spec_info.accept_length_cpu
|
321
|
+
batch.extend_lens = [x + 1 for x in accept_length_cpu]
|
322
|
+
batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend
|
323
|
+
seq_lens_cpu = batch.seq_lens.tolist()
|
308
324
|
|
309
325
|
pt = 0
|
310
|
-
seq_lens = batch.seq_lens.tolist()
|
311
|
-
|
312
326
|
i = 0
|
313
|
-
|
314
327
|
for req in batch.reqs:
|
315
328
|
if req.finished():
|
316
329
|
continue
|
317
330
|
# assert seq_len - pre_len == req.extend_input_len
|
318
|
-
input_len =
|
319
|
-
seq_len =
|
331
|
+
input_len = batch.extend_lens[i]
|
332
|
+
seq_len = seq_lens_cpu[i]
|
320
333
|
batch.req_to_token_pool.req_to_token[req.req_pool_idx][
|
321
334
|
seq_len - input_len : seq_len
|
322
335
|
] = batch.out_cache_loc[pt : pt + input_len]
|
323
336
|
pt += input_len
|
324
337
|
i += 1
|
338
|
+
assert pt == batch.out_cache_loc.shape[0]
|
325
339
|
|
326
340
|
self.positions = torch.empty_like(self.verified_id)
|
327
341
|
new_verified_id = torch.empty_like(self.accept_length, dtype=torch.long)
|
@@ -337,7 +351,7 @@ class EAGLEDraftInput(SpecInfo):
|
|
337
351
|
triton.next_power_of_2(self.spec_steps + 1),
|
338
352
|
)
|
339
353
|
|
340
|
-
batch.seq_lens_sum = sum(
|
354
|
+
batch.seq_lens_sum = sum(seq_lens_cpu)
|
341
355
|
batch.input_ids = self.verified_id
|
342
356
|
self.verified_id = new_verified_id
|
343
357
|
|
@@ -565,6 +579,8 @@ class EagleVerifyInput(SpecInfo):
|
|
565
579
|
finished_extend_len = {} # {rid:accept_length + 1}
|
566
580
|
accept_index_cpu = accept_index.tolist()
|
567
581
|
predict_cpu = predict.tolist()
|
582
|
+
has_finished = False
|
583
|
+
|
568
584
|
# iterate every accepted token and check if req has finished after append the token
|
569
585
|
# should be checked BEFORE free kv cache slots
|
570
586
|
for i, (req, accept_index_row) in enumerate(zip(batch.reqs, accept_index_cpu)):
|
@@ -578,7 +594,7 @@ class EagleVerifyInput(SpecInfo):
|
|
578
594
|
finished_extend_len[req.rid] = j + 1
|
579
595
|
req.check_finished()
|
580
596
|
if req.finished():
|
581
|
-
|
597
|
+
has_finished = True
|
582
598
|
# set all tokens after finished token to -1 and break
|
583
599
|
accept_index[i, j + 1 :] = -1
|
584
600
|
break
|
@@ -587,12 +603,12 @@ class EagleVerifyInput(SpecInfo):
|
|
587
603
|
if not req.finished():
|
588
604
|
new_accept_index.extend(new_accept_index_)
|
589
605
|
unfinished_index.append(i)
|
606
|
+
req.spec_verify_ct += 1
|
590
607
|
accept_length = (accept_index != -1).sum(dim=1) - 1
|
591
608
|
|
592
609
|
accept_index = accept_index[accept_index != -1]
|
593
610
|
accept_length_cpu = accept_length.tolist()
|
594
611
|
verified_id = predict[accept_index]
|
595
|
-
verified_id_cpu = verified_id.tolist()
|
596
612
|
|
597
613
|
evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
|
598
614
|
evict_mask[accept_index] = False
|
@@ -614,7 +630,13 @@ class EagleVerifyInput(SpecInfo):
|
|
614
630
|
draft_input.verified_id = predict[new_accept_index]
|
615
631
|
draft_input.hidden_states = batch.spec_info.hidden_states[new_accept_index]
|
616
632
|
draft_input.accept_length = accept_length[unfinished_index]
|
617
|
-
draft_input.
|
633
|
+
draft_input.accept_length_cpu = [
|
634
|
+
accept_length_cpu[i] for i in unfinished_index
|
635
|
+
]
|
636
|
+
if has_finished:
|
637
|
+
draft_input.seq_lens_for_draft_extend = batch.seq_lens[unfinished_index]
|
638
|
+
else:
|
639
|
+
draft_input.seq_lens_for_draft_extend = batch.seq_lens
|
618
640
|
|
619
641
|
logits_output.next_token_logits = logits_output.next_token_logits[accept_index]
|
620
642
|
return (
|
@@ -13,6 +13,7 @@ from sglang.srt.model_executor.forward_batch_info import (
|
|
13
13
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
14
14
|
from sglang.srt.server_args import ServerArgs
|
15
15
|
from sglang.srt.speculative.eagle_utils import EAGLEDraftInput
|
16
|
+
from sglang.srt.utils import rank0_print
|
16
17
|
|
17
18
|
|
18
19
|
class EAGLEWorker(TpModelWorker):
|
@@ -50,18 +51,18 @@ class EAGLEWorker(TpModelWorker):
|
|
50
51
|
|
51
52
|
def forward_draft_decode(self, batch: ScheduleBatch):
|
52
53
|
batch.spec_info.prepare_for_decode(batch)
|
54
|
+
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
53
55
|
model_worker_batch = batch.get_model_worker_batch()
|
54
56
|
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
55
|
-
forward_batch.capture_hidden_mode = CaptureHiddenMode.LAST
|
56
57
|
logits_output = self.model_runner.forward(forward_batch)
|
57
58
|
self.capture_for_decode(logits_output, forward_batch)
|
58
59
|
|
59
60
|
def forward_draft_extend(self, batch: ScheduleBatch):
|
60
61
|
self._set_mem_pool(batch, self.model_runner)
|
61
62
|
batch.spec_info.prepare_for_extend(batch)
|
63
|
+
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
62
64
|
model_worker_batch = batch.get_model_worker_batch()
|
63
65
|
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
64
|
-
forward_batch.capture_hidden_mode = CaptureHiddenMode.LAST
|
65
66
|
logits_output = self.model_runner.forward(forward_batch)
|
66
67
|
self.capture_for_decode(logits_output, forward_batch)
|
67
68
|
self._set_mem_pool(batch, self.target_worker.model_runner)
|
@@ -134,26 +135,23 @@ class EAGLEWorker(TpModelWorker):
|
|
134
135
|
batch.req_to_token_pool = runner.req_to_token_pool
|
135
136
|
|
136
137
|
def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
|
138
|
+
seq_lens_backup = batch.seq_lens
|
139
|
+
|
137
140
|
self._set_mem_pool(batch, self.model_runner)
|
138
141
|
batch.forward_mode = ForwardMode.DRAFT_EXTEND
|
139
|
-
if batch.spec_info.has_finished:
|
140
|
-
index = batch.spec_info.unfinished_index
|
141
|
-
seq_lens = batch.seq_lens
|
142
|
-
batch.seq_lens = batch.seq_lens[index]
|
143
|
-
|
144
142
|
batch.spec_info.prepare_extend_after_decode(batch)
|
143
|
+
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
145
144
|
model_worker_batch = batch.get_model_worker_batch()
|
146
145
|
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
147
|
-
forward_batch.capture_hidden_mode = CaptureHiddenMode.LAST
|
148
146
|
logits_output = self.model_runner.forward(forward_batch)
|
149
|
-
|
150
|
-
batch.spec_info.hidden_states = logits_output.hidden_states
|
151
147
|
self.capture_for_decode(logits_output, forward_batch)
|
152
|
-
batch.forward_mode = ForwardMode.DECODE
|
153
|
-
if batch.spec_info.has_finished:
|
154
|
-
batch.seq_lens = seq_lens
|
155
148
|
self._set_mem_pool(batch, self.target_worker.model_runner)
|
156
149
|
|
150
|
+
# Restore backup.
|
151
|
+
# This is because `seq_lens` can be modified in `prepare_extend_after_decode`
|
152
|
+
batch.forward_mode = ForwardMode.DECODE
|
153
|
+
batch.seq_lens = seq_lens_backup
|
154
|
+
|
157
155
|
def capture_for_decode(
|
158
156
|
self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
|
159
157
|
):
|