sglang 0.4.1.post6__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (141) hide show
  1. sglang/__init__.py +21 -23
  2. sglang/api.py +2 -7
  3. sglang/bench_offline_throughput.py +41 -27
  4. sglang/bench_one_batch.py +60 -4
  5. sglang/bench_one_batch_server.py +1 -1
  6. sglang/bench_serving.py +83 -71
  7. sglang/lang/backend/runtime_endpoint.py +183 -4
  8. sglang/lang/chat_template.py +46 -4
  9. sglang/launch_server.py +1 -1
  10. sglang/srt/_custom_ops.py +80 -42
  11. sglang/srt/configs/device_config.py +1 -1
  12. sglang/srt/configs/load_config.py +1 -0
  13. sglang/srt/configs/model_config.py +1 -0
  14. sglang/srt/constrained/base_grammar_backend.py +21 -0
  15. sglang/srt/constrained/xgrammar_backend.py +8 -4
  16. sglang/srt/conversation.py +14 -1
  17. sglang/srt/distributed/__init__.py +3 -3
  18. sglang/srt/distributed/communication_op.py +2 -1
  19. sglang/srt/distributed/device_communicators/cuda_wrapper.py +2 -1
  20. sglang/srt/distributed/device_communicators/custom_all_reduce.py +112 -42
  21. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
  22. sglang/srt/distributed/device_communicators/hpu_communicator.py +2 -1
  23. sglang/srt/distributed/device_communicators/pynccl.py +80 -1
  24. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +112 -2
  25. sglang/srt/distributed/device_communicators/shm_broadcast.py +5 -72
  26. sglang/srt/distributed/device_communicators/xpu_communicator.py +2 -1
  27. sglang/srt/distributed/parallel_state.py +1 -1
  28. sglang/srt/distributed/utils.py +2 -1
  29. sglang/srt/entrypoints/engine.py +452 -0
  30. sglang/srt/entrypoints/http_server.py +603 -0
  31. sglang/srt/function_call_parser.py +494 -0
  32. sglang/srt/layers/activation.py +8 -8
  33. sglang/srt/layers/attention/flashinfer_backend.py +10 -9
  34. sglang/srt/layers/attention/triton_backend.py +4 -6
  35. sglang/srt/layers/attention/vision.py +204 -0
  36. sglang/srt/layers/dp_attention.py +71 -0
  37. sglang/srt/layers/layernorm.py +5 -5
  38. sglang/srt/layers/linear.py +65 -14
  39. sglang/srt/layers/logits_processor.py +49 -64
  40. sglang/srt/layers/moe/ep_moe/layer.py +24 -16
  41. sglang/srt/layers/moe/fused_moe_native.py +84 -1
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  43. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +27 -7
  44. sglang/srt/layers/moe/fused_moe_triton/layer.py +38 -5
  45. sglang/srt/layers/parameter.py +18 -8
  46. sglang/srt/layers/quantization/__init__.py +20 -23
  47. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  48. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  49. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  50. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  51. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  52. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  53. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  54. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  55. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  56. sglang/srt/layers/quantization/fp8.py +10 -4
  57. sglang/srt/layers/quantization/modelopt_quant.py +1 -2
  58. sglang/srt/layers/quantization/w8a8_int8.py +1 -1
  59. sglang/srt/layers/radix_attention.py +2 -2
  60. sglang/srt/layers/rotary_embedding.py +1184 -31
  61. sglang/srt/layers/sampler.py +64 -6
  62. sglang/srt/layers/torchao_utils.py +12 -6
  63. sglang/srt/layers/vocab_parallel_embedding.py +2 -2
  64. sglang/srt/lora/lora.py +1 -9
  65. sglang/srt/managers/configure_logging.py +3 -0
  66. sglang/srt/managers/data_parallel_controller.py +79 -72
  67. sglang/srt/managers/detokenizer_manager.py +24 -6
  68. sglang/srt/managers/image_processor.py +158 -2
  69. sglang/srt/managers/io_struct.py +57 -3
  70. sglang/srt/managers/schedule_batch.py +78 -45
  71. sglang/srt/managers/schedule_policy.py +26 -12
  72. sglang/srt/managers/scheduler.py +326 -201
  73. sglang/srt/managers/session_controller.py +1 -0
  74. sglang/srt/managers/tokenizer_manager.py +210 -121
  75. sglang/srt/managers/tp_worker.py +6 -4
  76. sglang/srt/managers/tp_worker_overlap_thread.py +5 -8
  77. sglang/srt/managers/utils.py +44 -0
  78. sglang/srt/mem_cache/memory_pool.py +10 -32
  79. sglang/srt/metrics/collector.py +15 -6
  80. sglang/srt/model_executor/cuda_graph_runner.py +26 -30
  81. sglang/srt/model_executor/forward_batch_info.py +5 -7
  82. sglang/srt/model_executor/model_runner.py +44 -19
  83. sglang/srt/model_loader/loader.py +83 -6
  84. sglang/srt/model_loader/weight_utils.py +145 -6
  85. sglang/srt/models/baichuan.py +6 -6
  86. sglang/srt/models/chatglm.py +2 -2
  87. sglang/srt/models/commandr.py +17 -5
  88. sglang/srt/models/dbrx.py +13 -5
  89. sglang/srt/models/deepseek.py +3 -3
  90. sglang/srt/models/deepseek_v2.py +11 -11
  91. sglang/srt/models/exaone.py +2 -2
  92. sglang/srt/models/gemma.py +2 -2
  93. sglang/srt/models/gemma2.py +15 -25
  94. sglang/srt/models/gpt2.py +3 -5
  95. sglang/srt/models/gpt_bigcode.py +1 -1
  96. sglang/srt/models/granite.py +2 -2
  97. sglang/srt/models/grok.py +4 -3
  98. sglang/srt/models/internlm2.py +2 -2
  99. sglang/srt/models/llama.py +7 -5
  100. sglang/srt/models/minicpm.py +2 -2
  101. sglang/srt/models/minicpm3.py +9 -9
  102. sglang/srt/models/minicpmv.py +1238 -0
  103. sglang/srt/models/mixtral.py +3 -3
  104. sglang/srt/models/mixtral_quant.py +3 -3
  105. sglang/srt/models/mllama.py +2 -2
  106. sglang/srt/models/olmo.py +3 -3
  107. sglang/srt/models/olmo2.py +4 -4
  108. sglang/srt/models/olmoe.py +7 -13
  109. sglang/srt/models/phi3_small.py +2 -2
  110. sglang/srt/models/qwen.py +2 -2
  111. sglang/srt/models/qwen2.py +41 -4
  112. sglang/srt/models/qwen2_moe.py +3 -3
  113. sglang/srt/models/qwen2_vl.py +22 -122
  114. sglang/srt/models/stablelm.py +2 -2
  115. sglang/srt/models/torch_native_llama.py +20 -7
  116. sglang/srt/models/xverse.py +6 -6
  117. sglang/srt/models/xverse_moe.py +6 -6
  118. sglang/srt/openai_api/adapter.py +139 -37
  119. sglang/srt/openai_api/protocol.py +7 -4
  120. sglang/srt/sampling/custom_logit_processor.py +38 -0
  121. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +11 -14
  122. sglang/srt/sampling/sampling_batch_info.py +143 -18
  123. sglang/srt/sampling/sampling_params.py +3 -1
  124. sglang/srt/server.py +4 -1090
  125. sglang/srt/server_args.py +77 -15
  126. sglang/srt/speculative/eagle_utils.py +37 -15
  127. sglang/srt/speculative/eagle_worker.py +11 -13
  128. sglang/srt/utils.py +164 -129
  129. sglang/test/runners.py +8 -13
  130. sglang/test/test_programs.py +2 -1
  131. sglang/test/test_utils.py +83 -22
  132. sglang/utils.py +12 -2
  133. sglang/version.py +1 -1
  134. {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/METADATA +21 -10
  135. {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/RECORD +138 -123
  136. sglang/launch_server_llavavid.py +0 -25
  137. sglang/srt/constrained/__init__.py +0 -16
  138. sglang/srt/distributed/device_communicators/__init__.py +0 -0
  139. {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/LICENSE +0 -0
  140. {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/WHEEL +0 -0
  141. {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/top_level.txt +0 -0
@@ -18,25 +18,25 @@ from typing import Any, Dict, Iterable, Optional, Tuple
18
18
  import torch
19
19
  from torch import nn
20
20
  from transformers import PretrainedConfig
21
- from vllm.distributed import (
21
+
22
+ from sglang.srt.distributed import (
22
23
  get_tensor_model_parallel_rank,
23
24
  get_tensor_model_parallel_world_size,
24
25
  tensor_model_parallel_all_reduce,
25
26
  )
26
- from vllm.model_executor.layers.activation import SiluAndMul
27
- from vllm.model_executor.layers.layernorm import RMSNorm
28
- from vllm.model_executor.layers.linear import (
27
+ from sglang.srt.layers.activation import SiluAndMul
28
+ from sglang.srt.layers.layernorm import RMSNorm
29
+ from sglang.srt.layers.linear import (
29
30
  MergedColumnParallelLinear,
30
31
  QKVParallelLinear,
31
32
  ReplicatedLinear,
32
33
  RowParallelLinear,
33
34
  )
34
- from vllm.model_executor.layers.rotary_embedding import get_rope
35
-
36
35
  from sglang.srt.layers.logits_processor import LogitsProcessor
37
36
  from sglang.srt.layers.moe.fused_moe_triton import fused_moe
38
37
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
39
38
  from sglang.srt.layers.radix_attention import RadixAttention
39
+ from sglang.srt.layers.rotary_embedding import get_rope
40
40
  from sglang.srt.layers.vocab_parallel_embedding import (
41
41
  ParallelLMHead,
42
42
  VocabParallelEmbedding,
@@ -20,7 +20,7 @@ import os
20
20
  import time
21
21
  import uuid
22
22
  from http import HTTPStatus
23
- from typing import Dict, List
23
+ from typing import Dict, List, Optional
24
24
 
25
25
  from fastapi import HTTPException, Request, UploadFile
26
26
  from fastapi.responses import ORJSONResponse, StreamingResponse
@@ -40,6 +40,7 @@ from sglang.srt.conversation import (
40
40
  generate_chat_conv,
41
41
  register_conv_template,
42
42
  )
43
+ from sglang.srt.function_call_parser import TOOLS_TAG_LIST, FunctionCallParser
43
44
  from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput
44
45
  from sglang.srt.openai_api.protocol import (
45
46
  BatchRequest,
@@ -71,7 +72,6 @@ from sglang.srt.openai_api.protocol import (
71
72
  TopLogprob,
72
73
  UsageInfo,
73
74
  )
74
- from sglang.srt.utils import TOOLS_TAG_LIST, parse_tool_response
75
75
  from sglang.utils import get_exception_traceback
76
76
 
77
77
  logger = logging.getLogger(__name__)
@@ -309,6 +309,7 @@ async def process_batch(tokenizer_manager, batch_id: str, batch_request: BatchRe
309
309
  ret,
310
310
  to_file=True,
311
311
  cache_report=tokenizer_manager.server_args.enable_cache_report,
312
+ tool_call_parser=tokenizer_manager.server_args.tool_call_parser,
312
313
  )
313
314
  else:
314
315
  responses = v1_generate_response(
@@ -877,9 +878,6 @@ def v1_chat_generate_request(
877
878
  tools = None
878
879
  if request.tools and request.tool_choice != "none":
879
880
  request.skip_special_tokens = False
880
- if request.stream:
881
- logger.warning("Streaming is not supported with tools.")
882
- request.stream = False
883
881
  if not isinstance(request.tool_choice, str):
884
882
  tools = [
885
883
  item.function.model_dump()
@@ -908,12 +906,26 @@ def v1_chat_generate_request(
908
906
  openai_compatible_messages = openai_compatible_messages[:-1]
909
907
  else:
910
908
  assistant_prefix = None
911
- prompt_ids = tokenizer_manager.tokenizer.apply_chat_template(
912
- openai_compatible_messages,
913
- tokenize=True,
914
- add_generation_prompt=True,
915
- tools=tools,
916
- )
909
+
910
+ try:
911
+ prompt_ids = tokenizer_manager.tokenizer.apply_chat_template(
912
+ openai_compatible_messages,
913
+ tokenize=True,
914
+ add_generation_prompt=True,
915
+ tools=tools,
916
+ )
917
+ except:
918
+ # This except branch will be triggered when the chosen model
919
+ # has a different tools input format that is not compatiable
920
+ # with openAI's apply_chat_template tool_call format, like Mistral.
921
+ tools = [t if "function" in t else {"function": t} for t in tools]
922
+ prompt_ids = tokenizer_manager.tokenizer.apply_chat_template(
923
+ openai_compatible_messages,
924
+ tokenize=True,
925
+ add_generation_prompt=True,
926
+ tools=tools,
927
+ )
928
+
917
929
  if assistant_prefix:
918
930
  prompt_ids += tokenizer_manager.tokenizer.encode(assistant_prefix)
919
931
  stop = request.stop
@@ -1005,7 +1017,9 @@ def v1_chat_generate_request(
1005
1017
  return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0]
1006
1018
 
1007
1019
 
1008
- def v1_chat_generate_response(request, ret, to_file=False, cache_report=False):
1020
+ def v1_chat_generate_response(
1021
+ request, ret, to_file=False, cache_report=False, tool_call_parser=None
1022
+ ):
1009
1023
  choices = []
1010
1024
 
1011
1025
  for idx, ret_item in enumerate(ret):
@@ -1066,12 +1080,13 @@ def v1_chat_generate_response(request, ret, to_file=False, cache_report=False):
1066
1080
  if finish_reason == "stop":
1067
1081
  finish_reason = "tool_calls"
1068
1082
  try:
1069
- text, call_info_list = parse_tool_response(text, tools) # noqa
1083
+ parser = FunctionCallParser(tools, tool_call_parser)
1084
+ full_normal_text, call_info_list = parser.parse_non_stream(text)
1070
1085
  tool_calls = [
1071
1086
  ToolCall(
1072
- id=str(call_info[0]),
1087
+ id=str(call_info.tool_index),
1073
1088
  function=FunctionResponse(
1074
- name=call_info[1], arguments=call_info[2]
1089
+ name=call_info.name, arguments=call_info.parameters
1075
1090
  ),
1076
1091
  )
1077
1092
  for call_info in call_info_list
@@ -1172,6 +1187,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
1172
1187
  adapted_request, request = v1_chat_generate_request(all_requests, tokenizer_manager)
1173
1188
 
1174
1189
  if adapted_request.stream:
1190
+ parser_dict = {}
1175
1191
 
1176
1192
  async def generate_stream_resp():
1177
1193
  is_firsts = {}
@@ -1184,6 +1200,7 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
1184
1200
  adapted_request, raw_request
1185
1201
  ):
1186
1202
  index = content.get("index", 0)
1203
+ text = content["text"]
1187
1204
 
1188
1205
  is_first = is_firsts.get(index, True)
1189
1206
  stream_buffer = stream_buffers.get(index, "")
@@ -1263,29 +1280,111 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
1263
1280
 
1264
1281
  text = content["text"]
1265
1282
  delta = text[len(stream_buffer) :]
1266
- stream_buffer = stream_buffer + delta
1267
- choice_data = ChatCompletionResponseStreamChoice(
1268
- index=index,
1269
- delta=DeltaMessage(content=delta),
1270
- finish_reason=(finish_reason["type"] if finish_reason else ""),
1271
- matched_stop=(
1272
- finish_reason["matched"]
1273
- if finish_reason and "matched" in finish_reason
1274
- else None
1275
- ),
1276
- logprobs=choice_logprobs,
1277
- )
1278
- chunk = ChatCompletionStreamResponse(
1279
- id=content["meta_info"]["id"],
1280
- choices=[choice_data],
1281
- model=request.model,
1282
- )
1283
+ new_stream_buffer = stream_buffer + delta
1283
1284
 
1284
- is_firsts[index] = is_first
1285
- stream_buffers[index] = stream_buffer
1286
- n_prev_tokens[index] = n_prev_token
1285
+ if request.tool_choice != "none" and request.tools:
1286
+ if index not in parser_dict:
1287
+ parser_dict[index] = FunctionCallParser(
1288
+ tools=request.tools,
1289
+ tool_call_parser=tokenizer_manager.server_args.tool_call_parser,
1290
+ )
1291
+ parser = parser_dict[index]
1292
+
1293
+ # parse_increment => returns (normal_text, calls)
1294
+ normal_text, calls = parser.parse_stream_chunk(delta)
1295
+
1296
+ # 1) if there's normal_text, output it as normal content
1297
+ if normal_text:
1298
+ choice_data = ChatCompletionResponseStreamChoice(
1299
+ index=index,
1300
+ delta=DeltaMessage(content=normal_text),
1301
+ finish_reason=(
1302
+ finish_reason["type"] if finish_reason else ""
1303
+ ),
1304
+ )
1305
+ chunk = ChatCompletionStreamResponse(
1306
+ id=content["meta_info"]["id"],
1307
+ choices=[choice_data],
1308
+ model=request.model,
1309
+ )
1310
+ yield f"data: {chunk.model_dump_json()}\n\n"
1311
+
1312
+ # 2) if we found calls, we output them as separate chunk(s)
1313
+ for call_item in calls:
1314
+ # transform call_item -> FunctionResponse + ToolCall
1315
+
1316
+ if (
1317
+ content["meta_info"]["finish_reason"]
1318
+ and content["meta_info"]["finish_reason"]["type"]
1319
+ == "stop"
1320
+ ):
1321
+ latest_delta_len = 0
1322
+ if isinstance(call_item.parameters, str):
1323
+ latest_delta_len = len(call_item.parameters)
1324
+
1325
+ expected_call = json.dumps(
1326
+ parser.multi_format_parser.detectors[0]
1327
+ .prev_tool_call_arr[index]
1328
+ .get("arguments", {}),
1329
+ ensure_ascii=False,
1330
+ )
1331
+ actual_call = parser.multi_format_parser.detectors[
1332
+ 0
1333
+ ].streamed_args_for_tool[index]
1334
+ if latest_delta_len > 0:
1335
+ actual_call = actual_call[:-latest_delta_len]
1336
+ remaining_call = expected_call.replace(
1337
+ actual_call, "", 1
1338
+ )
1339
+ call_item.parameters = remaining_call
1340
+
1341
+ tool_call = ToolCall(
1342
+ id=str(call_item.tool_index),
1343
+ function=FunctionResponse(
1344
+ name=call_item.name,
1345
+ arguments=call_item.parameters,
1346
+ ),
1347
+ )
1348
+ choice_data = ChatCompletionResponseStreamChoice(
1349
+ index=index,
1350
+ delta=DeltaMessage(
1351
+ role="assistant", tool_calls=[tool_call]
1352
+ ),
1353
+ finish_reason="tool_call",
1354
+ )
1355
+ chunk = ChatCompletionStreamResponse(
1356
+ id=content["meta_info"]["id"],
1357
+ choices=[choice_data],
1358
+ model=request.model,
1359
+ )
1360
+ yield f"data: {chunk.model_dump_json()}\n\n"
1287
1361
 
1288
- yield f"data: {chunk.model_dump_json()}\n\n"
1362
+ stream_buffers[index] = new_stream_buffer
1363
+ is_firsts[index] = is_first
1364
+
1365
+ else:
1366
+ # No tool calls => just treat this as normal text
1367
+ choice_data = ChatCompletionResponseStreamChoice(
1368
+ index=index,
1369
+ delta=DeltaMessage(content=delta),
1370
+ finish_reason=(
1371
+ finish_reason["type"] if finish_reason else ""
1372
+ ),
1373
+ matched_stop=(
1374
+ finish_reason["matched"]
1375
+ if finish_reason and "matched" in finish_reason
1376
+ else None
1377
+ ),
1378
+ logprobs=choice_logprobs,
1379
+ )
1380
+ chunk = ChatCompletionStreamResponse(
1381
+ id=content["meta_info"]["id"],
1382
+ choices=[choice_data],
1383
+ model=request.model,
1384
+ )
1385
+ yield f"data: {chunk.model_dump_json()}\n\n"
1386
+ stream_buffers[index] = new_stream_buffer
1387
+ is_firsts[index] = is_first
1289
1388
  if request.stream_options and request.stream_options.include_usage:
1290
1389
  total_prompt_tokens = sum(
1291
1390
  tokens
@@ -1333,7 +1432,10 @@ async def v1_chat_completions(tokenizer_manager, raw_request: Request):
1333
1432
  ret = [ret]
1334
1433
 
1335
1434
  response = v1_chat_generate_response(
1336
- request, ret, cache_report=tokenizer_manager.server_args.enable_cache_report
1435
+ request,
1436
+ ret,
1437
+ cache_report=tokenizer_manager.server_args.enable_cache_report,
1438
+ tool_call_parser=tokenizer_manager.server_args.tool_call_parser,
1337
1439
  )
1338
1440
 
1339
1441
  return response
@@ -180,6 +180,7 @@ class CompletionRequest(BaseModel):
180
180
  ignore_eos: bool = False
181
181
  skip_special_tokens: bool = True
182
182
  lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
183
+ session_params: Optional[Dict] = None
183
184
 
184
185
 
185
186
  class CompletionResponseChoice(BaseModel):
@@ -261,7 +262,7 @@ class Function(BaseModel):
261
262
  """Function descriptions."""
262
263
 
263
264
  description: Optional[str] = Field(default=None, examples=[None])
264
- name: str
265
+ name: Optional[str] = None
265
266
  parameters: Optional[object] = None
266
267
 
267
268
 
@@ -275,7 +276,7 @@ class Tool(BaseModel):
275
276
  class ToolChoiceFuncName(BaseModel):
276
277
  """The name of tool choice function."""
277
278
 
278
- name: str
279
+ name: Optional[str] = None
279
280
 
280
281
 
281
282
  class ToolChoice(BaseModel):
@@ -322,13 +323,14 @@ class ChatCompletionRequest(BaseModel):
322
323
  ignore_eos: bool = False
323
324
  skip_special_tokens: bool = True
324
325
  lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
326
+ session_params: Optional[Dict] = None
325
327
 
326
328
 
327
329
  class FunctionResponse(BaseModel):
328
330
  """Function response."""
329
331
 
330
- name: str
331
- arguments: str
332
+ name: Optional[str] = None
333
+ arguments: Optional[str] = None
332
334
 
333
335
 
334
336
  class ToolCall(BaseModel):
@@ -365,6 +367,7 @@ class ChatCompletionResponse(BaseModel):
365
367
  class DeltaMessage(BaseModel):
366
368
  role: Optional[str] = None
367
369
  content: Optional[str] = None
370
+ tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
368
371
 
369
372
 
370
373
  class ChatCompletionResponseStreamChoice(BaseModel):
@@ -0,0 +1,38 @@
1
+ import json
2
+ from abc import ABC, abstractmethod
3
+ from functools import lru_cache
4
+ from typing import Any, Dict, List, Optional
5
+
6
+ import dill
7
+ import torch
8
+
9
+
10
+ @lru_cache(maxsize=None)
11
+ def _cache_from_str(json_str: str):
12
+ """Deserialize a json string to a Callable object.
13
+ This function is cached to avoid redundant deserialization.
14
+ """
15
+ data = json.loads(json_str)
16
+ return dill.loads(bytes.fromhex(data["callable"]))
17
+
18
+
19
+ class CustomLogitProcessor(ABC):
20
+ """Abstract base class for callable functions."""
21
+
22
+ @abstractmethod
23
+ def __call__(
24
+ self,
25
+ logits: torch.Tensor,
26
+ custom_param_list: Optional[List[Dict[str, Any]]] = None,
27
+ ) -> torch.Tensor:
28
+ """Define the callable behavior."""
29
+ raise NotImplementedError
30
+
31
+ def to_str(self) -> str:
32
+ """Serialize the callable function to a JSON-compatible string."""
33
+ return json.dumps({"callable": dill.dumps(self).hex()})
34
+
35
+ @classmethod
36
+ def from_str(cls, json_str: str):
37
+ """Deserialize a callable function from a JSON string."""
38
+ return _cache_from_str(json_str)
@@ -3,11 +3,16 @@ from typing import List
3
3
  import torch
4
4
 
5
5
  from sglang.srt.sampling.penaltylib.orchestrator import _BatchedPenalizer, _TokenIDs
6
- from sglang.srt.utils import is_cuda_available
6
+ from sglang.srt.utils import get_compiler_backend
7
7
 
8
- is_cuda = is_cuda_available()
9
- if is_cuda:
10
- from sgl_kernel import sampling_scaling_penalties
8
+
9
+ @torch.compile(dynamic=True, backend=get_compiler_backend())
10
+ def apply_scaling_penalties(logits, scaling_penalties):
11
+ logits[:] = torch.where(
12
+ logits > 0,
13
+ logits / scaling_penalties,
14
+ logits * scaling_penalties,
15
+ )
11
16
 
12
17
 
13
18
  class BatchedRepetitionPenalizer(_BatchedPenalizer):
@@ -61,16 +66,8 @@ class BatchedRepetitionPenalizer(_BatchedPenalizer):
61
66
  self.cumulated_repetition_penalties[mask] = self.repetition_penalties[mask]
62
67
 
63
68
  def _apply(self, logits: torch.Tensor) -> torch.Tensor:
64
- if is_cuda:
65
- return sampling_scaling_penalties(
66
- logits, self.cumulated_repetition_penalties
67
- )
68
- else:
69
- return torch.where(
70
- logits > 0,
71
- logits / self.cumulated_repetition_penalties,
72
- logits * self.cumulated_repetition_penalties,
73
- )
69
+ apply_scaling_penalties(logits, self.cumulated_repetition_penalties)
70
+ return logits
74
71
 
75
72
  def _filter(self, indices_to_keep: List[int], indices_tensor_to_keep: torch.Tensor):
76
73
  self.repetition_penalties = self.repetition_penalties[indices_tensor_to_keep]
@@ -3,17 +3,15 @@ from __future__ import annotations
3
3
  import dataclasses
4
4
  import logging
5
5
  import threading
6
- from typing import TYPE_CHECKING, Callable, List, Optional
6
+ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple
7
7
 
8
8
  import torch
9
9
 
10
- from sglang.srt.utils import is_cuda_available
11
-
12
- is_cuda = is_cuda_available()
13
- if is_cuda:
14
- from sgl_kernel import sampling_scaling_penalties
15
-
16
10
  import sglang.srt.sampling.penaltylib as penaltylib
11
+ from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
12
+ from sglang.srt.sampling.penaltylib.penalizers.repetition_penalty import (
13
+ apply_scaling_penalties,
14
+ )
17
15
 
18
16
  logger = logging.getLogger(__name__)
19
17
 
@@ -36,6 +34,9 @@ class SamplingBatchInfo:
36
34
  # Dispatch in CUDA graph
37
35
  need_min_p_sampling: bool
38
36
 
37
+ # Whether any request has custom logit processor
38
+ has_custom_logit_processor: bool
39
+
39
40
  # Bias Tensors
40
41
  vocab_size: int
41
42
  grammars: Optional[List] = None
@@ -52,6 +53,14 @@ class SamplingBatchInfo:
52
53
  # Device
53
54
  device: str = "cuda"
54
55
 
56
+ # Custom Parameters
57
+ custom_params: Optional[List[Optional[Dict[str, Any]]]] = None
58
+
59
+ # Custom Logit Processor
60
+ custom_logit_processor: Optional[
61
+ Dict[int, Tuple[CustomLogitProcessor, torch.Tensor]]
62
+ ] = None
63
+
55
64
  @classmethod
56
65
  def from_schedule_batch(
57
66
  cls, batch: ScheduleBatch, vocab_size: int, enable_overlap_schedule: bool
@@ -76,6 +85,39 @@ class SamplingBatchInfo:
76
85
  [r.sampling_params.min_p for r in reqs], dtype=torch.float
77
86
  ).to(device, non_blocking=True)
78
87
 
88
+ # Check if any request has custom logit processor
89
+ has_custom_logit_processor = (
90
+ batch.enable_custom_logit_processor # check the flag first.
91
+ and any(r.custom_logit_processor for r in reqs) # then check the requests.
92
+ )
93
+
94
+ if has_custom_logit_processor:
95
+ # Merge the same type of custom logit processors together
96
+ processor_dict = {}
97
+ for i, r in enumerate(reqs):
98
+ if r.custom_logit_processor is None:
99
+ continue
100
+ processor_str = r.custom_logit_processor
101
+ if processor_str not in processor_dict:
102
+ processor_dict[processor_str] = []
103
+ processor_dict[processor_str].append(i)
104
+
105
+ merged_custom_logit_processor = {
106
+ hash(processor_str): (
107
+ # The deserialized custom logit processor object
108
+ CustomLogitProcessor.from_str(processor_str),
109
+ # The mask tensor for the requests that use this custom logit processor
110
+ torch.zeros(len(reqs), dtype=torch.bool)
111
+ .scatter_(0, torch.tensor(true_indices), True)
112
+ .to(device, non_blocking=True),
113
+ )
114
+ for processor_str, true_indices in processor_dict.items()
115
+ }
116
+ custom_params = [r.sampling_params.custom_params for r in reqs]
117
+ else:
118
+ merged_custom_logit_processor = None
119
+ custom_params = None
120
+
79
121
  ret = cls(
80
122
  temperatures=temperatures,
81
123
  top_ps=top_ps,
@@ -83,8 +125,11 @@ class SamplingBatchInfo:
83
125
  min_ps=min_ps,
84
126
  need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs),
85
127
  is_all_greedy=all(r.sampling_params.top_k <= 1 for r in reqs),
128
+ has_custom_logit_processor=has_custom_logit_processor,
86
129
  vocab_size=vocab_size,
87
130
  device=device,
131
+ custom_params=custom_params,
132
+ custom_logit_processor=merged_custom_logit_processor,
88
133
  )
89
134
  # TODO (lianmin): `need_min_p_sampling` needs to be updated in filter and merge.
90
135
 
@@ -184,6 +229,8 @@ class SamplingBatchInfo:
184
229
 
185
230
  def filter_batch(self, unfinished_indices: List[int], new_indices: torch.Tensor):
186
231
  self.penalizer_orchestrator.filter(unfinished_indices, new_indices)
232
+ if self.has_custom_logit_processor:
233
+ self._filter_batch_custom_logit_processor(unfinished_indices, new_indices)
187
234
 
188
235
  for item in [
189
236
  "temperatures",
@@ -196,6 +243,27 @@ class SamplingBatchInfo:
196
243
  if value is not None: # logit_bias can be None
197
244
  setattr(self, item, value[new_indices])
198
245
 
246
+ def _filter_batch_custom_logit_processor(
247
+ self, unfinished_indices: List[int], new_indices: torch.Tensor
248
+ ):
249
+ """Filter the custom logit processor and custom params"""
250
+
251
+ self.custom_logit_processor = {
252
+ k: (p, mask[new_indices])
253
+ for k, (p, mask) in self.custom_logit_processor.items()
254
+ if any(
255
+ mask[new_indices]
256
+ ) # ignore the custom logit processor whose mask is all False
257
+ }
258
+ self.custom_params = [self.custom_params[i] for i in unfinished_indices]
259
+
260
+ # If the custom logit processor is an empty dict, set the flag to False,
261
+ # and set the custom logit processor and custom params to None.
262
+ if len(self.custom_logit_processor) == 0:
263
+ self.custom_logit_processor = None
264
+ self.custom_params = None
265
+ self.has_custom_logit_processor = False
266
+
199
267
  @staticmethod
200
268
  def merge_bias_tensor(
201
269
  lhs: torch.Tensor,
@@ -221,9 +289,76 @@ class SamplingBatchInfo:
221
289
 
222
290
  return None
223
291
 
292
+ @staticmethod
293
+ def merge_custom_logit_processor(
294
+ lhs: Optional[Dict[int, Tuple[CustomLogitProcessor, torch.Tensor]]],
295
+ rhs: Optional[Dict[int, Tuple[CustomLogitProcessor, torch.Tensor]]],
296
+ bs1: int,
297
+ bs2: int,
298
+ device: str,
299
+ ):
300
+ if lhs is None and rhs is None:
301
+ return None
302
+ lhs, rhs = lhs or {}, rhs or {}
303
+
304
+ keys = set(lhs.keys()).union(set(rhs.keys()))
305
+ merged_dict = {}
306
+
307
+ for k in keys:
308
+ # Get the logit processor object
309
+ processor = lhs[k][0] if k in lhs else rhs[k][0]
310
+ # Get and merge the mask tensors from the two dicts
311
+ left_mask = (
312
+ lhs[k][1]
313
+ if k in lhs
314
+ else torch.zeros(bs1, dtype=torch.bool, device=device)
315
+ )
316
+ right_mask = (
317
+ rhs[k][1]
318
+ if k in rhs
319
+ else torch.zeros(bs2, dtype=torch.bool, device=device)
320
+ )
321
+ merged_dict[k] = (processor, torch.cat([left_mask, right_mask]))
322
+
323
+ assert merged_dict[k][1].shape[0] == bs1 + bs2, (
324
+ f"The batch size of merged mask ({merged_dict[k][1].shape[0]}) does not match "
325
+ f"the sum of the batch sizes of the two masks ({bs1 + bs2})"
326
+ f"\n{left_mask=}\n{right_mask=}\n{bs1=}\n{bs2=}"
327
+ f"\n{lhs=}\n{rhs=}"
328
+ )
329
+
330
+ return merged_dict
331
+
224
332
  def merge_batch(self, other: "SamplingBatchInfo"):
225
333
  self.penalizer_orchestrator.merge(other.penalizer_orchestrator)
226
334
 
335
+ # Merge the logit bias tensor
336
+ self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
337
+ self.logit_bias, other.logit_bias, len(self), len(other), self.device
338
+ )
339
+ # Merge the custom logit processors and custom params lists
340
+ if self.has_custom_logit_processor or other.has_custom_logit_processor:
341
+ # Merge the custom logit processors
342
+ self.custom_logit_processor = (
343
+ SamplingBatchInfo.merge_custom_logit_processor(
344
+ self.custom_logit_processor,
345
+ other.custom_logit_processor,
346
+ len(self),
347
+ len(other),
348
+ self.device,
349
+ )
350
+ )
351
+ # Merge the custom params lists
352
+ self.custom_params = self.custom_params or [None] * len(self)
353
+ other.custom_params = other.custom_params or [None] * len(other)
354
+ self.custom_params.extend(other.custom_params)
355
+
356
+ # Set the flag to True if any of the two has custom logit processor
357
+ self.has_custom_logit_processor = True
358
+
359
+ # Note: becasue the __len()__ operator is defined on the temperatures tensor,
360
+ # please make sure any merge operation with len(self) or len(other) is done before
361
+ # the merge operation of the temperatures tensor below.
227
362
  for item in [
228
363
  "temperatures",
229
364
  "top_ps",
@@ -235,9 +370,6 @@ class SamplingBatchInfo:
235
370
  setattr(self, item, torch.concat([self_val, other_val]))
236
371
 
237
372
  self.is_all_greedy = self.is_all_greedy and other.is_all_greedy
238
- self.logit_bias = SamplingBatchInfo.merge_bias_tensor(
239
- self.logit_bias, other.logit_bias, len(self), len(other), self.device
240
- )
241
373
  self.need_min_p_sampling = self.need_min_p_sampling or other.need_min_p_sampling
242
374
 
243
375
  def apply_logits_bias(self, logits: torch.Tensor):
@@ -251,14 +383,7 @@ class SamplingBatchInfo:
251
383
 
252
384
  # repetition
253
385
  if self.scaling_penalties is not None:
254
- if is_cuda:
255
- logits[:] = sampling_scaling_penalties(logits, self.scaling_penalties)
256
- else:
257
- logits[:] = torch.where(
258
- logits > 0,
259
- logits / self.scaling_penalties,
260
- logits * self.scaling_penalties,
261
- )
386
+ apply_scaling_penalties(logits, self.scaling_penalties)
262
387
 
263
388
  # Apply regex vocab_mask
264
389
  if self.vocab_mask is not None:
@@ -13,7 +13,7 @@
13
13
  # ==============================================================================
14
14
  """Sampling parameters for text generation."""
15
15
 
16
- from typing import List, Optional, Union
16
+ from typing import Any, Dict, List, Optional, Union
17
17
 
18
18
  _SAMPLING_EPS = 1e-6
19
19
 
@@ -48,6 +48,7 @@ class SamplingParams:
48
48
  no_stop_trim: bool = False,
49
49
  ignore_eos: bool = False,
50
50
  skip_special_tokens: bool = True,
51
+ custom_params: Optional[Dict[str, Any]] = None,
51
52
  ) -> None:
52
53
  self.temperature = temperature
53
54
  self.top_p = top_p
@@ -71,6 +72,7 @@ class SamplingParams:
71
72
  self.json_schema = json_schema
72
73
  self.ebnf = ebnf
73
74
  self.no_stop_trim = no_stop_trim
75
+ self.custom_params = custom_params
74
76
 
75
77
  # Process some special cases
76
78
  if self.temperature < _SAMPLING_EPS: