sglang 0.4.6.post1__py3-none-any.whl → 0.4.6.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. sglang/bench_one_batch.py +3 -11
  2. sglang/bench_serving.py +149 -1
  3. sglang/check_env.py +3 -3
  4. sglang/lang/chat_template.py +44 -0
  5. sglang/srt/configs/__init__.py +4 -0
  6. sglang/srt/configs/deepseekvl2.py +3 -0
  7. sglang/srt/configs/device_config.py +1 -1
  8. sglang/srt/configs/internvl.py +696 -0
  9. sglang/srt/configs/janus_pro.py +3 -0
  10. sglang/srt/configs/kimi_vl.py +38 -0
  11. sglang/srt/configs/kimi_vl_moonvit.py +32 -0
  12. sglang/srt/configs/model_config.py +32 -0
  13. sglang/srt/constrained/xgrammar_backend.py +11 -19
  14. sglang/srt/conversation.py +151 -3
  15. sglang/srt/disaggregation/decode.py +4 -1
  16. sglang/srt/disaggregation/mini_lb.py +74 -23
  17. sglang/srt/disaggregation/mooncake/conn.py +9 -18
  18. sglang/srt/disaggregation/nixl/conn.py +241 -71
  19. sglang/srt/disaggregation/utils.py +44 -1
  20. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
  21. sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
  22. sglang/srt/distributed/device_communicators/pynccl.py +2 -1
  23. sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
  24. sglang/srt/distributed/parallel_state.py +22 -1
  25. sglang/srt/entrypoints/engine.py +58 -24
  26. sglang/srt/entrypoints/http_server.py +28 -1
  27. sglang/srt/entrypoints/verl_engine.py +3 -2
  28. sglang/srt/function_call_parser.py +97 -0
  29. sglang/srt/hf_transformers_utils.py +22 -1
  30. sglang/srt/layers/attention/cutlass_mla_backend.py +1 -1
  31. sglang/srt/layers/attention/flashattention_backend.py +146 -50
  32. sglang/srt/layers/attention/flashinfer_backend.py +129 -94
  33. sglang/srt/layers/attention/flashinfer_mla_backend.py +88 -30
  34. sglang/srt/layers/attention/flashmla_backend.py +3 -0
  35. sglang/srt/layers/attention/merge_state.py +46 -0
  36. sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
  37. sglang/srt/layers/attention/vision.py +290 -163
  38. sglang/srt/layers/dp_attention.py +5 -2
  39. sglang/srt/layers/moe/ep_moe/kernels.py +342 -7
  40. sglang/srt/layers/moe/ep_moe/layer.py +120 -1
  41. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +98 -57
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  43. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  44. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  45. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  46. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  47. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  48. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +10 -5
  49. sglang/srt/layers/quantization/__init__.py +2 -2
  50. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
  51. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
  52. sglang/srt/layers/quantization/deep_gemm.py +6 -1
  53. sglang/srt/layers/quantization/fp8.py +108 -95
  54. sglang/srt/layers/quantization/fp8_kernel.py +79 -60
  55. sglang/srt/layers/quantization/fp8_utils.py +71 -23
  56. sglang/srt/layers/quantization/kv_cache.py +3 -10
  57. sglang/srt/layers/quantization/utils.py +0 -5
  58. sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
  59. sglang/srt/layers/utils.py +35 -0
  60. sglang/srt/lora/layers.py +35 -9
  61. sglang/srt/lora/lora_manager.py +81 -35
  62. sglang/srt/managers/cache_controller.py +115 -119
  63. sglang/srt/managers/data_parallel_controller.py +52 -34
  64. sglang/srt/managers/io_struct.py +10 -0
  65. sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
  66. sglang/srt/managers/multimodal_processors/internvl.py +232 -0
  67. sglang/srt/managers/multimodal_processors/kimi_vl.py +73 -0
  68. sglang/srt/managers/schedule_batch.py +44 -16
  69. sglang/srt/managers/schedule_policy.py +11 -5
  70. sglang/srt/managers/scheduler.py +291 -72
  71. sglang/srt/managers/scheduler_output_processor_mixin.py +1 -1
  72. sglang/srt/managers/tokenizer_manager.py +24 -13
  73. sglang/srt/managers/tp_worker.py +60 -28
  74. sglang/srt/managers/tp_worker_overlap_thread.py +9 -3
  75. sglang/srt/mem_cache/chunk_cache.py +2 -0
  76. sglang/srt/mem_cache/memory_pool.py +70 -36
  77. sglang/srt/model_executor/cuda_graph_runner.py +82 -19
  78. sglang/srt/model_executor/forward_batch_info.py +31 -1
  79. sglang/srt/model_executor/model_runner.py +159 -90
  80. sglang/srt/model_loader/loader.py +18 -11
  81. sglang/srt/models/clip.py +4 -4
  82. sglang/srt/models/deepseek_janus_pro.py +1 -1
  83. sglang/srt/models/deepseek_nextn.py +2 -277
  84. sglang/srt/models/deepseek_v2.py +132 -37
  85. sglang/srt/models/gemma3_mm.py +1 -1
  86. sglang/srt/models/internlm2.py +3 -0
  87. sglang/srt/models/internvl.py +670 -0
  88. sglang/srt/models/kimi_vl.py +308 -0
  89. sglang/srt/models/kimi_vl_moonvit.py +639 -0
  90. sglang/srt/models/llama.py +93 -31
  91. sglang/srt/models/llama4.py +54 -7
  92. sglang/srt/models/llama_eagle.py +4 -1
  93. sglang/srt/models/llama_eagle3.py +4 -1
  94. sglang/srt/models/minicpmv.py +1 -1
  95. sglang/srt/models/mllama.py +1 -1
  96. sglang/srt/models/phi3_small.py +16 -2
  97. sglang/srt/models/qwen2_5_vl.py +8 -4
  98. sglang/srt/models/qwen2_moe.py +8 -3
  99. sglang/srt/models/qwen2_vl.py +4 -16
  100. sglang/srt/models/qwen3_moe.py +8 -3
  101. sglang/srt/models/xiaomi_mimo.py +171 -0
  102. sglang/srt/openai_api/adapter.py +58 -62
  103. sglang/srt/openai_api/protocol.py +38 -16
  104. sglang/srt/reasoning_parser.py +2 -2
  105. sglang/srt/sampling/sampling_batch_info.py +54 -2
  106. sglang/srt/sampling/sampling_params.py +2 -0
  107. sglang/srt/server_args.py +93 -24
  108. sglang/srt/speculative/eagle_worker.py +3 -2
  109. sglang/srt/utils.py +123 -10
  110. sglang/test/runners.py +4 -0
  111. sglang/test/test_block_fp8.py +2 -2
  112. sglang/test/test_deepep_utils.py +219 -0
  113. sglang/test/test_utils.py +32 -1
  114. sglang/version.py +1 -1
  115. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/METADATA +18 -9
  116. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/RECORD +119 -99
  117. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/WHEEL +1 -1
  118. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/licenses/LICENSE +0 -0
  119. {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,171 @@
1
+ # Adapted from qwen2.py
2
+
3
+ from functools import partial
4
+ from typing import Any, Dict, Iterable, Optional, Tuple
5
+
6
+ import torch
7
+ from torch import nn
8
+
9
+ from sglang.srt.distributed import (
10
+ get_tensor_model_parallel_rank,
11
+ get_tensor_model_parallel_world_size,
12
+ split_tensor_along_last_dim,
13
+ tensor_model_parallel_all_gather,
14
+ )
15
+ from sglang.srt.layers.layernorm import RMSNorm
16
+ from sglang.srt.layers.linear import QKVParallelLinear, RowParallelLinear
17
+ from sglang.srt.layers.logits_processor import LogitsProcessor
18
+ from sglang.srt.layers.pooler import Pooler, PoolingType
19
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
20
+ from sglang.srt.layers.radix_attention import RadixAttention
21
+ from sglang.srt.layers.rotary_embedding import get_rope
22
+ from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
23
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
24
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
25
+ from sglang.srt.models.qwen2 import Qwen2DecoderLayer, Qwen2MLP, Qwen2Model
26
+ from sglang.srt.utils import add_prefix
27
+
28
+ MiMoConfig = None
29
+
30
+
31
+ class MiMoModel(Qwen2Model):
32
+ def __init__(
33
+ self,
34
+ config: MiMoConfig,
35
+ quant_config: Optional[QuantizationConfig] = None,
36
+ prefix: str = "",
37
+ ) -> None:
38
+ super().__init__(
39
+ config=config,
40
+ quant_config=quant_config,
41
+ prefix=prefix,
42
+ decoder_layer_type=Qwen2DecoderLayer,
43
+ )
44
+
45
+
46
+ class MiMoForCausalLM(nn.Module):
47
+ # BitandBytes specific attributes
48
+ default_bitsandbytes_target_modules = [
49
+ ".gate_proj.",
50
+ ".down_proj.",
51
+ ".up_proj.",
52
+ ".q_proj.",
53
+ ".k_proj.",
54
+ ".v_proj.",
55
+ ".o_proj.",
56
+ ]
57
+ bitsandbytes_stacked_params_mapping = {
58
+ # shard_name, weight_name, index
59
+ "q_proj": ("qkv_proj", 0),
60
+ "k_proj": ("qkv_proj", 1),
61
+ "v_proj": ("qkv_proj", 2),
62
+ "gate_proj": ("gate_up_proj", 0),
63
+ "up_proj": ("gate_up_proj", 1),
64
+ }
65
+
66
+ def __init__(
67
+ self,
68
+ config: MiMoConfig,
69
+ quant_config: Optional[QuantizationConfig] = None,
70
+ prefix: str = "",
71
+ ) -> None:
72
+ super().__init__()
73
+ self.config = config
74
+ self.quant_config = quant_config
75
+ self.model = MiMoModel(
76
+ config, quant_config=quant_config, prefix=add_prefix("model", prefix)
77
+ )
78
+ if config.tie_word_embeddings:
79
+ self.lm_head = self.model.embed_tokens
80
+ else:
81
+ self.lm_head = ParallelLMHead(
82
+ config.vocab_size,
83
+ config.hidden_size,
84
+ quant_config=quant_config,
85
+ prefix=add_prefix("lm_head", prefix),
86
+ )
87
+ self.logits_processor = LogitsProcessor(config)
88
+ self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
89
+
90
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
91
+ return self.model.get_input_embeddings(input_ids)
92
+
93
+ @torch.no_grad()
94
+ def forward(
95
+ self,
96
+ input_ids: torch.Tensor,
97
+ positions: torch.Tensor,
98
+ forward_batch: ForwardBatch,
99
+ input_embeds: torch.Tensor = None,
100
+ get_embedding: bool = False,
101
+ ) -> torch.Tensor:
102
+ hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
103
+ if not get_embedding:
104
+ return self.logits_processor(
105
+ input_ids, hidden_states, self.lm_head, forward_batch
106
+ )
107
+ else:
108
+ return self.pooler(hidden_states, forward_batch)
109
+
110
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
111
+ stacked_params_mapping = [
112
+ # (param_name, shard_name, shard_id)
113
+ ("qkv_proj", "q_proj", "q"),
114
+ ("qkv_proj", "k_proj", "k"),
115
+ ("qkv_proj", "v_proj", "v"),
116
+ ("gate_up_proj", "gate_proj", 0),
117
+ ("gate_up_proj", "up_proj", 1),
118
+ ]
119
+
120
+ params_dict = dict(self.named_parameters())
121
+ for name, loaded_weight in weights:
122
+ if (
123
+ "rotary_emb.inv_freq" in name
124
+ or "projector" in name
125
+ or "mtp_layers" in name
126
+ ):
127
+ continue
128
+ if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
129
+ # Models trained using ColossalAI may include these tensors in
130
+ # the checkpoint. Skip them.
131
+ continue
132
+ if self.config.tie_word_embeddings and "lm_head.weight" in name:
133
+ continue
134
+ if name.startswith("model.vision_tower") and name not in params_dict:
135
+ continue
136
+
137
+ for param_name, weight_name, shard_id in stacked_params_mapping:
138
+ if weight_name not in name:
139
+ continue
140
+ name = name.replace(weight_name, param_name)
141
+ # Skip loading extra bias for GPTQ models.
142
+ if name.endswith(".bias") and name not in params_dict:
143
+ continue
144
+ param = params_dict[name]
145
+ weight_loader = param.weight_loader
146
+ weight_loader(param, loaded_weight, shard_id)
147
+ break
148
+ else:
149
+ # Skip loading extra bias for GPTQ models.
150
+ if name.endswith(".bias") and name not in params_dict:
151
+ continue
152
+ param = params_dict[name]
153
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
154
+ weight_loader(param, loaded_weight)
155
+
156
+ def get_embed_and_head(self):
157
+ return self.model.embed_tokens.weight, self.lm_head.weight
158
+
159
+ def set_embed_and_head(self, embed, head):
160
+ del self.model.embed_tokens.weight
161
+ del self.lm_head.weight
162
+ self.model.embed_tokens.weight = embed
163
+ self.lm_head.weight = head
164
+ torch.cuda.empty_cache()
165
+ torch.cuda.synchronize()
166
+
167
+ def load_kv_cache_scales(self, quantization_param_path: str) -> None:
168
+ self.model.load_kv_cache_scales(quantization_param_path)
169
+
170
+
171
+ EntryClass = MiMoForCausalLM
@@ -14,6 +14,7 @@
14
14
  """Conversion between OpenAI APIs and native SRT APIs"""
15
15
 
16
16
  import asyncio
17
+ import base64
17
18
  import json
18
19
  import logging
19
20
  import os
@@ -36,6 +37,7 @@ from sglang.srt.conversation import (
36
37
  chat_template_exists,
37
38
  generate_chat_conv,
38
39
  generate_embedding_convs,
40
+ get_conv_template_by_model_path,
39
41
  register_conv_template,
40
42
  )
41
43
  from sglang.srt.function_call_parser import FunctionCallParser
@@ -163,10 +165,14 @@ def load_chat_template_for_openai_api(tokenizer_manager, chat_template_arg, mode
163
165
  else:
164
166
  chat_template_name = chat_template_arg
165
167
 
166
- # Check chat-template
167
- # TODO:
168
- # 1. Do not import any code from sglang.lang
169
- # 2. For VLM, when chat_template_arg is None, set it automatically by guessing from model_path.
168
+
169
+ def guess_chat_template_name_from_model_path(model_path):
170
+ global chat_template_name
171
+ chat_template_name = get_conv_template_by_model_path(model_path)
172
+ if chat_template_name is not None:
173
+ logger.info(
174
+ f"Infer the chat template name from the model path and obtain the result: {chat_template_name}."
175
+ )
170
176
 
171
177
 
172
178
  async def v1_files_create(
@@ -523,6 +529,7 @@ def v1_generate_request(
523
529
  "temperature": request.temperature,
524
530
  "max_new_tokens": request.max_tokens,
525
531
  "min_new_tokens": request.min_tokens,
532
+ "thinking_budget": request.thinking_budget,
526
533
  "stop": request.stop,
527
534
  "stop_token_ids": request.stop_token_ids,
528
535
  "top_p": request.top_p,
@@ -894,6 +901,24 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
894
901
  return response
895
902
 
896
903
 
904
+ def _get_enable_thinking_from_request(request_obj):
905
+ """Extracts the 'enable_thinking' flag from request chat_template_kwargs.
906
+
907
+ Args:
908
+ request_obj: The request object (or an item from a list of requests).
909
+
910
+ Returns:
911
+ The boolean value of 'enable_thinking' if found and not True, otherwise True.
912
+ """
913
+ if (
914
+ hasattr(request_obj, "chat_template_kwargs")
915
+ and request_obj.chat_template_kwargs
916
+ and request_obj.chat_template_kwargs.get("enable_thinking") is not None
917
+ ):
918
+ return request_obj.chat_template_kwargs.get("enable_thinking")
919
+ return True
920
+
921
+
897
922
  def v1_chat_generate_request(
898
923
  all_requests: List[ChatCompletionRequest],
899
924
  tokenizer_manager,
@@ -943,47 +968,23 @@ def v1_chat_generate_request(
943
968
 
944
969
  if chat_template_name is None:
945
970
  openai_compatible_messages = []
946
- if (
947
- tools
948
- and tokenizer_manager.server_args.tool_call_parser == "deepseekv3"
949
- ):
950
- # add function call prompt to deepseekv3
951
- openai_compatible_messages.append(
952
- {
953
- "role": "system",
954
- "content": """You are a helpful Assistant.
955
- ## Tools
956
- ### Function
957
- You have the following functions available:
958
- """
959
- + "".join(
960
- [
961
- f"""
962
- - `{tool['name']}`:
963
- ```json
964
- {json.dumps(tool)}
965
- ```
966
- """
967
- for tool in tools
968
- ]
969
- ),
970
- }
971
- )
972
971
 
973
972
  for message in request.messages:
974
973
  if message.content is None:
975
974
  message.content = ""
976
- if isinstance(message.content, str):
977
- openai_compatible_messages.append(
978
- {"role": message.role, "content": message.content}
979
- )
975
+ msg_dict = message.dict()
976
+ if isinstance(msg_dict.get("content"), list):
977
+ for chunk in msg_dict["content"]:
978
+ if isinstance(chunk, dict) and chunk.get("type") == "text":
979
+ new_msg = msg_dict.copy()
980
+ new_msg["content"] = chunk["text"]
981
+ new_msg = {
982
+ k: v for k, v in new_msg.items() if v is not None
983
+ }
984
+ openai_compatible_messages.append(new_msg)
980
985
  else:
981
- content_list = message.dict()["content"]
982
- for content in content_list:
983
- if content["type"] == "text":
984
- openai_compatible_messages.append(
985
- {"role": message.role, "content": content["text"]}
986
- )
986
+ msg_dict = {k: v for k, v in msg_dict.items() if v is not None}
987
+ openai_compatible_messages.append(msg_dict)
987
988
  if (
988
989
  openai_compatible_messages
989
990
  and openai_compatible_messages[-1]["role"] == "assistant"
@@ -1099,8 +1100,9 @@ def v1_chat_generate_request(
1099
1100
 
1100
1101
  sampling_params = {
1101
1102
  "temperature": request.temperature,
1102
- "max_new_tokens": request.max_tokens,
1103
+ "max_new_tokens": request.max_tokens or request.max_completion_tokens,
1103
1104
  "min_new_tokens": request.min_tokens,
1105
+ "thinking_budget": request.thinking_budget,
1104
1106
  "stop": stop,
1105
1107
  "stop_token_ids": request.stop_token_ids,
1106
1108
  "top_p": request.top_p,
@@ -1258,31 +1260,16 @@ def v1_chat_generate_response(
1258
1260
  tool_calls = None
1259
1261
  text = ret_item["text"]
1260
1262
 
1261
- enable_thinking = True
1262
1263
  if isinstance(request, list):
1263
1264
  tool_choice = request[idx].tool_choice
1264
1265
  tools = request[idx].tools
1265
1266
  separate_reasoning = request[idx].separate_reasoning
1266
-
1267
- if (
1268
- request[idx].chat_template_kwargs
1269
- and request[idx].chat_template_kwargs.get("enable_thinking") is not None
1270
- ):
1271
- enable_thinking = request[idx].chat_template_kwargs.get(
1272
- "enable_thinking", True
1273
- )
1267
+ enable_thinking = _get_enable_thinking_from_request(request[idx])
1274
1268
  else:
1275
1269
  tool_choice = request.tool_choice
1276
1270
  tools = request.tools
1277
1271
  separate_reasoning = request.separate_reasoning
1278
-
1279
- if (
1280
- request.chat_template_kwargs
1281
- and request.chat_template_kwargs.get("enable_thinking") is not None
1282
- ):
1283
- enable_thinking = request.chat_template_kwargs.get(
1284
- "enable_thinking", True
1285
- )
1272
+ enable_thinking = _get_enable_thinking_from_request(request)
1286
1273
 
1287
1274
  reasoning_text = None
1288
1275
  if reasoning_parser and separate_reasoning and enable_thinking:
@@ -1308,7 +1295,8 @@ def v1_chat_generate_response(
1308
1295
  text, call_info_list = parser.parse_non_stream(text)
1309
1296
  tool_calls = [
1310
1297
  ToolCall(
1311
- id=str(call_info.tool_index),
1298
+ id=f"call_{base64.urlsafe_b64encode(uuid.uuid4().bytes).rstrip(b'=').decode()}",
1299
+ index=call_info.tool_index,
1312
1300
  function=FunctionResponse(
1313
1301
  name=call_info.name, arguments=call_info.parameters
1314
1302
  ),
@@ -1424,6 +1412,7 @@ async def v1_chat_completions(
1424
1412
  reasoning_parser_dict = {}
1425
1413
 
1426
1414
  async def generate_stream_resp():
1415
+ tool_call_first = True
1427
1416
  is_firsts = {}
1428
1417
  stream_buffers = {}
1429
1418
  n_prev_tokens = {}
@@ -1521,9 +1510,12 @@ async def v1_chat_completions(
1521
1510
  delta = text[len(stream_buffer) :]
1522
1511
  new_stream_buffer = stream_buffer + delta
1523
1512
 
1513
+ enable_thinking = _get_enable_thinking_from_request(request)
1514
+
1524
1515
  if (
1525
1516
  tokenizer_manager.server_args.reasoning_parser
1526
1517
  and request.separate_reasoning
1518
+ and enable_thinking
1527
1519
  ):
1528
1520
  if index not in reasoning_parser_dict:
1529
1521
  reasoning_parser_dict[index] = ReasoningParser(
@@ -1587,7 +1579,6 @@ async def v1_chat_completions(
1587
1579
  # 2) if we found calls, we output them as separate chunk(s)
1588
1580
  for call_item in calls:
1589
1581
  # transform call_item -> FunctionResponse + ToolCall
1590
-
1591
1582
  if finish_reason_type == "stop":
1592
1583
  latest_delta_len = 0
1593
1584
  if isinstance(call_item.parameters, str):
@@ -1610,14 +1601,19 @@ async def v1_chat_completions(
1610
1601
  call_item.parameters = remaining_call
1611
1602
 
1612
1603
  finish_reason_type = "tool_calls"
1613
-
1614
1604
  tool_call = ToolCall(
1615
- id=str(call_item.tool_index),
1605
+ id=(
1606
+ f"call_{base64.urlsafe_b64encode(uuid.uuid4().bytes).rstrip(b'=').decode()}"
1607
+ if tool_call_first
1608
+ else None
1609
+ ),
1610
+ index=call_item.tool_index,
1616
1611
  function=FunctionResponse(
1617
1612
  name=call_item.name,
1618
1613
  arguments=call_item.parameters,
1619
1614
  ),
1620
1615
  )
1616
+ tool_call_first = False
1621
1617
  choice_data = ChatCompletionResponseStreamChoice(
1622
1618
  index=index,
1623
1619
  delta=DeltaMessage(tool_calls=[tool_call]),
@@ -172,6 +172,7 @@ class CompletionRequest(BaseModel):
172
172
  top_k: int = -1
173
173
  min_p: float = 0.0
174
174
  min_tokens: int = 0
175
+ thinking_budget: Optional[int] = None
175
176
  json_schema: Optional[str] = None
176
177
  regex: Optional[str] = None
177
178
  ebnf: Optional[str] = None
@@ -250,9 +251,29 @@ ChatCompletionMessageContentPart = Union[
250
251
  ]
251
252
 
252
253
 
254
+ class FunctionResponse(BaseModel):
255
+ """Function response."""
256
+
257
+ name: Optional[str] = None
258
+ arguments: Optional[str] = None
259
+
260
+
261
+ class ToolCall(BaseModel):
262
+ """Tool call response."""
263
+
264
+ id: Optional[str] = None
265
+ index: Optional[int] = None
266
+ type: Literal["function"] = "function"
267
+ function: FunctionResponse
268
+
269
+
253
270
  class ChatCompletionMessageGenericParam(BaseModel):
254
271
  role: Literal["system", "assistant", "tool"]
255
272
  content: Union[str, List[ChatCompletionMessageContentTextPart], None]
273
+ tool_call_id: Optional[str] = None
274
+ name: Optional[str] = None
275
+ reasoning_content: Optional[str] = None
276
+ tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
256
277
 
257
278
 
258
279
  class ChatCompletionMessageUserParam(BaseModel):
@@ -320,7 +341,23 @@ class ChatCompletionRequest(BaseModel):
320
341
  logit_bias: Optional[Dict[str, float]] = None
321
342
  logprobs: bool = False
322
343
  top_logprobs: Optional[int] = None
323
- max_tokens: Optional[int] = None
344
+ max_tokens: Optional[int] = Field(
345
+ default=None,
346
+ deprecated="max_tokens is deprecated in favor of the max_completion_tokens field",
347
+ description="The maximum number of tokens that can be generated in the chat completion. ",
348
+ )
349
+ max_completion_tokens: Optional[int] = Field(
350
+ default=None,
351
+ description="The maximum number of completion tokens for a chat completion request, "
352
+ "including visible output tokens and reasoning tokens. Input tokens are not included. ",
353
+ )
354
+ thinking_budget: Optional[int] = Field(
355
+ default=None,
356
+ description="The maximum number of reasoning tokens that can be generated for a request. "
357
+ "This setting of does not affect the thinking process of models. "
358
+ "If the number of tokens generated by the model's thinking process exceeds thinking_budget, "
359
+ "the reasoning content will be truncated and the final response content will be generated immediately.",
360
+ )
324
361
  n: int = 1
325
362
  presence_penalty: float = 0.0
326
363
  response_format: Optional[Union[ResponseFormat, StructuralTagResponseFormat]] = None
@@ -369,21 +406,6 @@ class ChatCompletionRequest(BaseModel):
369
406
  bootstrap_room: Optional[int] = None
370
407
 
371
408
 
372
- class FunctionResponse(BaseModel):
373
- """Function response."""
374
-
375
- name: Optional[str] = None
376
- arguments: Optional[str] = None
377
-
378
-
379
- class ToolCall(BaseModel):
380
- """Tool call response."""
381
-
382
- id: str
383
- type: Literal["function"] = "function"
384
- function: FunctionResponse
385
-
386
-
387
409
  class ChatMessage(BaseModel):
388
410
  role: Optional[str] = None
389
411
  content: Optional[str] = None
@@ -32,7 +32,7 @@ class BaseReasoningFormatDetector:
32
32
  One-time parsing: Detects and parses reasoning sections in the provided text.
33
33
  Returns both reasoning content and normal text separately.
34
34
  """
35
- text = text.replace(self.think_start_token, "").strip()
35
+ text = text.replace(self.think_start_token, "")
36
36
  if self.think_end_token not in text:
37
37
  # Assume reasoning was truncated before `</think>` token
38
38
  return StreamingParseResult(reasoning_text=text)
@@ -73,7 +73,7 @@ class BaseReasoningFormatDetector:
73
73
  normal_text = current_text[end_idx + len(self.think_end_token) :]
74
74
 
75
75
  return StreamingParseResult(
76
- normal_text=normal_text, reasoning_text=reasoning_text.rstrip()
76
+ normal_text=normal_text, reasoning_text=reasoning_text
77
77
  )
78
78
 
79
79
  # Continue with reasoning content
@@ -30,8 +30,13 @@ class SamplingBatchInfo:
30
30
  # Whether any request needs min_p sampling
31
31
  need_min_p_sampling: bool
32
32
 
33
+ # Use thinking_budget to truncate thinking
34
+ num_thinking_tokens: Optional[torch.Tensor] = None
35
+ think_end_ids: Optional[torch.Tensor] = None
36
+ thinking_budgets: Optional[torch.Tensor] = None
37
+
33
38
  # Masking tensors for grammar-guided structured outputs
34
- vocab_size: int
39
+ vocab_size: int = 0
35
40
  grammars: Optional[List] = None
36
41
  vocab_mask: Optional[torch.Tensor] = None
37
42
  apply_mask_func: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None
@@ -76,7 +81,22 @@ class SamplingBatchInfo:
76
81
  min_ps = torch.tensor(
77
82
  [r.sampling_params.min_p for r in reqs], dtype=torch.float
78
83
  ).to(device, non_blocking=True)
79
-
84
+ if any(hasattr(r.tokenizer, "think_end_id") for r in reqs):
85
+ think_end_ids = torch.tensor(
86
+ [getattr(r.tokenizer, "think_end_id", -1) for r in reqs],
87
+ dtype=torch.int64,
88
+ ).to(device, non_blocking=True)
89
+ num_thinking_tokens = torch.tensor([0 for _ in reqs], dtype=torch.int64).to(
90
+ device, non_blocking=True
91
+ )
92
+ thinking_budgets = torch.tensor(
93
+ [r.sampling_params.thinking_budget or -1 for r in reqs],
94
+ dtype=torch.int64,
95
+ ).to(device, non_blocking=True)
96
+ else:
97
+ think_end_ids = None
98
+ num_thinking_tokens = None
99
+ thinking_budgets = None
80
100
  # Check if any request has custom logit processor
81
101
  has_custom_logit_processor = (
82
102
  batch.enable_custom_logit_processor # check the flag first.
@@ -132,6 +152,9 @@ class SamplingBatchInfo:
132
152
  top_ps=top_ps,
133
153
  top_ks=top_ks,
134
154
  min_ps=min_ps,
155
+ think_end_ids=think_end_ids,
156
+ num_thinking_tokens=num_thinking_tokens,
157
+ thinking_budgets=thinking_budgets,
135
158
  is_all_greedy=all(r.sampling_params.top_k <= 1 for r in reqs),
136
159
  need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs),
137
160
  vocab_size=vocab_size,
@@ -146,6 +169,35 @@ class SamplingBatchInfo:
146
169
  def __len__(self):
147
170
  return len(self.temperatures)
148
171
 
172
+ def apply_thinking_budgets(self, next_token_logits: torch.Tensor):
173
+ has_budget = self.thinking_budgets > 0
174
+ if not has_budget.any():
175
+ return
176
+ torch.where(
177
+ has_budget,
178
+ self.num_thinking_tokens + 1,
179
+ self.num_thinking_tokens,
180
+ out=self.num_thinking_tokens,
181
+ )
182
+ should_stop = has_budget & (
183
+ self.num_thinking_tokens - 1 > self.thinking_budgets
184
+ )
185
+ next_token_logits.masked_fill_(should_stop.unsqueeze(0), float("-inf"))
186
+ batch_indices = torch.nonzero(should_stop, as_tuple=True)[0]
187
+ if len(batch_indices) > 0:
188
+ end_token_indices = self.think_end_ids[batch_indices]
189
+ next_token_logits[batch_indices, end_token_indices] = 0.0
190
+
191
+ def update_thinking_budgets(self, next_token_ids: torch.Tensor):
192
+ if not torch.any(self.thinking_budgets > 0):
193
+ return
194
+ torch.where(
195
+ next_token_ids == self.think_end_ids,
196
+ torch.tensor(-1, device=self.thinking_budgets.device),
197
+ self.thinking_budgets,
198
+ out=self.thinking_budgets,
199
+ )
200
+
149
201
  def update_regex_vocab_mask(self):
150
202
  if not self.grammars:
151
203
  self.vocab_mask = None
@@ -30,6 +30,7 @@ class SamplingParams:
30
30
  def __init__(
31
31
  self,
32
32
  max_new_tokens: int = 128,
33
+ thinking_budget: Optional[int] = None,
33
34
  stop: Optional[Union[str, List[str]]] = None,
34
35
  stop_token_ids: Optional[List[int]] = None,
35
36
  temperature: float = 1.0,
@@ -57,6 +58,7 @@ class SamplingParams:
57
58
  self.stop_token_ids = set(stop_token_ids)
58
59
  else:
59
60
  self.stop_token_ids = None
61
+ self.thinking_budget = thinking_budget
60
62
  self.temperature = temperature
61
63
  self.top_p = top_p
62
64
  self.top_k = top_k