sglang 0.4.6.post1__py3-none-any.whl → 0.4.6.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -11
- sglang/bench_serving.py +149 -1
- sglang/check_env.py +3 -3
- sglang/lang/chat_template.py +44 -0
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/deepseekvl2.py +3 -0
- sglang/srt/configs/device_config.py +1 -1
- sglang/srt/configs/internvl.py +696 -0
- sglang/srt/configs/janus_pro.py +3 -0
- sglang/srt/configs/kimi_vl.py +38 -0
- sglang/srt/configs/kimi_vl_moonvit.py +32 -0
- sglang/srt/configs/model_config.py +32 -0
- sglang/srt/constrained/xgrammar_backend.py +11 -19
- sglang/srt/conversation.py +151 -3
- sglang/srt/disaggregation/decode.py +4 -1
- sglang/srt/disaggregation/mini_lb.py +74 -23
- sglang/srt/disaggregation/mooncake/conn.py +9 -18
- sglang/srt/disaggregation/nixl/conn.py +241 -71
- sglang/srt/disaggregation/utils.py +44 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
- sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
- sglang/srt/distributed/device_communicators/pynccl.py +2 -1
- sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
- sglang/srt/distributed/parallel_state.py +22 -1
- sglang/srt/entrypoints/engine.py +58 -24
- sglang/srt/entrypoints/http_server.py +28 -1
- sglang/srt/entrypoints/verl_engine.py +3 -2
- sglang/srt/function_call_parser.py +97 -0
- sglang/srt/hf_transformers_utils.py +22 -1
- sglang/srt/layers/attention/cutlass_mla_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +146 -50
- sglang/srt/layers/attention/flashinfer_backend.py +129 -94
- sglang/srt/layers/attention/flashinfer_mla_backend.py +88 -30
- sglang/srt/layers/attention/flashmla_backend.py +3 -0
- sglang/srt/layers/attention/merge_state.py +46 -0
- sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
- sglang/srt/layers/attention/vision.py +290 -163
- sglang/srt/layers/dp_attention.py +5 -2
- sglang/srt/layers/moe/ep_moe/kernels.py +342 -7
- sglang/srt/layers/moe/ep_moe/layer.py +120 -1
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +98 -57
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=768,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +10 -5
- sglang/srt/layers/quantization/__init__.py +2 -2
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
- sglang/srt/layers/quantization/deep_gemm.py +6 -1
- sglang/srt/layers/quantization/fp8.py +108 -95
- sglang/srt/layers/quantization/fp8_kernel.py +79 -60
- sglang/srt/layers/quantization/fp8_utils.py +71 -23
- sglang/srt/layers/quantization/kv_cache.py +3 -10
- sglang/srt/layers/quantization/utils.py +0 -5
- sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
- sglang/srt/layers/utils.py +35 -0
- sglang/srt/lora/layers.py +35 -9
- sglang/srt/lora/lora_manager.py +81 -35
- sglang/srt/managers/cache_controller.py +115 -119
- sglang/srt/managers/data_parallel_controller.py +52 -34
- sglang/srt/managers/io_struct.py +10 -0
- sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
- sglang/srt/managers/multimodal_processors/internvl.py +232 -0
- sglang/srt/managers/multimodal_processors/kimi_vl.py +73 -0
- sglang/srt/managers/schedule_batch.py +44 -16
- sglang/srt/managers/schedule_policy.py +11 -5
- sglang/srt/managers/scheduler.py +291 -72
- sglang/srt/managers/scheduler_output_processor_mixin.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +24 -13
- sglang/srt/managers/tp_worker.py +60 -28
- sglang/srt/managers/tp_worker_overlap_thread.py +9 -3
- sglang/srt/mem_cache/chunk_cache.py +2 -0
- sglang/srt/mem_cache/memory_pool.py +70 -36
- sglang/srt/model_executor/cuda_graph_runner.py +82 -19
- sglang/srt/model_executor/forward_batch_info.py +31 -1
- sglang/srt/model_executor/model_runner.py +159 -90
- sglang/srt/model_loader/loader.py +18 -11
- sglang/srt/models/clip.py +4 -4
- sglang/srt/models/deepseek_janus_pro.py +1 -1
- sglang/srt/models/deepseek_nextn.py +2 -277
- sglang/srt/models/deepseek_v2.py +132 -37
- sglang/srt/models/gemma3_mm.py +1 -1
- sglang/srt/models/internlm2.py +3 -0
- sglang/srt/models/internvl.py +670 -0
- sglang/srt/models/kimi_vl.py +308 -0
- sglang/srt/models/kimi_vl_moonvit.py +639 -0
- sglang/srt/models/llama.py +93 -31
- sglang/srt/models/llama4.py +54 -7
- sglang/srt/models/llama_eagle.py +4 -1
- sglang/srt/models/llama_eagle3.py +4 -1
- sglang/srt/models/minicpmv.py +1 -1
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/phi3_small.py +16 -2
- sglang/srt/models/qwen2_5_vl.py +8 -4
- sglang/srt/models/qwen2_moe.py +8 -3
- sglang/srt/models/qwen2_vl.py +4 -16
- sglang/srt/models/qwen3_moe.py +8 -3
- sglang/srt/models/xiaomi_mimo.py +171 -0
- sglang/srt/openai_api/adapter.py +58 -62
- sglang/srt/openai_api/protocol.py +38 -16
- sglang/srt/reasoning_parser.py +2 -2
- sglang/srt/sampling/sampling_batch_info.py +54 -2
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server_args.py +93 -24
- sglang/srt/speculative/eagle_worker.py +3 -2
- sglang/srt/utils.py +123 -10
- sglang/test/runners.py +4 -0
- sglang/test/test_block_fp8.py +2 -2
- sglang/test/test_deepep_utils.py +219 -0
- sglang/test/test_utils.py +32 -1
- sglang/version.py +1 -1
- {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/METADATA +18 -9
- {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/RECORD +119 -99
- {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/WHEEL +1 -1
- {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post1.dist-info → sglang-0.4.6.post3.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,171 @@
|
|
1
|
+
# Adapted from qwen2.py
|
2
|
+
|
3
|
+
from functools import partial
|
4
|
+
from typing import Any, Dict, Iterable, Optional, Tuple
|
5
|
+
|
6
|
+
import torch
|
7
|
+
from torch import nn
|
8
|
+
|
9
|
+
from sglang.srt.distributed import (
|
10
|
+
get_tensor_model_parallel_rank,
|
11
|
+
get_tensor_model_parallel_world_size,
|
12
|
+
split_tensor_along_last_dim,
|
13
|
+
tensor_model_parallel_all_gather,
|
14
|
+
)
|
15
|
+
from sglang.srt.layers.layernorm import RMSNorm
|
16
|
+
from sglang.srt.layers.linear import QKVParallelLinear, RowParallelLinear
|
17
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor
|
18
|
+
from sglang.srt.layers.pooler import Pooler, PoolingType
|
19
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
20
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
21
|
+
from sglang.srt.layers.rotary_embedding import get_rope
|
22
|
+
from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
|
23
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
24
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
25
|
+
from sglang.srt.models.qwen2 import Qwen2DecoderLayer, Qwen2MLP, Qwen2Model
|
26
|
+
from sglang.srt.utils import add_prefix
|
27
|
+
|
28
|
+
MiMoConfig = None
|
29
|
+
|
30
|
+
|
31
|
+
class MiMoModel(Qwen2Model):
|
32
|
+
def __init__(
|
33
|
+
self,
|
34
|
+
config: MiMoConfig,
|
35
|
+
quant_config: Optional[QuantizationConfig] = None,
|
36
|
+
prefix: str = "",
|
37
|
+
) -> None:
|
38
|
+
super().__init__(
|
39
|
+
config=config,
|
40
|
+
quant_config=quant_config,
|
41
|
+
prefix=prefix,
|
42
|
+
decoder_layer_type=Qwen2DecoderLayer,
|
43
|
+
)
|
44
|
+
|
45
|
+
|
46
|
+
class MiMoForCausalLM(nn.Module):
|
47
|
+
# BitandBytes specific attributes
|
48
|
+
default_bitsandbytes_target_modules = [
|
49
|
+
".gate_proj.",
|
50
|
+
".down_proj.",
|
51
|
+
".up_proj.",
|
52
|
+
".q_proj.",
|
53
|
+
".k_proj.",
|
54
|
+
".v_proj.",
|
55
|
+
".o_proj.",
|
56
|
+
]
|
57
|
+
bitsandbytes_stacked_params_mapping = {
|
58
|
+
# shard_name, weight_name, index
|
59
|
+
"q_proj": ("qkv_proj", 0),
|
60
|
+
"k_proj": ("qkv_proj", 1),
|
61
|
+
"v_proj": ("qkv_proj", 2),
|
62
|
+
"gate_proj": ("gate_up_proj", 0),
|
63
|
+
"up_proj": ("gate_up_proj", 1),
|
64
|
+
}
|
65
|
+
|
66
|
+
def __init__(
|
67
|
+
self,
|
68
|
+
config: MiMoConfig,
|
69
|
+
quant_config: Optional[QuantizationConfig] = None,
|
70
|
+
prefix: str = "",
|
71
|
+
) -> None:
|
72
|
+
super().__init__()
|
73
|
+
self.config = config
|
74
|
+
self.quant_config = quant_config
|
75
|
+
self.model = MiMoModel(
|
76
|
+
config, quant_config=quant_config, prefix=add_prefix("model", prefix)
|
77
|
+
)
|
78
|
+
if config.tie_word_embeddings:
|
79
|
+
self.lm_head = self.model.embed_tokens
|
80
|
+
else:
|
81
|
+
self.lm_head = ParallelLMHead(
|
82
|
+
config.vocab_size,
|
83
|
+
config.hidden_size,
|
84
|
+
quant_config=quant_config,
|
85
|
+
prefix=add_prefix("lm_head", prefix),
|
86
|
+
)
|
87
|
+
self.logits_processor = LogitsProcessor(config)
|
88
|
+
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
89
|
+
|
90
|
+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
91
|
+
return self.model.get_input_embeddings(input_ids)
|
92
|
+
|
93
|
+
@torch.no_grad()
|
94
|
+
def forward(
|
95
|
+
self,
|
96
|
+
input_ids: torch.Tensor,
|
97
|
+
positions: torch.Tensor,
|
98
|
+
forward_batch: ForwardBatch,
|
99
|
+
input_embeds: torch.Tensor = None,
|
100
|
+
get_embedding: bool = False,
|
101
|
+
) -> torch.Tensor:
|
102
|
+
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
103
|
+
if not get_embedding:
|
104
|
+
return self.logits_processor(
|
105
|
+
input_ids, hidden_states, self.lm_head, forward_batch
|
106
|
+
)
|
107
|
+
else:
|
108
|
+
return self.pooler(hidden_states, forward_batch)
|
109
|
+
|
110
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
111
|
+
stacked_params_mapping = [
|
112
|
+
# (param_name, shard_name, shard_id)
|
113
|
+
("qkv_proj", "q_proj", "q"),
|
114
|
+
("qkv_proj", "k_proj", "k"),
|
115
|
+
("qkv_proj", "v_proj", "v"),
|
116
|
+
("gate_up_proj", "gate_proj", 0),
|
117
|
+
("gate_up_proj", "up_proj", 1),
|
118
|
+
]
|
119
|
+
|
120
|
+
params_dict = dict(self.named_parameters())
|
121
|
+
for name, loaded_weight in weights:
|
122
|
+
if (
|
123
|
+
"rotary_emb.inv_freq" in name
|
124
|
+
or "projector" in name
|
125
|
+
or "mtp_layers" in name
|
126
|
+
):
|
127
|
+
continue
|
128
|
+
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
129
|
+
# Models trained using ColossalAI may include these tensors in
|
130
|
+
# the checkpoint. Skip them.
|
131
|
+
continue
|
132
|
+
if self.config.tie_word_embeddings and "lm_head.weight" in name:
|
133
|
+
continue
|
134
|
+
if name.startswith("model.vision_tower") and name not in params_dict:
|
135
|
+
continue
|
136
|
+
|
137
|
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
138
|
+
if weight_name not in name:
|
139
|
+
continue
|
140
|
+
name = name.replace(weight_name, param_name)
|
141
|
+
# Skip loading extra bias for GPTQ models.
|
142
|
+
if name.endswith(".bias") and name not in params_dict:
|
143
|
+
continue
|
144
|
+
param = params_dict[name]
|
145
|
+
weight_loader = param.weight_loader
|
146
|
+
weight_loader(param, loaded_weight, shard_id)
|
147
|
+
break
|
148
|
+
else:
|
149
|
+
# Skip loading extra bias for GPTQ models.
|
150
|
+
if name.endswith(".bias") and name not in params_dict:
|
151
|
+
continue
|
152
|
+
param = params_dict[name]
|
153
|
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
154
|
+
weight_loader(param, loaded_weight)
|
155
|
+
|
156
|
+
def get_embed_and_head(self):
|
157
|
+
return self.model.embed_tokens.weight, self.lm_head.weight
|
158
|
+
|
159
|
+
def set_embed_and_head(self, embed, head):
|
160
|
+
del self.model.embed_tokens.weight
|
161
|
+
del self.lm_head.weight
|
162
|
+
self.model.embed_tokens.weight = embed
|
163
|
+
self.lm_head.weight = head
|
164
|
+
torch.cuda.empty_cache()
|
165
|
+
torch.cuda.synchronize()
|
166
|
+
|
167
|
+
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
|
168
|
+
self.model.load_kv_cache_scales(quantization_param_path)
|
169
|
+
|
170
|
+
|
171
|
+
EntryClass = MiMoForCausalLM
|
sglang/srt/openai_api/adapter.py
CHANGED
@@ -14,6 +14,7 @@
|
|
14
14
|
"""Conversion between OpenAI APIs and native SRT APIs"""
|
15
15
|
|
16
16
|
import asyncio
|
17
|
+
import base64
|
17
18
|
import json
|
18
19
|
import logging
|
19
20
|
import os
|
@@ -36,6 +37,7 @@ from sglang.srt.conversation import (
|
|
36
37
|
chat_template_exists,
|
37
38
|
generate_chat_conv,
|
38
39
|
generate_embedding_convs,
|
40
|
+
get_conv_template_by_model_path,
|
39
41
|
register_conv_template,
|
40
42
|
)
|
41
43
|
from sglang.srt.function_call_parser import FunctionCallParser
|
@@ -163,10 +165,14 @@ def load_chat_template_for_openai_api(tokenizer_manager, chat_template_arg, mode
|
|
163
165
|
else:
|
164
166
|
chat_template_name = chat_template_arg
|
165
167
|
|
166
|
-
|
167
|
-
|
168
|
-
|
169
|
-
|
168
|
+
|
169
|
+
def guess_chat_template_name_from_model_path(model_path):
|
170
|
+
global chat_template_name
|
171
|
+
chat_template_name = get_conv_template_by_model_path(model_path)
|
172
|
+
if chat_template_name is not None:
|
173
|
+
logger.info(
|
174
|
+
f"Infer the chat template name from the model path and obtain the result: {chat_template_name}."
|
175
|
+
)
|
170
176
|
|
171
177
|
|
172
178
|
async def v1_files_create(
|
@@ -523,6 +529,7 @@ def v1_generate_request(
|
|
523
529
|
"temperature": request.temperature,
|
524
530
|
"max_new_tokens": request.max_tokens,
|
525
531
|
"min_new_tokens": request.min_tokens,
|
532
|
+
"thinking_budget": request.thinking_budget,
|
526
533
|
"stop": request.stop,
|
527
534
|
"stop_token_ids": request.stop_token_ids,
|
528
535
|
"top_p": request.top_p,
|
@@ -894,6 +901,24 @@ async def v1_completions(tokenizer_manager, raw_request: Request):
|
|
894
901
|
return response
|
895
902
|
|
896
903
|
|
904
|
+
def _get_enable_thinking_from_request(request_obj):
|
905
|
+
"""Extracts the 'enable_thinking' flag from request chat_template_kwargs.
|
906
|
+
|
907
|
+
Args:
|
908
|
+
request_obj: The request object (or an item from a list of requests).
|
909
|
+
|
910
|
+
Returns:
|
911
|
+
The boolean value of 'enable_thinking' if found and not True, otherwise True.
|
912
|
+
"""
|
913
|
+
if (
|
914
|
+
hasattr(request_obj, "chat_template_kwargs")
|
915
|
+
and request_obj.chat_template_kwargs
|
916
|
+
and request_obj.chat_template_kwargs.get("enable_thinking") is not None
|
917
|
+
):
|
918
|
+
return request_obj.chat_template_kwargs.get("enable_thinking")
|
919
|
+
return True
|
920
|
+
|
921
|
+
|
897
922
|
def v1_chat_generate_request(
|
898
923
|
all_requests: List[ChatCompletionRequest],
|
899
924
|
tokenizer_manager,
|
@@ -943,47 +968,23 @@ def v1_chat_generate_request(
|
|
943
968
|
|
944
969
|
if chat_template_name is None:
|
945
970
|
openai_compatible_messages = []
|
946
|
-
if (
|
947
|
-
tools
|
948
|
-
and tokenizer_manager.server_args.tool_call_parser == "deepseekv3"
|
949
|
-
):
|
950
|
-
# add function call prompt to deepseekv3
|
951
|
-
openai_compatible_messages.append(
|
952
|
-
{
|
953
|
-
"role": "system",
|
954
|
-
"content": """You are a helpful Assistant.
|
955
|
-
## Tools
|
956
|
-
### Function
|
957
|
-
You have the following functions available:
|
958
|
-
"""
|
959
|
-
+ "".join(
|
960
|
-
[
|
961
|
-
f"""
|
962
|
-
- `{tool['name']}`:
|
963
|
-
```json
|
964
|
-
{json.dumps(tool)}
|
965
|
-
```
|
966
|
-
"""
|
967
|
-
for tool in tools
|
968
|
-
]
|
969
|
-
),
|
970
|
-
}
|
971
|
-
)
|
972
971
|
|
973
972
|
for message in request.messages:
|
974
973
|
if message.content is None:
|
975
974
|
message.content = ""
|
976
|
-
|
977
|
-
|
978
|
-
|
979
|
-
|
975
|
+
msg_dict = message.dict()
|
976
|
+
if isinstance(msg_dict.get("content"), list):
|
977
|
+
for chunk in msg_dict["content"]:
|
978
|
+
if isinstance(chunk, dict) and chunk.get("type") == "text":
|
979
|
+
new_msg = msg_dict.copy()
|
980
|
+
new_msg["content"] = chunk["text"]
|
981
|
+
new_msg = {
|
982
|
+
k: v for k, v in new_msg.items() if v is not None
|
983
|
+
}
|
984
|
+
openai_compatible_messages.append(new_msg)
|
980
985
|
else:
|
981
|
-
|
982
|
-
|
983
|
-
if content["type"] == "text":
|
984
|
-
openai_compatible_messages.append(
|
985
|
-
{"role": message.role, "content": content["text"]}
|
986
|
-
)
|
986
|
+
msg_dict = {k: v for k, v in msg_dict.items() if v is not None}
|
987
|
+
openai_compatible_messages.append(msg_dict)
|
987
988
|
if (
|
988
989
|
openai_compatible_messages
|
989
990
|
and openai_compatible_messages[-1]["role"] == "assistant"
|
@@ -1099,8 +1100,9 @@ def v1_chat_generate_request(
|
|
1099
1100
|
|
1100
1101
|
sampling_params = {
|
1101
1102
|
"temperature": request.temperature,
|
1102
|
-
"max_new_tokens": request.max_tokens,
|
1103
|
+
"max_new_tokens": request.max_tokens or request.max_completion_tokens,
|
1103
1104
|
"min_new_tokens": request.min_tokens,
|
1105
|
+
"thinking_budget": request.thinking_budget,
|
1104
1106
|
"stop": stop,
|
1105
1107
|
"stop_token_ids": request.stop_token_ids,
|
1106
1108
|
"top_p": request.top_p,
|
@@ -1258,31 +1260,16 @@ def v1_chat_generate_response(
|
|
1258
1260
|
tool_calls = None
|
1259
1261
|
text = ret_item["text"]
|
1260
1262
|
|
1261
|
-
enable_thinking = True
|
1262
1263
|
if isinstance(request, list):
|
1263
1264
|
tool_choice = request[idx].tool_choice
|
1264
1265
|
tools = request[idx].tools
|
1265
1266
|
separate_reasoning = request[idx].separate_reasoning
|
1266
|
-
|
1267
|
-
if (
|
1268
|
-
request[idx].chat_template_kwargs
|
1269
|
-
and request[idx].chat_template_kwargs.get("enable_thinking") is not None
|
1270
|
-
):
|
1271
|
-
enable_thinking = request[idx].chat_template_kwargs.get(
|
1272
|
-
"enable_thinking", True
|
1273
|
-
)
|
1267
|
+
enable_thinking = _get_enable_thinking_from_request(request[idx])
|
1274
1268
|
else:
|
1275
1269
|
tool_choice = request.tool_choice
|
1276
1270
|
tools = request.tools
|
1277
1271
|
separate_reasoning = request.separate_reasoning
|
1278
|
-
|
1279
|
-
if (
|
1280
|
-
request.chat_template_kwargs
|
1281
|
-
and request.chat_template_kwargs.get("enable_thinking") is not None
|
1282
|
-
):
|
1283
|
-
enable_thinking = request.chat_template_kwargs.get(
|
1284
|
-
"enable_thinking", True
|
1285
|
-
)
|
1272
|
+
enable_thinking = _get_enable_thinking_from_request(request)
|
1286
1273
|
|
1287
1274
|
reasoning_text = None
|
1288
1275
|
if reasoning_parser and separate_reasoning and enable_thinking:
|
@@ -1308,7 +1295,8 @@ def v1_chat_generate_response(
|
|
1308
1295
|
text, call_info_list = parser.parse_non_stream(text)
|
1309
1296
|
tool_calls = [
|
1310
1297
|
ToolCall(
|
1311
|
-
id=
|
1298
|
+
id=f"call_{base64.urlsafe_b64encode(uuid.uuid4().bytes).rstrip(b'=').decode()}",
|
1299
|
+
index=call_info.tool_index,
|
1312
1300
|
function=FunctionResponse(
|
1313
1301
|
name=call_info.name, arguments=call_info.parameters
|
1314
1302
|
),
|
@@ -1424,6 +1412,7 @@ async def v1_chat_completions(
|
|
1424
1412
|
reasoning_parser_dict = {}
|
1425
1413
|
|
1426
1414
|
async def generate_stream_resp():
|
1415
|
+
tool_call_first = True
|
1427
1416
|
is_firsts = {}
|
1428
1417
|
stream_buffers = {}
|
1429
1418
|
n_prev_tokens = {}
|
@@ -1521,9 +1510,12 @@ async def v1_chat_completions(
|
|
1521
1510
|
delta = text[len(stream_buffer) :]
|
1522
1511
|
new_stream_buffer = stream_buffer + delta
|
1523
1512
|
|
1513
|
+
enable_thinking = _get_enable_thinking_from_request(request)
|
1514
|
+
|
1524
1515
|
if (
|
1525
1516
|
tokenizer_manager.server_args.reasoning_parser
|
1526
1517
|
and request.separate_reasoning
|
1518
|
+
and enable_thinking
|
1527
1519
|
):
|
1528
1520
|
if index not in reasoning_parser_dict:
|
1529
1521
|
reasoning_parser_dict[index] = ReasoningParser(
|
@@ -1587,7 +1579,6 @@ async def v1_chat_completions(
|
|
1587
1579
|
# 2) if we found calls, we output them as separate chunk(s)
|
1588
1580
|
for call_item in calls:
|
1589
1581
|
# transform call_item -> FunctionResponse + ToolCall
|
1590
|
-
|
1591
1582
|
if finish_reason_type == "stop":
|
1592
1583
|
latest_delta_len = 0
|
1593
1584
|
if isinstance(call_item.parameters, str):
|
@@ -1610,14 +1601,19 @@ async def v1_chat_completions(
|
|
1610
1601
|
call_item.parameters = remaining_call
|
1611
1602
|
|
1612
1603
|
finish_reason_type = "tool_calls"
|
1613
|
-
|
1614
1604
|
tool_call = ToolCall(
|
1615
|
-
id=
|
1605
|
+
id=(
|
1606
|
+
f"call_{base64.urlsafe_b64encode(uuid.uuid4().bytes).rstrip(b'=').decode()}"
|
1607
|
+
if tool_call_first
|
1608
|
+
else None
|
1609
|
+
),
|
1610
|
+
index=call_item.tool_index,
|
1616
1611
|
function=FunctionResponse(
|
1617
1612
|
name=call_item.name,
|
1618
1613
|
arguments=call_item.parameters,
|
1619
1614
|
),
|
1620
1615
|
)
|
1616
|
+
tool_call_first = False
|
1621
1617
|
choice_data = ChatCompletionResponseStreamChoice(
|
1622
1618
|
index=index,
|
1623
1619
|
delta=DeltaMessage(tool_calls=[tool_call]),
|
@@ -172,6 +172,7 @@ class CompletionRequest(BaseModel):
|
|
172
172
|
top_k: int = -1
|
173
173
|
min_p: float = 0.0
|
174
174
|
min_tokens: int = 0
|
175
|
+
thinking_budget: Optional[int] = None
|
175
176
|
json_schema: Optional[str] = None
|
176
177
|
regex: Optional[str] = None
|
177
178
|
ebnf: Optional[str] = None
|
@@ -250,9 +251,29 @@ ChatCompletionMessageContentPart = Union[
|
|
250
251
|
]
|
251
252
|
|
252
253
|
|
254
|
+
class FunctionResponse(BaseModel):
|
255
|
+
"""Function response."""
|
256
|
+
|
257
|
+
name: Optional[str] = None
|
258
|
+
arguments: Optional[str] = None
|
259
|
+
|
260
|
+
|
261
|
+
class ToolCall(BaseModel):
|
262
|
+
"""Tool call response."""
|
263
|
+
|
264
|
+
id: Optional[str] = None
|
265
|
+
index: Optional[int] = None
|
266
|
+
type: Literal["function"] = "function"
|
267
|
+
function: FunctionResponse
|
268
|
+
|
269
|
+
|
253
270
|
class ChatCompletionMessageGenericParam(BaseModel):
|
254
271
|
role: Literal["system", "assistant", "tool"]
|
255
272
|
content: Union[str, List[ChatCompletionMessageContentTextPart], None]
|
273
|
+
tool_call_id: Optional[str] = None
|
274
|
+
name: Optional[str] = None
|
275
|
+
reasoning_content: Optional[str] = None
|
276
|
+
tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
|
256
277
|
|
257
278
|
|
258
279
|
class ChatCompletionMessageUserParam(BaseModel):
|
@@ -320,7 +341,23 @@ class ChatCompletionRequest(BaseModel):
|
|
320
341
|
logit_bias: Optional[Dict[str, float]] = None
|
321
342
|
logprobs: bool = False
|
322
343
|
top_logprobs: Optional[int] = None
|
323
|
-
max_tokens: Optional[int] =
|
344
|
+
max_tokens: Optional[int] = Field(
|
345
|
+
default=None,
|
346
|
+
deprecated="max_tokens is deprecated in favor of the max_completion_tokens field",
|
347
|
+
description="The maximum number of tokens that can be generated in the chat completion. ",
|
348
|
+
)
|
349
|
+
max_completion_tokens: Optional[int] = Field(
|
350
|
+
default=None,
|
351
|
+
description="The maximum number of completion tokens for a chat completion request, "
|
352
|
+
"including visible output tokens and reasoning tokens. Input tokens are not included. ",
|
353
|
+
)
|
354
|
+
thinking_budget: Optional[int] = Field(
|
355
|
+
default=None,
|
356
|
+
description="The maximum number of reasoning tokens that can be generated for a request. "
|
357
|
+
"This setting of does not affect the thinking process of models. "
|
358
|
+
"If the number of tokens generated by the model's thinking process exceeds thinking_budget, "
|
359
|
+
"the reasoning content will be truncated and the final response content will be generated immediately.",
|
360
|
+
)
|
324
361
|
n: int = 1
|
325
362
|
presence_penalty: float = 0.0
|
326
363
|
response_format: Optional[Union[ResponseFormat, StructuralTagResponseFormat]] = None
|
@@ -369,21 +406,6 @@ class ChatCompletionRequest(BaseModel):
|
|
369
406
|
bootstrap_room: Optional[int] = None
|
370
407
|
|
371
408
|
|
372
|
-
class FunctionResponse(BaseModel):
|
373
|
-
"""Function response."""
|
374
|
-
|
375
|
-
name: Optional[str] = None
|
376
|
-
arguments: Optional[str] = None
|
377
|
-
|
378
|
-
|
379
|
-
class ToolCall(BaseModel):
|
380
|
-
"""Tool call response."""
|
381
|
-
|
382
|
-
id: str
|
383
|
-
type: Literal["function"] = "function"
|
384
|
-
function: FunctionResponse
|
385
|
-
|
386
|
-
|
387
409
|
class ChatMessage(BaseModel):
|
388
410
|
role: Optional[str] = None
|
389
411
|
content: Optional[str] = None
|
sglang/srt/reasoning_parser.py
CHANGED
@@ -32,7 +32,7 @@ class BaseReasoningFormatDetector:
|
|
32
32
|
One-time parsing: Detects and parses reasoning sections in the provided text.
|
33
33
|
Returns both reasoning content and normal text separately.
|
34
34
|
"""
|
35
|
-
text = text.replace(self.think_start_token, "")
|
35
|
+
text = text.replace(self.think_start_token, "")
|
36
36
|
if self.think_end_token not in text:
|
37
37
|
# Assume reasoning was truncated before `</think>` token
|
38
38
|
return StreamingParseResult(reasoning_text=text)
|
@@ -73,7 +73,7 @@ class BaseReasoningFormatDetector:
|
|
73
73
|
normal_text = current_text[end_idx + len(self.think_end_token) :]
|
74
74
|
|
75
75
|
return StreamingParseResult(
|
76
|
-
normal_text=normal_text, reasoning_text=reasoning_text
|
76
|
+
normal_text=normal_text, reasoning_text=reasoning_text
|
77
77
|
)
|
78
78
|
|
79
79
|
# Continue with reasoning content
|
@@ -30,8 +30,13 @@ class SamplingBatchInfo:
|
|
30
30
|
# Whether any request needs min_p sampling
|
31
31
|
need_min_p_sampling: bool
|
32
32
|
|
33
|
+
# Use thinking_budget to truncate thinking
|
34
|
+
num_thinking_tokens: Optional[torch.Tensor] = None
|
35
|
+
think_end_ids: Optional[torch.Tensor] = None
|
36
|
+
thinking_budgets: Optional[torch.Tensor] = None
|
37
|
+
|
33
38
|
# Masking tensors for grammar-guided structured outputs
|
34
|
-
vocab_size: int
|
39
|
+
vocab_size: int = 0
|
35
40
|
grammars: Optional[List] = None
|
36
41
|
vocab_mask: Optional[torch.Tensor] = None
|
37
42
|
apply_mask_func: Optional[Callable[[torch.Tensor, torch.Tensor], None]] = None
|
@@ -76,7 +81,22 @@ class SamplingBatchInfo:
|
|
76
81
|
min_ps = torch.tensor(
|
77
82
|
[r.sampling_params.min_p for r in reqs], dtype=torch.float
|
78
83
|
).to(device, non_blocking=True)
|
79
|
-
|
84
|
+
if any(hasattr(r.tokenizer, "think_end_id") for r in reqs):
|
85
|
+
think_end_ids = torch.tensor(
|
86
|
+
[getattr(r.tokenizer, "think_end_id", -1) for r in reqs],
|
87
|
+
dtype=torch.int64,
|
88
|
+
).to(device, non_blocking=True)
|
89
|
+
num_thinking_tokens = torch.tensor([0 for _ in reqs], dtype=torch.int64).to(
|
90
|
+
device, non_blocking=True
|
91
|
+
)
|
92
|
+
thinking_budgets = torch.tensor(
|
93
|
+
[r.sampling_params.thinking_budget or -1 for r in reqs],
|
94
|
+
dtype=torch.int64,
|
95
|
+
).to(device, non_blocking=True)
|
96
|
+
else:
|
97
|
+
think_end_ids = None
|
98
|
+
num_thinking_tokens = None
|
99
|
+
thinking_budgets = None
|
80
100
|
# Check if any request has custom logit processor
|
81
101
|
has_custom_logit_processor = (
|
82
102
|
batch.enable_custom_logit_processor # check the flag first.
|
@@ -132,6 +152,9 @@ class SamplingBatchInfo:
|
|
132
152
|
top_ps=top_ps,
|
133
153
|
top_ks=top_ks,
|
134
154
|
min_ps=min_ps,
|
155
|
+
think_end_ids=think_end_ids,
|
156
|
+
num_thinking_tokens=num_thinking_tokens,
|
157
|
+
thinking_budgets=thinking_budgets,
|
135
158
|
is_all_greedy=all(r.sampling_params.top_k <= 1 for r in reqs),
|
136
159
|
need_min_p_sampling=any(r.sampling_params.min_p > 0 for r in reqs),
|
137
160
|
vocab_size=vocab_size,
|
@@ -146,6 +169,35 @@ class SamplingBatchInfo:
|
|
146
169
|
def __len__(self):
|
147
170
|
return len(self.temperatures)
|
148
171
|
|
172
|
+
def apply_thinking_budgets(self, next_token_logits: torch.Tensor):
|
173
|
+
has_budget = self.thinking_budgets > 0
|
174
|
+
if not has_budget.any():
|
175
|
+
return
|
176
|
+
torch.where(
|
177
|
+
has_budget,
|
178
|
+
self.num_thinking_tokens + 1,
|
179
|
+
self.num_thinking_tokens,
|
180
|
+
out=self.num_thinking_tokens,
|
181
|
+
)
|
182
|
+
should_stop = has_budget & (
|
183
|
+
self.num_thinking_tokens - 1 > self.thinking_budgets
|
184
|
+
)
|
185
|
+
next_token_logits.masked_fill_(should_stop.unsqueeze(0), float("-inf"))
|
186
|
+
batch_indices = torch.nonzero(should_stop, as_tuple=True)[0]
|
187
|
+
if len(batch_indices) > 0:
|
188
|
+
end_token_indices = self.think_end_ids[batch_indices]
|
189
|
+
next_token_logits[batch_indices, end_token_indices] = 0.0
|
190
|
+
|
191
|
+
def update_thinking_budgets(self, next_token_ids: torch.Tensor):
|
192
|
+
if not torch.any(self.thinking_budgets > 0):
|
193
|
+
return
|
194
|
+
torch.where(
|
195
|
+
next_token_ids == self.think_end_ids,
|
196
|
+
torch.tensor(-1, device=self.thinking_budgets.device),
|
197
|
+
self.thinking_budgets,
|
198
|
+
out=self.thinking_budgets,
|
199
|
+
)
|
200
|
+
|
149
201
|
def update_regex_vocab_mask(self):
|
150
202
|
if not self.grammars:
|
151
203
|
self.vocab_mask = None
|
@@ -30,6 +30,7 @@ class SamplingParams:
|
|
30
30
|
def __init__(
|
31
31
|
self,
|
32
32
|
max_new_tokens: int = 128,
|
33
|
+
thinking_budget: Optional[int] = None,
|
33
34
|
stop: Optional[Union[str, List[str]]] = None,
|
34
35
|
stop_token_ids: Optional[List[int]] = None,
|
35
36
|
temperature: float = 1.0,
|
@@ -57,6 +58,7 @@ class SamplingParams:
|
|
57
58
|
self.stop_token_ids = set(stop_token_ids)
|
58
59
|
else:
|
59
60
|
self.stop_token_ids = None
|
61
|
+
self.thinking_budget = thinking_budget
|
60
62
|
self.temperature = temperature
|
61
63
|
self.top_p = top_p
|
62
64
|
self.top_k = top_k
|