sglang 0.4.1.post7__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (67) hide show
  1. sglang/bench_offline_throughput.py +17 -11
  2. sglang/bench_one_batch.py +14 -6
  3. sglang/bench_serving.py +47 -44
  4. sglang/lang/chat_template.py +31 -0
  5. sglang/srt/configs/load_config.py +1 -0
  6. sglang/srt/distributed/device_communicators/custom_all_reduce.py +5 -2
  7. sglang/srt/entrypoints/engine.py +5 -2
  8. sglang/srt/entrypoints/http_server.py +24 -0
  9. sglang/srt/function_call_parser.py +494 -0
  10. sglang/srt/layers/activation.py +5 -5
  11. sglang/srt/layers/dp_attention.py +3 -1
  12. sglang/srt/layers/layernorm.py +5 -5
  13. sglang/srt/layers/linear.py +24 -9
  14. sglang/srt/layers/logits_processor.py +1 -1
  15. sglang/srt/layers/moe/ep_moe/layer.py +20 -12
  16. sglang/srt/layers/moe/fused_moe_native.py +17 -3
  17. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  18. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +18 -1
  19. sglang/srt/layers/moe/fused_moe_triton/layer.py +9 -0
  20. sglang/srt/layers/parameter.py +16 -7
  21. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  22. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  23. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  24. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  25. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  26. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  27. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  28. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  29. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  30. sglang/srt/layers/quantization/fp8.py +4 -1
  31. sglang/srt/layers/rotary_embedding.py +6 -1
  32. sglang/srt/layers/sampler.py +28 -8
  33. sglang/srt/layers/torchao_utils.py +12 -6
  34. sglang/srt/managers/detokenizer_manager.py +1 -0
  35. sglang/srt/managers/io_struct.py +36 -5
  36. sglang/srt/managers/schedule_batch.py +31 -25
  37. sglang/srt/managers/scheduler.py +61 -35
  38. sglang/srt/managers/tokenizer_manager.py +4 -0
  39. sglang/srt/model_executor/cuda_graph_runner.py +23 -25
  40. sglang/srt/model_executor/forward_batch_info.py +5 -7
  41. sglang/srt/model_executor/model_runner.py +7 -4
  42. sglang/srt/model_loader/loader.py +75 -0
  43. sglang/srt/model_loader/weight_utils.py +91 -5
  44. sglang/srt/models/commandr.py +14 -2
  45. sglang/srt/models/dbrx.py +9 -1
  46. sglang/srt/models/deepseek_v2.py +3 -3
  47. sglang/srt/models/gemma2.py +9 -1
  48. sglang/srt/models/grok.py +1 -0
  49. sglang/srt/models/minicpm3.py +3 -3
  50. sglang/srt/models/torch_native_llama.py +17 -4
  51. sglang/srt/openai_api/adapter.py +139 -37
  52. sglang/srt/openai_api/protocol.py +5 -4
  53. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +11 -14
  54. sglang/srt/sampling/sampling_batch_info.py +4 -14
  55. sglang/srt/server.py +2 -2
  56. sglang/srt/server_args.py +20 -1
  57. sglang/srt/speculative/eagle_utils.py +37 -15
  58. sglang/srt/speculative/eagle_worker.py +11 -13
  59. sglang/srt/utils.py +62 -65
  60. sglang/test/test_programs.py +1 -0
  61. sglang/test/test_utils.py +81 -22
  62. sglang/version.py +1 -1
  63. {sglang-0.4.1.post7.dist-info → sglang-0.4.2.dist-info}/METADATA +7 -7
  64. {sglang-0.4.1.post7.dist-info → sglang-0.4.2.dist-info}/RECORD +67 -56
  65. {sglang-0.4.1.post7.dist-info → sglang-0.4.2.dist-info}/LICENSE +0 -0
  66. {sglang-0.4.1.post7.dist-info → sglang-0.4.2.dist-info}/WHEEL +0 -0
  67. {sglang-0.4.1.post7.dist-info → sglang-0.4.2.dist-info}/top_level.txt +0 -0
@@ -20,7 +20,7 @@ import os
20
20
  import time
21
21
  import uuid
22
22
  from http import HTTPStatus
23
- from typing import Dict, List
23
+ from typing import Dict, List, Optional
24
24
 
25
25
  from fastapi import HTTPException, Request, UploadFile
26
26
  from fastapi.responses import ORJSONResponse, StreamingResponse
@@ -40,6 +40,7 @@ from sglang.srt.conversation import (
40
40
  generate_chat_conv,
41
41
  register_conv_template,
42
42
  )
43
+ from sglang.srt.function_call_parser import TOOLS_TAG_LIST, FunctionCallParser
43
44
  from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput
44
45
  from sglang.srt.openai_api.protocol import (
45
46
  BatchRequest,
@@ -71,7 +72,6 @@ from sglang.srt.openai_api.protocol import (
71
72
  TopLogprob,
72
73
  UsageInfo,
73
74
  )
74
- from sglang.srt.utils import TOOLS_TAG_LIST, parse_tool_response
75
75
  from sglang.utils import get_exception_traceback
76
76
 
77
77
  logger = logging.getLogger(__name__)
@@ -309,6 +309,7 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
309
309
  ret,
310
310
  to_file=True,
311
311
  cache_report=tokenizer_manager.server_args.enable_cache_report,
312
+ tool_call_parser=tokenizer_manager.server_args.tool_call_parser,
312
313
  )
313
314
  else:
314
315
  responses = v1_generate_response(
@@ -877,9 +878,6 @@ def v1_chat_generate_request(
877
878
  tools = None
878
879
  if request.tools and request.tool_choice != "none":
879
880
  request.skip_special_tokens = False
880
- if request.stream:
881
- logger.warning("Streaming is not supported with tools.")
882
- request.stream = False
883
881
  if not isinstance(request.tool_choice, str):
884
882
  tools = [
885
883
  item.function.model_dump()
@@ -908,12 +906,26 @@ def v1_chat_generate_request(
908
906
  openai_compatible_messages = openai_compatible_messages[:-1]
909
907
  else:
910
908
  assistant_prefix = None
911
- prompt_ids = tokenizer_manager.tokenizer.apply_chat_template(
912
- openai_compatible_messages,
913
- tokenize=True,
914
- add_generation_prompt=True,
915
- tools=tools,
916
- )
909
+
910
+ try:
911
+ prompt_ids = tokenizer_manager.tokenizer.apply_chat_template(
912
+ openai_compatible_messages,
913
+ tokenize=True,
914
+ add_generation_prompt=True,
915
+ tools=tools,
916
+ )
917
+ except:
918
+ # This except branch will be triggered when the chosen model
919
+ # has a different tools input format that is not compatiable
920
+ # with openAI's apply_chat_template tool_call format, like Mistral.
921
+ tools = [t if "function" in t else {"function": t} for t in tools]
922
+ prompt_ids = tokenizer_manager.tokenizer.apply_chat_template(
923
+ openai_compatible_messages,
924
+ tokenize=True,
925
+ add_generation_prompt=True,
926
+ tools=tools,
927
+ )
928
+
917
929
  if assistant_prefix:
918
930
  prompt_ids += tokenizer_manager.tokenizer.encode(assistant_prefix)
919
931
  stop = request.stop
@@ -1005,7 +1017,9 @@ def v1_chat_generate_request(
1005
1017
  return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0]
1006
1018
 
1007
1019
 
1008
- def v1_chat_generate_response(request, ret, to_file=False, cache_report=False):
1020
+ def v1_chat_generate_response(
1021
+ request, ret, to_file=False, cache_report=False, tool_call_parser=None
1022
+ ):
1009
1023
  choices = []
1010
1024
 
1011
1025
  for idx, ret_item in enumerate(ret):
@@ -1066,12 +1080,13 @@ def v1_chat_generate_response(request, ret, to_file=False, cache_report=False):
1066
1080
  if finish_reason == "stop":
1067
1081
  finish_reason = "tool_calls"
1068
1082
  try:
1069
- text, call_info_list = parse_tool_response(text, tools) # noqa
1083
+ parser = FunctionCallParser(tools, tool_call_parser)
1084
+ full_normal_text, call_info_list = parser.parse_non_stream(text)
1070
1085
  tool_calls = [
1071
1086
  ToolCall(
1072
- id=str(call_info[0]),
1087
+ id=str(call_info.tool_index),
1073
1088
  function=FunctionResponse(
1074
- name=call_info[1], arguments=call_info[2]
1089
+ name=call_info.name, arguments=call_info.parameters
1075
1090
  ),
1076
1091
  )
1077
1092
  for call_info in call_info_list
@@ -1172,6 +1187,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
1172
1187
  adapted_request, request = v1_chat_generate_request(all_requests, tokenizer_manager)
1173
1188
 
1174
1189
  if adapted_request.stream:
1190
+ parser_dict = {}
1175
1191
 
1176
1192
  async def generate_stream_resp():
1177
1193
  is_firsts = {}
@@ -1184,6 +1200,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
1184
1200
  adapted_request, raw_request
1185
1201
  ):
1186
1202
  index = content.get("index", 0)
1203
+ text = content["text"]
1187
1204
 
1188
1205
  is_first = is_firsts.get(index, True)
1189
1206
  stream_buffer = stream_buffers.get(index, "")
@@ -1263,29 +1280,111 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
1263
1280
 
1264
1281
  text = content["text"]
1265
1282
  delta = text[len(stream_buffer) :]
1266
- stream_buffer = stream_buffer + delta
1267
- choice_data = ChatCompletionResponseStreamChoice(
1268
- index=index,
1269
- delta=DeltaMessage(content=delta),
1270
- finish_reason=(finish_reason["type"] if finish_reason else ""),
1271
- matched_stop=(
1272
- finish_reason["matched"]
1273
- if finish_reason and "matched" in finish_reason
1274
- else None
1275
- ),
1276
- logprobs=choice_logprobs,
1277
- )
1278
- chunk = ChatCompletionStreamResponse(
1279
- id=content["meta_info"]["id"],
1280
- choices=[choice_data],
1281
- model=request.model,
1282
- )
1283
+ new_stream_buffer = stream_buffer + delta
1283
1284
 
1284
- is_firsts[index] = is_first
1285
- stream_buffers[index] = stream_buffer
1286
- n_prev_tokens[index] = n_prev_token
1285
+ if request.tool_choice != "none" and request.tools:
1286
+ if index not in parser_dict:
1287
+ parser_dict[index] = FunctionCallParser(
1288
+ tools=request.tools,
1289
+ tool_call_parser=tokenizer_manager.server_args.tool_call_parser,
1290
+ )
1291
+ parser = parser_dict[index]
1292
+
1293
+ # parse_increment => returns (normal_text, calls)
1294
+ normal_text, calls = parser.parse_stream_chunk(delta)
1295
+
1296
+ # 1) if there's normal_text, output it as normal content
1297
+ if normal_text:
1298
+ choice_data = ChatCompletionResponseStreamChoice(
1299
+ index=index,
1300
+ delta=DeltaMessage(content=normal_text),
1301
+ finish_reason=(
1302
+ finish_reason["type"] if finish_reason else ""
1303
+ ),
1304
+ )
1305
+ chunk = ChatCompletionStreamResponse(
1306
+ id=content["meta_info"]["id"],
1307
+ choices=[choice_data],
1308
+ model=request.model,
1309
+ )
1310
+ yield f"data: {chunk.model_dump_json()}\n\n"
1311
+
1312
+ # 2) if we found calls, we output them as separate chunk(s)
1313
+ for call_item in calls:
1314
+ # transform call_item -> FunctionResponse + ToolCall
1315
+
1316
+ if (
1317
+ content["meta_info"]["finish_reason"]
1318
+ and content["meta_info"]["finish_reason"]["type"]
1319
+ == "stop"
1320
+ ):
1321
+ latest_delta_len = 0
1322
+ if isinstance(call_item.parameters, str):
1323
+ latest_delta_len = len(call_item.parameters)
1324
+
1325
+ expected_call = json.dumps(
1326
+ parser.multi_format_parser.detectors[0]
1327
+ .prev_tool_call_arr[index]
1328
+ .get("arguments", {}),
1329
+ ensure_ascii=False,
1330
+ )
1331
+ actual_call = parser.multi_format_parser.detectors[
1332
+ 0
1333
+ ].streamed_args_for_tool[index]
1334
+ if latest_delta_len > 0:
1335
+ actual_call = actual_call[:-latest_delta_len]
1336
+ remaining_call = expected_call.replace(
1337
+ actual_call, "", 1
1338
+ )
1339
+ call_item.parameters = remaining_call
1340
+
1341
+ tool_call = ToolCall(
1342
+ id=str(call_item.tool_index),
1343
+ function=FunctionResponse(
1344
+ name=call_item.name,
1345
+ arguments=call_item.parameters,
1346
+ ),
1347
+ )
1348
+ choice_data = ChatCompletionResponseStreamChoice(
1349
+ index=index,
1350
+ delta=DeltaMessage(
1351
+ role="assistant", tool_calls=[tool_call]
1352
+ ),
1353
+ finish_reason="tool_call",
1354
+ )
1355
+ chunk = ChatCompletionStreamResponse(
1356
+ id=content["meta_info"]["id"],
1357
+ choices=[choice_data],
1358
+ model=request.model,
1359
+ )
1360
+ yield f"data: {chunk.model_dump_json()}\n\n"
1287
1361
 
1288
- yield f"data: {chunk.model_dump_json()}\n\n"
1362
+ stream_buffers[index] = new_stream_buffer
1363
+ is_firsts[index] = is_first
1364
+
1365
+ else:
1366
+ # No tool calls => just treat this as normal text
1367
+ choice_data = ChatCompletionResponseStreamChoice(
1368
+ index=index,
1369
+ delta=DeltaMessage(content=delta),
1370
+ finish_reason=(
1371
+ finish_reason["type"] if finish_reason else ""
1372
+ ),
1373
+ matched_stop=(
1374
+ finish_reason["matched"]
1375
+ if finish_reason and "matched" in finish_reason
1376
+ else None
1377
+ ),
1378
+ logprobs=choice_logprobs,
1379
+ )
1380
+ chunk = ChatCompletionStreamResponse(
1381
+ id=content["meta_info"]["id"],
1382
+ choices=[choice_data],
1383
+ model=request.model,
1384
+ )
1385
+ yield f"data: {chunk.model_dump_json()}\n\n"
1386
+ stream_buffers[index] = new_stream_buffer
1387
+ is_firsts[index] = is_first
1289
1388
  if request.stream_options and request.stream_options.include_usage:
1290
1389
  total_prompt_tokens = sum(
1291
1390
  tokens
@@ -1333,7 +1432,10 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
1333
1432
  ret = [ret]
1334
1433
 
1335
1434
  response = v1_chat_generate_response(
1336
- request, ret, cache_report=tokenizer_manager.server_args.enable_cache_report
1435
+ request,
1436
+ ret,
1437
+ cache_report=tokenizer_manager.server_args.enable_cache_report,
1438
+ tool_call_parser=tokenizer_manager.server_args.tool_call_parser,
1337
1439
  )
1338
1440
 
1339
1441
  return response
@@ -262,7 +262,7 @@ class Function(BaseModel):
262
262
  """Function descriptions."""
263
263
 
264
264
  description: Optional[str] = Field(default=None, examples=[None])
265
- name: str
265
+ name: Optional[str] = None
266
266
  parameters: Optional[object] = None
267
267
 
268
268
 
@@ -276,7 +276,7 @@ class Tool(BaseModel):
276
276
  class ToolChoiceFuncName(BaseModel):
277
277
  """The name of tool choice function."""
278
278
 
279
- name: str
279
+ name: Optional[str] = None
280
280
 
281
281
 
282
282
  class ToolChoice(BaseModel):
@@ -329,8 +329,8 @@ class ChatCompletionRequest(BaseModel):
329
329
  class FunctionResponse(BaseModel):
330
330
  """Function response."""
331
331
 
332
- name: str
333
- arguments: str
332
+ name: Optional[str] = None
333
+ arguments: Optional[str] = None
334
334
 
335
335
 
336
336
  class ToolCall(BaseModel):
@@ -367,6 +367,7 @@ class ChatCompletionResponse(BaseModel):
367
367
  class DeltaMessage(BaseModel):
368
368
  role: Optional[str] = None
369
369
  content: Optional[str] = None
370
+ tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
370
371
 
371
372
 
372
373
  class ChatCompletionResponseStreamChoice(BaseModel):
@@ -3,11 +3,16 @@ from typing import List
3
3
  import torch
4
4
 
5
5
  from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
6
- from sglang.srt.utils import is_cuda_available
6
+ from sglang.srt.utils import get_compiler_backend
7
7
 
8
- is_cuda = is_cuda_available()
9
- if is_cuda:
10
- from sgl_kernel import sampling_scaling_penalties
8
+
9
+ @torch.compile(dynamic=True, backend=get_compiler_backend())
10
+ def apply_scaling_penalties(logits, scaling_penalties):
11
+ logits[:] = torch.where(
12
+ logits > 0,
13
+ logits / scaling_penalties,
14
+ logits * scaling_penalties,
15
+ )
11
16
 
12
17
 
13
18
  class BatchedRepetitionPenalizer(_BatchedPenalizer):
@@ -61,16 +66,8 @@ class BatchedRepetitionPenalizer(_BatchedPenalizer):
61
66
  self.cumulated_repetition_penalties[mask] = self.repetition_penalties[mask]
62
67
 
63
68
  def _apply(self, logits: torch.Tensor) -> torch.Tensor:
64
- if is_cuda:
65
- return sampling_scaling_penalties(
66
- logits, self.cumulated_repetition_penalties
67
- )
68
- else:
69
- return torch.where(
70
- logits > 0,
71
- logits / self.cumulated_repetition_penalties,
72
- logits * self.cumulated_repetition_penalties,
73
- )
69
+ apply_scaling_penalties(logits, self.cumulated_repetition_penalties)
70
+ return logits
74
71
 
75
72
  def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
76
73
  self.repetition_penalties = self.repetition_penalties[indices_tensor_to_keep]
@@ -7,14 +7,11 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
7
7
 
8
8
  import torch
9
9
 
10
- from sglang.srt.utils import is_cuda_available
11
-
12
- is_cuda = is_cuda_available()
13
- if is_cuda:
14
- from sgl_kernel import sampling_scaling_penalties
15
-
16
10
  import sglang.srt.sampling.penaltylib as penaltylib
17
11
  from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
12
+ from sglang.srt.sampling.penaltylib.penalizers.repetition_penalty import (
13
+ apply_scaling_penalties,
14
+ )
18
15
 
19
16
  logger = logging.getLogger(__name__)
20
17
 
@@ -386,14 +383,7 @@ class SamplingBatchInfo:
386
383
 
387
384
  # repetition
388
385
  if self.scaling_penalties is not None:
389
- if is_cuda:
390
- logits[:] = sampling_scaling_penalties(logits, self.scaling_penalties)
391
- else:
392
- logits[:] = torch.where(
393
- logits > 0,
394
- logits / self.scaling_penalties,
395
- logits * self.scaling_penalties,
396
- )
386
+ apply_scaling_penalties(logits, self.scaling_penalties)
397
387
 
398
388
  # Apply regex vocab_mask
399
389
  if self.vocab_mask is not None:
sglang/srt/server.py CHANGED
@@ -12,7 +12,7 @@
12
12
  # limitations under the License.
13
13
  # ==============================================================================
14
14
 
15
- # Some shortcuts for backward compatbility.
15
+ # Some shortcuts for backward compatibility.
16
16
  # They will be removed in new versions.
17
17
  from sglang.srt.entrypoints.engine import Engine
18
- from sglang.srt.entrypoints.http_server import launch_server
18
+ from sglang.srt.entrypoints.http_server import kill_process_tree, launch_server
sglang/srt/server_args.py CHANGED
@@ -75,6 +75,7 @@ class ServerArgs:
75
75
  # Other runtime options
76
76
  tp_size: int = 1
77
77
  stream_interval: int = 1
78
+ stream_output: bool = False
78
79
  random_seed: Optional[int] = None
79
80
  constrained_json_whitespace_pattern: Optional[str] = None
80
81
  watchdog_timeout: float = 300
@@ -161,6 +162,7 @@ class ServerArgs:
161
162
 
162
163
  # Custom logit processor
163
164
  enable_custom_logit_processor: bool = False
165
+ tool_call_parser: str = None
164
166
 
165
167
  def __post_init__(self):
166
168
  # Set missing default values
@@ -317,6 +319,7 @@ class ServerArgs:
317
319
  "dummy",
318
320
  "gguf",
319
321
  "bitsandbytes",
322
+ "layered",
320
323
  ],
321
324
  help="The format of the model weights to load. "
322
325
  '"auto" will try to load the weights in the safetensors format '
@@ -330,7 +333,10 @@ class ServerArgs:
330
333
  "which is mainly for profiling."
331
334
  '"gguf" will load the weights in the gguf format. '
332
335
  '"bitsandbytes" will load the weights using bitsandbytes '
333
- "quantization.",
336
+ "quantization."
337
+ '"layered" loads weights layer by layer so that one can quantize a '
338
+ "layer before loading another to make the peak memory envelope "
339
+ "smaller.",
334
340
  )
335
341
  parser.add_argument(
336
342
  "--trust-remote-code",
@@ -495,6 +501,11 @@ class ServerArgs:
495
501
  default=ServerArgs.stream_interval,
496
502
  help="The interval (or buffer size) for streaming in terms of the token length. A smaller value makes streaming smoother, while a larger value makes the throughput higher",
497
503
  )
504
+ parser.add_argument(
505
+ "--stream-output",
506
+ action="store_true",
507
+ help="Whether to output as a sequence of disjoint segments.",
508
+ )
498
509
  parser.add_argument(
499
510
  "--random-seed",
500
511
  type=int,
@@ -873,6 +884,14 @@ class ServerArgs:
873
884
  action="store_true",
874
885
  help="Enable users to pass custom logit processors to the server (disabled by default for security)",
875
886
  )
887
+ # Function Calling
888
+ parser.add_argument(
889
+ "--tool-call-parser",
890
+ type=str,
891
+ choices=["qwen25", "mistral", "llama3"],
892
+ default=ServerArgs.tool_call_parser,
893
+ help="Specify the parser for handling tool-call interactions. Options include: 'qwen25', 'mistral', and 'llama3'.",
894
+ )
876
895
 
877
896
  @classmethod
878
897
  def from_cli_args(cls, args: argparse.Namespace):
@@ -180,7 +180,6 @@ def generate_draft_decode_kv_indices(
180
180
  class EAGLEDraftInput(SpecInfo):
181
181
  def __init__(self):
182
182
  self.prev_mode = ForwardMode.DECODE
183
- self.sample_output = None
184
183
 
185
184
  self.scores: torch.Tensor = None
186
185
  self.score_list: List[torch.Tensor] = []
@@ -190,12 +189,16 @@ class EAGLEDraftInput(SpecInfo):
190
189
  self.cache_list: List[torch.Tenor] = []
191
190
  self.iter = 0
192
191
 
192
+ # shape: (b, hidden_size)
193
193
  self.hidden_states: torch.Tensor = None
194
+ # shape: (b,)
194
195
  self.verified_id: torch.Tensor = None
196
+ # shape: (b, vocab_size)
197
+ self.sample_output: torch.Tensor = None
198
+
195
199
  self.positions: torch.Tensor = None
196
200
  self.accept_length: torch.Tensor = None
197
- self.has_finished: bool = False
198
- self.unfinished_index: List[int] = None
201
+ self.accept_length_cpu: List[int] = None
199
202
 
200
203
  def load_server_args(self, server_args: ServerArgs):
201
204
  self.topk: int = server_args.speculative_eagle_topk
@@ -218,7 +221,7 @@ class EAGLEDraftInput(SpecInfo):
218
221
  :pre_len
219
222
  ] = req.prefix_indices
220
223
 
221
- batch.req_to_token_pool.req_to_token[req.req_pool_idx][pre_len:seq_len] = (
224
+ batch.req_to_token_pool.req_to_token[req.req_pool_idx, pre_len:seq_len] = (
222
225
  out_cache_loc[pt : pt + req.extend_input_len]
223
226
  )
224
227
 
@@ -228,6 +231,14 @@ class EAGLEDraftInput(SpecInfo):
228
231
  assert len(batch.extend_lens) == 1
229
232
  batch.input_ids = torch.concat((batch.input_ids[1:], self.verified_id))
230
233
 
234
+ def filter_batch(
235
+ self,
236
+ new_indices: torch.Tensor,
237
+ ):
238
+ self.sample_output = self.sample_output[: len(new_indices)]
239
+ self.hidden_states = self.hidden_states[: len(new_indices)]
240
+ self.verified_id = self.verified_id[: len(new_indices)]
241
+
231
242
  def prepare_for_decode(self, batch: ScheduleBatch):
232
243
  prob = self.sample_output # shape: (b * top_k, vocab) or (b, vocab)
233
244
  top = torch.topk(prob, self.topk, dim=-1)
@@ -287,7 +298,9 @@ class EAGLEDraftInput(SpecInfo):
287
298
  self.cache_list.append(batch.out_cache_loc)
288
299
  self.positions = (
289
300
  batch.seq_lens[:, None]
290
- + torch.ones([1, self.topk], device="cuda", dtype=torch.long) * self.iter
301
+ + torch.full(
302
+ [1, self.topk], fill_value=self.iter, device="cuda", dtype=torch.long
303
+ )
291
304
  ).flatten()
292
305
 
293
306
  bs = len(batch.seq_lens)
@@ -304,24 +317,25 @@ class EAGLEDraftInput(SpecInfo):
304
317
 
305
318
  def prepare_extend_after_decode(self, batch: ScheduleBatch):
306
319
  batch.out_cache_loc = batch.alloc_token_slots(self.verified_id.numel())
307
- batch.extend_lens = (self.accept_length + 1).tolist()
320
+ accept_length_cpu = batch.spec_info.accept_length_cpu
321
+ batch.extend_lens = [x + 1 for x in accept_length_cpu]
322
+ batch.seq_lens = batch.spec_info.seq_lens_for_draft_extend
323
+ seq_lens_cpu = batch.seq_lens.tolist()
308
324
 
309
325
  pt = 0
310
- seq_lens = batch.seq_lens.tolist()
311
-
312
326
  i = 0
313
-
314
327
  for req in batch.reqs:
315
328
  if req.finished():
316
329
  continue
317
330
  # assert seq_len - pre_len == req.extend_input_len
318
- input_len = self.accept_length[i] + 1
319
- seq_len = seq_lens[i]
331
+ input_len = batch.extend_lens[i]
332
+ seq_len = seq_lens_cpu[i]
320
333
  batch.req_to_token_pool.req_to_token[req.req_pool_idx][
321
334
  seq_len - input_len : seq_len
322
335
  ] = batch.out_cache_loc[pt : pt + input_len]
323
336
  pt += input_len
324
337
  i += 1
338
+ assert pt == batch.out_cache_loc.shape[0]
325
339
 
326
340
  self.positions = torch.empty_like(self.verified_id)
327
341
  new_verified_id = torch.empty_like(self.accept_length, dtype=torch.long)
@@ -337,7 +351,7 @@ class EAGLEDraftInput(SpecInfo):
337
351
  triton.next_power_of_2(self.spec_steps + 1),
338
352
  )
339
353
 
340
- batch.seq_lens_sum = sum(batch.seq_lens)
354
+ batch.seq_lens_sum = sum(seq_lens_cpu)
341
355
  batch.input_ids = self.verified_id
342
356
  self.verified_id = new_verified_id
343
357
 
@@ -565,6 +579,8 @@ class EagleVerifyInput(SpecInfo):
565
579
  finished_extend_len = {} # {rid:accept_length + 1}
566
580
  accept_index_cpu = accept_index.tolist()
567
581
  predict_cpu = predict.tolist()
582
+ has_finished = False
583
+
568
584
  # iterate every accepted token and check if req has finished after append the token
569
585
  # should be checked BEFORE free kv cache slots
570
586
  for i, (req, accept_index_row) in enumerate(zip(batch.reqs, accept_index_cpu)):
@@ -578,7 +594,7 @@ class EagleVerifyInput(SpecInfo):
578
594
  finished_extend_len[req.rid] = j + 1
579
595
  req.check_finished()
580
596
  if req.finished():
581
- draft_input.has_finished = True
597
+ has_finished = True
582
598
  # set all tokens after finished token to -1 and break
583
599
  accept_index[i, j + 1 :] = -1
584
600
  break
@@ -587,12 +603,12 @@ class EagleVerifyInput(SpecInfo):
587
603
  if not req.finished():
588
604
  new_accept_index.extend(new_accept_index_)
589
605
  unfinished_index.append(i)
606
+ req.spec_verify_ct += 1
590
607
  accept_length = (accept_index != -1).sum(dim=1) - 1
591
608
 
592
609
  accept_index = accept_index[accept_index != -1]
593
610
  accept_length_cpu = accept_length.tolist()
594
611
  verified_id = predict[accept_index]
595
- verified_id_cpu = verified_id.tolist()
596
612
 
597
613
  evict_mask = torch.full_like(self.draft_token, True, dtype=torch.bool)
598
614
  evict_mask[accept_index] = False
@@ -614,7 +630,13 @@ class EagleVerifyInput(SpecInfo):
614
630
  draft_input.verified_id = predict[new_accept_index]
615
631
  draft_input.hidden_states = batch.spec_info.hidden_states[new_accept_index]
616
632
  draft_input.accept_length = accept_length[unfinished_index]
617
- draft_input.unfinished_index = unfinished_index
633
+ draft_input.accept_length_cpu = [
634
+ accept_length_cpu[i] for i in unfinished_index
635
+ ]
636
+ if has_finished:
637
+ draft_input.seq_lens_for_draft_extend = batch.seq_lens[unfinished_index]
638
+ else:
639
+ draft_input.seq_lens_for_draft_extend = batch.seq_lens
618
640
 
619
641
  logits_output.next_token_logits = logits_output.next_token_logits[accept_index]
620
642
  return (
@@ -13,6 +13,7 @@ from sglang.srt.model_executor.forward_batch_info import (
13
13
  from sglang.srt.model_executor.model_runner import ModelRunner
14
14
  from sglang.srt.server_args import ServerArgs
15
15
  from sglang.srt.speculative.eagle_utils import EAGLEDraftInput
16
+ from sglang.srt.utils import rank0_print
16
17
 
17
18
 
18
19
  class EAGLEWorker(TpModelWorker):
@@ -50,18 +51,18 @@ class EAGLEWorker(TpModelWorker):
50
51
 
51
52
  def forward_draft_decode(self, batch: ScheduleBatch):
52
53
  batch.spec_info.prepare_for_decode(batch)
54
+ batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
53
55
  model_worker_batch = batch.get_model_worker_batch()
54
56
  forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
55
- forward_batch.capture_hidden_mode = CaptureHiddenMode.LAST
56
57
  logits_output = self.model_runner.forward(forward_batch)
57
58
  self.capture_for_decode(logits_output, forward_batch)
58
59
 
59
60
  def forward_draft_extend(self, batch: ScheduleBatch):
60
61
  self._set_mem_pool(batch, self.model_runner)
61
62
  batch.spec_info.prepare_for_extend(batch)
63
+ batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
62
64
  model_worker_batch = batch.get_model_worker_batch()
63
65
  forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
64
- forward_batch.capture_hidden_mode = CaptureHiddenMode.LAST
65
66
  logits_output = self.model_runner.forward(forward_batch)
66
67
  self.capture_for_decode(logits_output, forward_batch)
67
68
  self._set_mem_pool(batch, self.target_worker.model_runner)
@@ -134,26 +135,23 @@ class EAGLEWorker(TpModelWorker):
134
135
  batch.req_to_token_pool = runner.req_to_token_pool
135
136
 
136
137
  def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
138
+ seq_lens_backup = batch.seq_lens
139
+
137
140
  self._set_mem_pool(batch, self.model_runner)
138
141
  batch.forward_mode = ForwardMode.DRAFT_EXTEND
139
- if batch.spec_info.has_finished:
140
- index = batch.spec_info.unfinished_index
141
- seq_lens = batch.seq_lens
142
- batch.seq_lens = batch.seq_lens[index]
143
-
144
142
  batch.spec_info.prepare_extend_after_decode(batch)
143
+ batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
145
144
  model_worker_batch = batch.get_model_worker_batch()
146
145
  forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
147
- forward_batch.capture_hidden_mode = CaptureHiddenMode.LAST
148
146
  logits_output = self.model_runner.forward(forward_batch)
149
-
150
- batch.spec_info.hidden_states = logits_output.hidden_states
151
147
  self.capture_for_decode(logits_output, forward_batch)
152
- batch.forward_mode = ForwardMode.DECODE
153
- if batch.spec_info.has_finished:
154
- batch.seq_lens = seq_lens
155
148
  self._set_mem_pool(batch, self.target_worker.model_runner)
156
149
 
150
+ # Restore backup.
151
+ # This is because `seq_lens` can be modified in `prepare_extend_after_decode`
152
+ batch.forward_mode = ForwardMode.DECODE
153
+ batch.seq_lens = seq_lens_backup
154
+
157
155
  def capture_for_decode(
158
156
  self, logits_output: LogitsProcessorOutput, forward_batch: ForwardBatch
159
157
  ):