sglang 0.4.1.post6__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (141) hide show
  1. sglang/__init__.py +21 -23
  2. sglang/api.py +2 -7
  3. sglang/bench_offline_throughput.py +41 -27
  4. sglang/bench_one_batch.py +60 -4
  5. sglang/bench_one_batch_server.py +1 -1
  6. sglang/bench_serving.py +83 -71
  7. sglang/lang/backend/runtime_endpoint.py +183 -4
  8. sglang/lang/chat_template.py +46 -4
  9. sglang/launch_server.py +1 -1
  10. sglang/srt/_custom_ops.py +80 -42
  11. sglang/srt/configs/device_config.py +1 -1
  12. sglang/srt/configs/load_config.py +1 -0
  13. sglang/srt/configs/model_config.py +1 -0
  14. sglang/srt/constrained/base_grammar_backend.py +21 -0
  15. sglang/srt/constrained/xgrammar_backend.py +8 -4
  16. sglang/srt/conversation.py +14 -1
  17. sglang/srt/distributed/__init__.py +3 -3
  18. sglang/srt/distributed/communication_op.py +2 -1
  19. sglang/srt/distributed/device_communicators/cuda_wrapper.py +2 -1
  20. sglang/srt/distributed/device_communicators/custom_all_reduce.py +112 -42
  21. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
  22. sglang/srt/distributed/device_communicators/hpu_communicator.py +2 -1
  23. sglang/srt/distributed/device_communicators/pynccl.py +80 -1
  24. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +112 -2
  25. sglang/srt/distributed/device_communicators/shm_broadcast.py +5 -72
  26. sglang/srt/distributed/device_communicators/xpu_communicator.py +2 -1
  27. sglang/srt/distributed/parallel_state.py +1 -1
  28. sglang/srt/distributed/utils.py +2 -1
  29. sglang/srt/entrypoints/engine.py +452 -0
  30. sglang/srt/entrypoints/http_server.py +603 -0
  31. sglang/srt/function_call_parser.py +494 -0
  32. sglang/srt/layers/activation.py +8 -8
  33. sglang/srt/layers/attention/flashinfer_backend.py +10 -9
  34. sglang/srt/layers/attention/triton_backend.py +4 -6
  35. sglang/srt/layers/attention/vision.py +204 -0
  36. sglang/srt/layers/dp_attention.py +71 -0
  37. sglang/srt/layers/layernorm.py +5 -5
  38. sglang/srt/layers/linear.py +65 -14
  39. sglang/srt/layers/logits_processor.py +49 -64
  40. sglang/srt/layers/moe/ep_moe/layer.py +24 -16
  41. sglang/srt/layers/moe/fused_moe_native.py +84 -1
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  43. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +27 -7
  44. sglang/srt/layers/moe/fused_moe_triton/layer.py +38 -5
  45. sglang/srt/layers/parameter.py +18 -8
  46. sglang/srt/layers/quantization/__init__.py +20 -23
  47. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  48. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  49. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  50. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  51. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  52. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  53. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  54. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  55. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  56. sglang/srt/layers/quantization/fp8.py +10 -4
  57. sglang/srt/layers/quantization/modelopt_quant.py +1 -2
  58. sglang/srt/layers/quantization/w8a8_int8.py +1 -1
  59. sglang/srt/layers/radix_attention.py +2 -2
  60. sglang/srt/layers/rotary_embedding.py +1184 -31
  61. sglang/srt/layers/sampler.py +64 -6
  62. sglang/srt/layers/torchao_utils.py +12 -6
  63. sglang/srt/layers/vocab_parallel_embedding.py +2 -2
  64. sglang/srt/lora/lora.py +1 -9
  65. sglang/srt/managers/configure_logging.py +3 -0
  66. sglang/srt/managers/data_parallel_controller.py +79 -72
  67. sglang/srt/managers/detokenizer_manager.py +24 -6
  68. sglang/srt/managers/image_processor.py +158 -2
  69. sglang/srt/managers/io_struct.py +57 -3
  70. sglang/srt/managers/schedule_batch.py +78 -45
  71. sglang/srt/managers/schedule_policy.py +26 -12
  72. sglang/srt/managers/scheduler.py +326 -201
  73. sglang/srt/managers/session_controller.py +1 -0
  74. sglang/srt/managers/tokenizer_manager.py +210 -121
  75. sglang/srt/managers/tp_worker.py +6 -4
  76. sglang/srt/managers/tp_worker_overlap_thread.py +5 -8
  77. sglang/srt/managers/utils.py +44 -0
  78. sglang/srt/mem_cache/memory_pool.py +10 -32
  79. sglang/srt/metrics/collector.py +15 -6
  80. sglang/srt/model_executor/cuda_graph_runner.py +26 -30
  81. sglang/srt/model_executor/forward_batch_info.py +5 -7
  82. sglang/srt/model_executor/model_runner.py +44 -19
  83. sglang/srt/model_loader/loader.py +83 -6
  84. sglang/srt/model_loader/weight_utils.py +145 -6
  85. sglang/srt/models/baichuan.py +6 -6
  86. sglang/srt/models/chatglm.py +2 -2
  87. sglang/srt/models/commandr.py +17 -5
  88. sglang/srt/models/dbrx.py +13 -5
  89. sglang/srt/models/deepseek.py +3 -3
  90. sglang/srt/models/deepseek_v2.py +11 -11
  91. sglang/srt/models/exaone.py +2 -2
  92. sglang/srt/models/gemma.py +2 -2
  93. sglang/srt/models/gemma2.py +15 -25
  94. sglang/srt/models/gpt2.py +3 -5
  95. sglang/srt/models/gpt_bigcode.py +1 -1
  96. sglang/srt/models/granite.py +2 -2
  97. sglang/srt/models/grok.py +4 -3
  98. sglang/srt/models/internlm2.py +2 -2
  99. sglang/srt/models/llama.py +7 -5
  100. sglang/srt/models/minicpm.py +2 -2
  101. sglang/srt/models/minicpm3.py +9 -9
  102. sglang/srt/models/minicpmv.py +1238 -0
  103. sglang/srt/models/mixtral.py +3 -3
  104. sglang/srt/models/mixtral_quant.py +3 -3
  105. sglang/srt/models/mllama.py +2 -2
  106. sglang/srt/models/olmo.py +3 -3
  107. sglang/srt/models/olmo2.py +4 -4
  108. sglang/srt/models/olmoe.py +7 -13
  109. sglang/srt/models/phi3_small.py +2 -2
  110. sglang/srt/models/qwen.py +2 -2
  111. sglang/srt/models/qwen2.py +41 -4
  112. sglang/srt/models/qwen2_moe.py +3 -3
  113. sglang/srt/models/qwen2_vl.py +22 -122
  114. sglang/srt/models/stablelm.py +2 -2
  115. sglang/srt/models/torch_native_llama.py +20 -7
  116. sglang/srt/models/xverse.py +6 -6
  117. sglang/srt/models/xverse_moe.py +6 -6
  118. sglang/srt/openai_api/adapter.py +139 -37
  119. sglang/srt/openai_api/protocol.py +7 -4
  120. sglang/srt/sampling/custom_logit_processor.py +38 -0
  121. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +11 -14
  122. sglang/srt/sampling/sampling_batch_info.py +143 -18
  123. sglang/srt/sampling/sampling_params.py +3 -1
  124. sglang/srt/server.py +4 -1090
  125. sglang/srt/server_args.py +77 -15
  126. sglang/srt/speculative/eagle_utils.py +37 -15
  127. sglang/srt/speculative/eagle_worker.py +11 -13
  128. sglang/srt/utils.py +164 -129
  129. sglang/test/runners.py +8 -13
  130. sglang/test/test_programs.py +2 -1
  131. sglang/test/test_utils.py +83 -22
  132. sglang/utils.py +12 -2
  133. sglang/version.py +1 -1
  134. {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/METADATA +21 -10
  135. {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/RECORD +138 -123
  136. sglang/launch_server_llavavid.py +0 -25
  137. sglang/srt/constrained/__init__.py +0 -16
  138. sglang/srt/distributed/device_communicators/__init__.py +0 -0
  139. {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/LICENSE +0 -0
  140. {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/WHEEL +0 -0
  141. {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/top_level.txt +0 -0
sglang/srt/server_args.py CHANGED
@@ -29,8 +29,8 @@ from sglang.srt.utils import (
29
29
  get_nvgpu_memory_capacity,
30
30
  is_flashinfer_available,
31
31
  is_hip,
32
- is_ipv6,
33
32
  is_port_available,
33
+ is_valid_ipv6_address,
34
34
  nullable_str,
35
35
  )
36
36
 
@@ -75,6 +75,7 @@ class ServerArgs:
75
75
  # Other runtime options
76
76
  tp_size: int = 1
77
77
  stream_interval: int = 1
78
+ stream_output: bool = False
78
79
  random_seed: Optional[int] = None
79
80
  constrained_json_whitespace_pattern: Optional[str] = None
80
81
  watchdog_timeout: float = 300
@@ -157,6 +158,11 @@ class ServerArgs:
157
158
  num_continuous_decode_steps: int = 1
158
159
  delete_ckpt_after_loading: bool = False
159
160
  enable_memory_saver: bool = False
161
+ allow_auto_truncate: bool = False
162
+
163
+ # Custom logit processor
164
+ enable_custom_logit_processor: bool = False
165
+ tool_call_parser: str = None
160
166
 
161
167
  def __post_init__(self):
162
168
  # Set missing default values
@@ -240,14 +246,13 @@ class ServerArgs:
240
246
  # Others
241
247
  if self.enable_dp_attention:
242
248
  self.dp_size = self.tp_size
249
+ assert self.tp_size % self.dp_size == 0
243
250
  self.chunked_prefill_size = self.chunked_prefill_size // 2
244
251
  self.schedule_conservativeness = self.schedule_conservativeness * 0.3
245
- self.disable_overlap_schedule = True
246
252
  logger.warning(
247
253
  f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. "
248
254
  f"The schedule conservativeness is adjusted to {self.schedule_conservativeness}. "
249
255
  "Data parallel size is adjusted to be the same as tensor parallel size. "
250
- "Overlap scheduler is disabled."
251
256
  )
252
257
 
253
258
  # Speculative Decoding
@@ -314,6 +319,7 @@ class ServerArgs:
314
319
  "dummy",
315
320
  "gguf",
316
321
  "bitsandbytes",
322
+ "layered",
317
323
  ],
318
324
  help="The format of the model weights to load. "
319
325
  '"auto" will try to load the weights in the safetensors format '
@@ -327,7 +333,10 @@ class ServerArgs:
327
333
  "which is mainly for profiling."
328
334
  '"gguf" will load the weights in the gguf format. '
329
335
  '"bitsandbytes" will load the weights using bitsandbytes '
330
- "quantization.",
336
+ "quantization."
337
+ '"layered" loads weights layer by layer so that one can quantize a '
338
+ "layer before loading another to make the peak memory envelope "
339
+ "smaller.",
331
340
  )
332
341
  parser.add_argument(
333
342
  "--trust-remote-code",
@@ -392,7 +401,7 @@ class ServerArgs:
392
401
  "--device",
393
402
  type=str,
394
403
  default="cuda",
395
- choices=["cuda", "xpu", "hpu"],
404
+ choices=["cuda", "xpu", "hpu", "cpu"],
396
405
  help="The device type.",
397
406
  )
398
407
  parser.add_argument(
@@ -492,6 +501,11 @@ class ServerArgs:
492
501
  default=ServerArgs.stream_interval,
493
502
  help="The interval (or buffer size) for streaming in terms of the token length. A smaller value makes streaming smoother, while a larger value makes the throughput higher",
494
503
  )
504
+ parser.add_argument(
505
+ "--stream-output",
506
+ action="store_true",
507
+ help="Whether to output as a sequence of disjoint segments.",
508
+ )
495
509
  parser.add_argument(
496
510
  "--random-seed",
497
511
  type=int,
@@ -860,6 +874,24 @@ class ServerArgs:
860
874
  action="store_true",
861
875
  help="Allow saving memory using release_memory_occupation and resume_memory_occupation",
862
876
  )
877
+ parser.add_argument(
878
+ "--allow-auto-truncate",
879
+ action="store_true",
880
+ help="Allow automatically truncating requests that exceed the maximum input length instead of returning an error.",
881
+ )
882
+ parser.add_argument(
883
+ "--enable-custom-logit-processor",
884
+ action="store_true",
885
+ help="Enable users to pass custom logit processors to the server (disabled by default for security)",
886
+ )
887
+ # Function Calling
888
+ parser.add_argument(
889
+ "--tool-call-parser",
890
+ type=str,
891
+ choices=["qwen25", "mistral", "llama3"],
892
+ default=ServerArgs.tool_call_parser,
893
+ help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', and 'llama3'.",
894
+ )
863
895
 
864
896
  @classmethod
865
897
  def from_cli_args(cls, args: argparse.Namespace):
@@ -870,7 +902,7 @@ class ServerArgs:
870
902
  return cls(**{attr: getattr(args, attr) for attr in attrs})
871
903
 
872
904
  def url(self):
873
- if is_ipv6(self.host):
905
+ if is_valid_ipv6_address(self.host):
874
906
  return f"http://[{self.host}]:{self.port}"
875
907
  else:
876
908
  return f"http://{self.host}:{self.port}"
@@ -880,8 +912,8 @@ class ServerArgs:
880
912
  self.tp_size % self.nnodes == 0
881
913
  ), "tp_size must be divisible by number of nodes"
882
914
  assert not (
883
- self.dp_size > 1 and self.nnodes != 1
884
- ), "multi-node data parallel is not supported"
915
+ self.dp_size > 1 and self.nnodes != 1 and not self.enable_dp_attention
916
+ ), "multi-node data parallel is not supported unless dp attention!"
885
917
  assert (
886
918
  self.max_loras_per_batch > 0
887
919
  # FIXME
@@ -919,6 +951,9 @@ def prepare_server_args(argv: List[str]) -> ServerArgs:
919
951
  return server_args
920
952
 
921
953
 
954
+ ZMQ_TCP_PORT_DELTA = 233
955
+
956
+
922
957
  @dataclasses.dataclass
923
958
  class PortArgs:
924
959
  # The ipc filename for tokenizer to receive inputs from detokenizer (zmq)
@@ -932,7 +967,7 @@ class PortArgs:
932
967
  nccl_port: int
933
968
 
934
969
  @staticmethod
935
- def init_new(server_args) -> "PortArgs":
970
+ def init_new(server_args, dp_rank: Optional[int] = None) -> "PortArgs":
936
971
  port = server_args.port + random.randint(100, 1000)
937
972
  while True:
938
973
  if is_port_available(port):
@@ -942,12 +977,39 @@ class PortArgs:
942
977
  else:
943
978
  port -= 43
944
979
 
945
- return PortArgs(
946
- tokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
947
- scheduler_input_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
948
- detokenizer_ipc_name=tempfile.NamedTemporaryFile(delete=False).name,
949
- nccl_port=port,
950
- )
980
+ if not server_args.enable_dp_attention:
981
+ # Normal case, use IPC within a single node
982
+ return PortArgs(
983
+ tokenizer_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
984
+ scheduler_input_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
985
+ detokenizer_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
986
+ nccl_port=port,
987
+ )
988
+ else:
989
+ # DP attention. Use TCP + port to handle both single-node and multi-node.
990
+ if server_args.nnodes == 1 and server_args.dist_init_addr is None:
991
+ dist_init_addr = ("127.0.0.1", server_args.port + ZMQ_TCP_PORT_DELTA)
992
+ else:
993
+ dist_init_addr = server_args.dist_init_addr.split(":")
994
+ assert (
995
+ len(dist_init_addr) == 2
996
+ ), "please provide --dist-init-addr as host:port of head node"
997
+
998
+ dist_init_host, dist_init_port = dist_init_addr
999
+ port_base = int(dist_init_port) + 1
1000
+ if dp_rank is None:
1001
+ scheduler_input_port = (
1002
+ port_base + 2
1003
+ ) # TokenizerManager to DataParallelController
1004
+ else:
1005
+ scheduler_input_port = port_base + 2 + 1 + dp_rank
1006
+
1007
+ return PortArgs(
1008
+ tokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base}",
1009
+ scheduler_input_ipc_name=f"tcp://{dist_init_host}:{scheduler_input_port}",
1010
+ detokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base + 1}",
1011
+ nccl_port=port,
1012
+ )
951
1013
 
952
1014
 
953
1015
  class LoRAPathAction(argparse.Action):
@@ -180,7 +180,6 @@ def generate_draft_decode_kv_indices(
180
180
  class EAGLEDraftInput(SpecInfo):
181
181
  def __init__(self):
182
182
  self.prev_mode = ForwardMode.DECODE
183
- self.sample_output = None
184
183
 
185
184
  self.scores: torch.Tensor = None
186
185
  self.score_list: List[torch.Tensor] = []
@@ -190,12 +189,16 @@ class EAGLEDraftInput(SpecInfo):
190
189
  self.cache_list: List[torch.Tenor] = []
191
190
  self.iter = 0
192
191
 
192
+ # shape: (b, hidden_size)
193
193
  self.hidden_states: torch.Tensor = None
194
+ # shape: (b,)
194
195
  self.verified_id: torch.Tensor = None
196
+ # shape: (b, vocab_size)
197
+ self.sample_output: torch.Tensor = None
198
+
195
199
  self.positions: torch.Tensor = None
196
200
  self.accept_length: torch.Tensor = None
197
- self.has_finished: bool = False
198
- self.unfinished_index: List[int] = None
201
+ self.accept_length_cpu: List[int] = None
199
202
 
200
203
  def load_server_args(self, server_args: ServerArgs):
201
204
  self.topk: int = server_args.speculative_eagle_topk
@@ -218,7 +221,7 @@ class EAGLEDraftInput(SpecInfo):
218
221
  :pre_len
219
222
  ] = req.prefix_indices
220
223
 
221
- batch.req_to_token_pool.req_to_token[req.req_pool_idx][pre_len:seq_len] = (
224
+ batch.req_to_token_pool.req_to_token[req.req_pool_idx, pre_len:seq_len] = (
222
225
  out_cache_loc[pt : pt + req.extend_input_len]
223
226
  )
224
227
 
@@ -228,6 +231,14 @@ class EAGLEDraftInput(SpecInfo):
228
231
  assert len(batch.extend_lens) == 1
229
232
  batch.input_ids = torch.concat((batch.input_ids[1:], self.verified_id))
230
233
 
234
+ def filter_batch(
235
+ self,
236
+ new_indices: torch.Tensor,
237
+ ):
238
+ self.sample_output = self.sample_output[: len(new_indices)]
239
+ self.hidden_states = self.hidden_states[: len(new_indices)]
240
+ self.verified_id = self.verified_id[: len(new_indices)]
241
+
231
242
  def prepare_for_decode(self, batch: ScheduleBatch):
232
243
  prob = self.sample_output # shape: (b * top_k, vocab) or (b, vocab)
233
244
  top = torch.topk(prob, self.topk, dim=-1)
@@ -287,7 +298,9 @@ class EAGLEDraftInput(SpecInfo):
287
298
  self.cache_list.append(batch.out_cache_loc)
288
299
  self.positions = (
289
300
  batch.seq_lens[:, None]
290
- + torch.ones([1, self.topk], device="cuda", dtype=torch.long) * self.iter
301
+ + torch.full(
302
+ [1, self.topk], fill_value=self.iter, device="cuda", dtype=torch.long
303
+ )
291
304
  ).flatten()
292
305
 
293
306
  bs = len(batch.seq_lens)
@@ -304,24 +317,25 @@ class EAGLEDraftInput(SpecInfo):
304
317
 
305
318
  def prepare_extend_after_decode(self, batch: ScheduleBatch):
306
319
  batch.out_cache_loc = batch.alloc_token_slots(self.verified_id.numel())
307
- batch.extend_lens = (self.accept_length + 1).tolist()
320
+ accept_length_cpu = batch.spec_info.accept_length_cpu
321
+ batch.extend_lens = [x + 1 for x in accept_length_cpu]
322
+ batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend
323
+ seq_lens_cpu = batch.seq_lens.tolist()
308
324
 
309
325
  pt = 0
310
- seq_lens = batch.seq_lens.tolist()
311
-
312
326
  i = 0
313
-
314
327
  for req in batch.reqs:
315
328
  if req.finished():
316
329
  continue
317
330
  # assert seq_len - pre_len == req.extend_input_len
318
- input_len = self.accept_length[i] + 1
319
- seq_len = seq_lens[i]
331
+ input_len = batch.extend_lens[i]
332
+ seq_len = seq_lens_cpu[i]
320
333
  batch.req_to_token_pool.req_to_token[req.req_pool_idx][
321
334
  seq_len - input_len : seq_len
322
335
  ] = batch.out_cache_loc[pt : pt + input_len]
323
336
  pt += input_len
324
337
  i += 1
338
+ assert pt == batch.out_cache_loc.shape[0]
325
339
 
326
340
  self.positions = torch.empty_like(self.verified_id)
327
341
  new_verified_id = torch.empty_like(self.accept_length, dtype=torch.long)
@@ -337,7 +351,7 @@ class EAGLEDraftInput(SpecInfo):
337
351
  triton.next_power_of_2(self.spec_steps + 1),
338
352
  )
339
353
 
340
- batch.seq_lens_sum = sum(batch.seq_lens)
354
+ batch.seq_lens_sum = sum(seq_lens_cpu)
341
355
  batch.input_ids = self.verified_id
342
356
  self.verified_id = new_verified_id
343
357
 
@@ -565,6 +579,8 @@ class EagleVerifyInput(SpecInfo):
565
579
  finished_extend_len = {} # {rid:accept_length + 1}
566
580
  accept_index_cpu = accept_index.tolist()
567
581
  predict_cpu = predict.tolist()
582
+ has_finished = False
583
+
568
584
  # iterate every accepted token and check if req has finished after append the token
569
585
  # should be checked BEFORE free kv cache slots
570
586
  for i, (req, accept_index_row) in enumerate(zip(batch.reqs, accept_index_cpu)):
@@ -578,7 +594,7 @@ class EagleVerifyInput(SpecInfo):
578
594
  finished_extend_len[req.rid] = j + 1
579
595
  req.check_finished()
580
596
  if req.finished():
581
- draft_input.has_finished = True
597
+ has_finished = True
582
598
  # set all tokens after finished token to -1 and break
583
599
  accept_index[i, j + 1 :] = -1
584
600
  break
@@ -587,12 +603,12 @@ class EagleVerifyInput(SpecInfo):
587
603
  if not req.finished():
588
604
  new_accept_index.extend(new_accept_index_)
589
605
  unfinished_index.append(i)
606
+ req.spec_verify_ct += 1
590
607
  accept_length = (accept_index != -1).sum(dim=1) - 1
591
608
 
592
609
  accept_index = accept_index[accept_index != -1]
593
610
  accept_length_cpu = accept_length.tolist()
594
611
  verified_id = predict[accept_index]
595
- verified_id_cpu = verified_id.tolist()
596
612
 
597
613
  evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
598
614
  evict_mask[accept_index] = False
@@ -614,7 +630,13 @@ class EagleVerifyInput(SpecInfo):
614
630
  draft_input.verified_id = predict[new_accept_index]
615
631
  draft_input.hidden_states = batch.spec_info.hidden_states[new_accept_index]
616
632
  draft_input.accept_length = accept_length[unfinished_index]
617
- draft_input.unfinished_index = unfinished_index
633
+ draft_input.accept_length_cpu = [
634
+ accept_length_cpu[i] for i in unfinished_index
635
+ ]
636
+ if has_finished:
637
+ draft_input.seq_lens_for_draft_extend = batch.seq_lens[unfinished_index]
638
+ else:
639
+ draft_input.seq_lens_for_draft_extend = batch.seq_lens
618
640
 
619
641
  logits_output.next_token_logits = logits_output.next_token_logits[accept_index]
620
642
  return (
@@ -13,6 +13,7 @@ from sglang.srt.model_executor.forward_batch_info import (
13
13
  from sglang.srt.model_executor.model_runner import ModelRunner
14
14
  from sglang.srt.server_args import ServerArgs
15
15
  from sglang.srt.speculative.eagle_utils import EAGLEDraftInput
16
+ from sglang.srt.utils import rank0_print
16
17
 
17
18
 
18
19
  class EAGLEWorker(TpModelWorker):
@@ -50,18 +51,18 @@ class EAGLEWorker(TpModelWorker):
50
51
 
51
52
  def forward_draft_decode(self, batch: ScheduleBatch):
52
53
  batch.spec_info.prepare_for_decode(batch)
54
+ batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
53
55
  model_worker_batch = batch.get_model_worker_batch()
54
56
  forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
55
- forward_batch.capture_hidden_mode = CaptureHiddenMode.LAST
56
57
  logits_output = self.model_runner.forward(forward_batch)
57
58
  self.capture_for_decode(logits_output, forward_batch)
58
59
 
59
60
  def forward_draft_extend(self, batch: ScheduleBatch):
60
61
  self._set_mem_pool(batch, self.model_runner)
61
62
  batch.spec_info.prepare_for_extend(batch)
63
+ batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
62
64
  model_worker_batch = batch.get_model_worker_batch()
63
65
  forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
64
- forward_batch.capture_hidden_mode = CaptureHiddenMode.LAST
65
66
  logits_output = self.model_runner.forward(forward_batch)
66
67
  self.capture_for_decode(logits_output, forward_batch)
67
68
  self._set_mem_pool(batch, self.target_worker.model_runner)
@@ -134,26 +135,23 @@ class EAGLEWorker(TpModelWorker):
134
135
  batch.req_to_token_pool = runner.req_to_token_pool
135
136
 
136
137
  def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
138
+ seq_lens_backup = batch.seq_lens
139
+
137
140
  self._set_mem_pool(batch, self.model_runner)
138
141
  batch.forward_mode = ForwardMode.DRAFT_EXTEND
139
- if batch.spec_info.has_finished:
140
- index = batch.spec_info.unfinished_index
141
- seq_lens = batch.seq_lens
142
- batch.seq_lens = batch.seq_lens[index]
143
-
144
142
  batch.spec_info.prepare_extend_after_decode(batch)
143
+ batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
145
144
  model_worker_batch = batch.get_model_worker_batch()
146
145
  forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
147
- forward_batch.capture_hidden_mode = CaptureHiddenMode.LAST
148
146
  logits_output = self.model_runner.forward(forward_batch)
149
-
150
- batch.spec_info.hidden_states = logits_output.hidden_states
151
147
  self.capture_for_decode(logits_output, forward_batch)
152
- batch.forward_mode = ForwardMode.DECODE
153
- if batch.spec_info.has_finished:
154
- batch.seq_lens = seq_lens
155
148
  self._set_mem_pool(batch, self.target_worker.model_runner)
156
149
 
150
+ # Restore backup.
151
+ # This is because `seq_lens` can be modified in `prepare_extend_after_decode`
152
+ batch.forward_mode = ForwardMode.DECODE
153
+ batch.seq_lens = seq_lens_backup
154
+
157
155
  def capture_for_decode(
158
156
  self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
159
157
  ):