sglang 0.4.6.post2__py3-none-any.whl → 0.4.6.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (90) hide show
  1. sglang/bench_one_batch.py +1 -11
  2. sglang/bench_serving.py +149 -1
  3. sglang/lang/chat_template.py +44 -0
  4. sglang/srt/configs/deepseekvl2.py +3 -0
  5. sglang/srt/configs/device_config.py +1 -1
  6. sglang/srt/configs/internvl.py +696 -0
  7. sglang/srt/configs/janus_pro.py +3 -0
  8. sglang/srt/configs/model_config.py +17 -0
  9. sglang/srt/constrained/xgrammar_backend.py +11 -19
  10. sglang/srt/conversation.py +30 -3
  11. sglang/srt/disaggregation/decode.py +4 -1
  12. sglang/srt/disaggregation/mini_lb.py +74 -23
  13. sglang/srt/disaggregation/mooncake/conn.py +9 -18
  14. sglang/srt/disaggregation/nixl/conn.py +241 -71
  15. sglang/srt/disaggregation/utils.py +44 -1
  16. sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -8
  17. sglang/srt/distributed/device_communicators/npu_communicator.py +39 -0
  18. sglang/srt/distributed/device_communicators/pynccl.py +2 -1
  19. sglang/srt/distributed/device_communicators/shm_broadcast.py +2 -1
  20. sglang/srt/distributed/parallel_state.py +22 -1
  21. sglang/srt/entrypoints/engine.py +14 -2
  22. sglang/srt/entrypoints/http_server.py +28 -1
  23. sglang/srt/entrypoints/verl_engine.py +3 -2
  24. sglang/srt/hf_transformers_utils.py +20 -1
  25. sglang/srt/layers/attention/flashattention_backend.py +146 -50
  26. sglang/srt/layers/attention/flashinfer_backend.py +23 -13
  27. sglang/srt/layers/attention/flashinfer_mla_backend.py +62 -15
  28. sglang/srt/layers/attention/merge_state.py +46 -0
  29. sglang/srt/layers/attention/triton_ops/merge_state.py +96 -0
  30. sglang/srt/layers/attention/vision.py +290 -163
  31. sglang/srt/layers/moe/ep_moe/kernels.py +342 -7
  32. sglang/srt/layers/moe/ep_moe/layer.py +120 -1
  33. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +97 -54
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  35. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  36. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +4 -1
  37. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +2 -4
  38. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +2 -1
  39. sglang/srt/layers/quantization/deep_gemm.py +5 -0
  40. sglang/srt/layers/quantization/fp8.py +108 -95
  41. sglang/srt/layers/quantization/fp8_kernel.py +79 -60
  42. sglang/srt/layers/quantization/fp8_utils.py +71 -23
  43. sglang/srt/layers/quantization/kv_cache.py +3 -10
  44. sglang/srt/layers/quantization/utils.py +0 -5
  45. sglang/srt/layers/quantization/w8a8_fp8.py +8 -10
  46. sglang/srt/lora/lora_manager.py +10 -13
  47. sglang/srt/managers/cache_controller.py +115 -119
  48. sglang/srt/managers/io_struct.py +10 -0
  49. sglang/srt/managers/multimodal_processors/base_processor.py +5 -0
  50. sglang/srt/managers/multimodal_processors/internvl.py +232 -0
  51. sglang/srt/managers/schedule_batch.py +19 -1
  52. sglang/srt/managers/schedule_policy.py +11 -5
  53. sglang/srt/managers/scheduler.py +28 -13
  54. sglang/srt/managers/tokenizer_manager.py +24 -13
  55. sglang/srt/managers/tp_worker.py +9 -12
  56. sglang/srt/mem_cache/chunk_cache.py +2 -0
  57. sglang/srt/mem_cache/memory_pool.py +2 -2
  58. sglang/srt/model_executor/model_runner.py +44 -33
  59. sglang/srt/model_loader/loader.py +18 -11
  60. sglang/srt/models/clip.py +4 -4
  61. sglang/srt/models/deepseek_janus_pro.py +1 -1
  62. sglang/srt/models/deepseek_nextn.py +1 -20
  63. sglang/srt/models/deepseek_v2.py +55 -20
  64. sglang/srt/models/gemma3_mm.py +1 -1
  65. sglang/srt/models/internlm2.py +3 -0
  66. sglang/srt/models/internvl.py +670 -0
  67. sglang/srt/models/llama.py +1 -1
  68. sglang/srt/models/llama4.py +53 -7
  69. sglang/srt/models/minicpmv.py +1 -1
  70. sglang/srt/models/mllama.py +1 -1
  71. sglang/srt/models/phi3_small.py +16 -2
  72. sglang/srt/models/qwen2_5_vl.py +8 -4
  73. sglang/srt/models/qwen2_vl.py +4 -4
  74. sglang/srt/models/xiaomi_mimo.py +171 -0
  75. sglang/srt/openai_api/adapter.py +24 -40
  76. sglang/srt/openai_api/protocol.py +28 -16
  77. sglang/srt/reasoning_parser.py +2 -2
  78. sglang/srt/sampling/sampling_batch_info.py +54 -2
  79. sglang/srt/sampling/sampling_params.py +2 -0
  80. sglang/srt/server_args.py +30 -6
  81. sglang/srt/utils.py +35 -1
  82. sglang/test/test_block_fp8.py +2 -2
  83. sglang/test/test_deepep_utils.py +219 -0
  84. sglang/test/test_utils.py +3 -1
  85. sglang/version.py +1 -1
  86. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/METADATA +14 -6
  87. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/RECORD +90 -80
  88. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/WHEEL +1 -1
  89. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/licenses/LICENSE +0 -0
  90. {sglang-0.4.6.post2.dist-info → sglang-0.4.6.post3.dist-info}/top_level.txt +0 -0
@@ -46,7 +46,11 @@ from sglang.srt.layers.radix_attention import RadixAttention
46
46
  from sglang.srt.layers.rotary_embedding import get_rope
47
47
  from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
48
48
  from sglang.srt.managers.schedule_batch import global_server_args_dict
49
- from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
49
+ from sglang.srt.model_executor.forward_batch_info import (
50
+ ForwardBatch,
51
+ ForwardMode,
52
+ PPProxyTensors,
53
+ )
50
54
  from sglang.srt.models.llama import LlamaForCausalLM, LlamaMLP
51
55
  from sglang.srt.utils import add_prefix, fast_topk, get_compiler_backend, make_layers
52
56
 
@@ -81,6 +85,7 @@ class Llama4MoE(nn.Module):
81
85
  super().__init__()
82
86
  self.tp_size = get_tensor_model_parallel_world_size()
83
87
  self.top_k = config.num_experts_per_tok
88
+ self.device_module = torch.get_device_module()
84
89
 
85
90
  intermediate_size_moe = config.intermediate_size
86
91
  self.router = ReplicatedLinear(
@@ -113,7 +118,25 @@ class Llama4MoE(nn.Module):
113
118
  reduce_results=False, # We need to do scatter before reduce
114
119
  )
115
120
 
116
- def forward(self, hidden_states):
121
+ def forward(self, hidden_states, forward_batch: ForwardBatch):
122
+ shared_out, routed_out = self._forward_core(
123
+ hidden_states, forward_batch.forward_mode
124
+ )
125
+
126
+ out_aD = routed_out + shared_out
127
+
128
+ if self.tp_size > 1:
129
+ out_aD = tensor_model_parallel_all_reduce(out_aD)
130
+
131
+ return out_aD
132
+
133
+ def _forward_core(self, hidden_states, forward_mode: ForwardMode):
134
+ if hidden_states.shape[0] < 4:
135
+ return self._forward_core_shared_routed_overlap(hidden_states)
136
+ else:
137
+ return self._forward_core_normal(hidden_states)
138
+
139
+ def _forward_core_normal(self, hidden_states):
117
140
  # router_scores: [num_tokens, num_experts]
118
141
  router_logits, _ = self.router(hidden_states)
119
142
  shared_out = self.shared_expert(hidden_states)
@@ -121,12 +144,35 @@ class Llama4MoE(nn.Module):
121
144
  hidden_states=hidden_states,
122
145
  router_logits=router_logits,
123
146
  )
124
- out_aD = routed_out + shared_out
147
+ return shared_out, routed_out
125
148
 
126
- if self.tp_size > 1:
127
- out_aD = tensor_model_parallel_all_reduce(out_aD)
149
+ def _forward_core_shared_routed_overlap(self, hidden_states):
150
+ alt_stream = _get_or_create_alt_stream(self.device_module)
128
151
 
129
- return out_aD
152
+ alt_stream.wait_stream(self.device_module.current_stream())
153
+
154
+ shared_out = self.shared_expert(hidden_states)
155
+
156
+ with self.device_module.stream(alt_stream):
157
+ # router_scores: [num_tokens, num_experts]
158
+ router_logits, _ = self.router(hidden_states)
159
+ routed_out = self.experts(
160
+ hidden_states=hidden_states,
161
+ router_logits=router_logits,
162
+ )
163
+ self.device_module.current_stream().wait_stream(alt_stream)
164
+
165
+ return shared_out, routed_out
166
+
167
+
168
+ _alt_stream = None
169
+
170
+
171
+ def _get_or_create_alt_stream(device_module):
172
+ global _alt_stream
173
+ if _alt_stream is None:
174
+ _alt_stream = device_module.Stream()
175
+ return _alt_stream
130
176
 
131
177
 
132
178
  class Llama4Attention(nn.Module):
@@ -380,7 +426,7 @@ class Llama4DecoderLayer(nn.Module):
380
426
  )
381
427
 
382
428
  # Fully Connected
383
- hidden_states = self.feed_forward(hidden_states)
429
+ hidden_states = self.feed_forward(hidden_states, forward_batch)
384
430
 
385
431
  # TODO(ch-wan): ues reduce-scatter in MLP to avoid this scatter
386
432
  # Scatter
@@ -197,7 +197,7 @@ class Idefics2EncoderLayer(nn.Module):
197
197
  use_qkv_parallel=True,
198
198
  quant_config=quant_config,
199
199
  dropout=config.attention_dropout,
200
- use_context_forward=False,
200
+ qkv_backend="sdpa",
201
201
  softmax_in_single_precision=True,
202
202
  flatten_batch=False,
203
203
  prefix=add_prefix("self_attn", prefix),
@@ -203,7 +203,7 @@ class MllamaVisionEncoderLayer(nn.Module):
203
203
  use_qkv_parallel=True,
204
204
  quant_config=quant_config,
205
205
  dropout=0.0,
206
- use_context_forward=False,
206
+ qkv_backend="sdpa",
207
207
  softmax_in_single_precision=False,
208
208
  flatten_batch=False,
209
209
  prefix=add_prefix("self_attn", prefix),
@@ -6,7 +6,7 @@ from torch import nn
6
6
  from transformers import Phi3Config
7
7
  from transformers.configuration_utils import PretrainedConfig
8
8
 
9
- from sglang.srt.distributed import get_tensor_model_parallel_world_size
9
+ from sglang.srt.distributed import get_pp_group, get_tensor_model_parallel_world_size
10
10
  from sglang.srt.layers.linear import (
11
11
  MergedColumnParallelLinear,
12
12
  QKVParallelLinear,
@@ -17,6 +17,7 @@ from sglang.srt.layers.pooler import Pooler, PoolingType
17
17
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
18
18
  from sglang.srt.layers.radix_attention import RadixAttention
19
19
  from sglang.srt.layers.rotary_embedding import get_rope
20
+ from sglang.srt.layers.utils import PPMissingLayer
20
21
  from sglang.srt.layers.vocab_parallel_embedding import (
21
22
  DEFAULT_VOCAB_PADDING_SIZE,
22
23
  ParallelLMHead,
@@ -294,13 +295,24 @@ class Phi3SmallModel(nn.Module):
294
295
  super().__init__()
295
296
 
296
297
  self.config = config
298
+
299
+ self.pp_group = get_pp_group()
300
+ if self.pp_group.is_first_rank:
301
+ self.embed_tokens = VocabParallelEmbedding(
302
+ config.vocab_size,
303
+ config.hidden_size,
304
+ prefix=add_prefix("embed_tokens", prefix),
305
+ )
306
+ else:
307
+ self.embed_tokens = PPMissingLayer()
308
+
297
309
  self.embed_tokens = VocabParallelEmbedding(
298
310
  config.vocab_size,
299
311
  config.hidden_size,
300
312
  prefix=add_prefix("embed_tokens", prefix),
301
313
  )
302
314
  self.mup_embedding_multiplier = config.mup_embedding_multiplier
303
- self.start_layer, self.end_layer, self.layers = make_layers(
315
+ self.layers, self.start_layer, self.end_layer = make_layers(
304
316
  config.num_hidden_layers,
305
317
  lambda idx, prefix: Phi3SmallDecoderLayer(
306
318
  config,
@@ -308,6 +320,8 @@ class Phi3SmallModel(nn.Module):
308
320
  quant_config,
309
321
  prefix=prefix,
310
322
  ),
323
+ pp_rank=self.pp_group.rank_in_group,
324
+ pp_size=self.pp_group.world_size,
311
325
  prefix=add_prefix("layers", prefix),
312
326
  )
313
327
 
@@ -125,16 +125,20 @@ class Qwen2_5_VisionBlock(nn.Module):
125
125
  self.norm1 = Qwen2RMSNorm(dim, eps=1e-6)
126
126
  self.norm2 = Qwen2RMSNorm(dim, eps=1e-6)
127
127
  if attn_implementation == "sdpa":
128
- use_context_forward = False
129
128
  softmax_in_single_precision = False
129
+ qkv_backend = "sdpa"
130
130
  flatten_batch = True
131
131
  elif attn_implementation == "flash_attention_2":
132
132
  softmax_in_single_precision = False
133
- use_context_forward = True
133
+ qkv_backend = "triton_attn"
134
134
  flatten_batch = True
135
135
  elif attn_implementation == "eager":
136
136
  softmax_in_single_precision = True
137
- use_context_forward = False
137
+ qkv_backend = "sdpa"
138
+ flatten_batch = True
139
+ elif attn_implementation == "flash_attention_3":
140
+ softmax_in_single_precision = False
141
+ qkv_backend = "fa3"
138
142
  flatten_batch = True
139
143
 
140
144
  self.attn = VisionAttention(
@@ -142,7 +146,7 @@ class Qwen2_5_VisionBlock(nn.Module):
142
146
  num_heads=num_heads,
143
147
  projection_size=dim,
144
148
  use_qkv_parallel=True,
145
- use_context_forward=use_context_forward,
149
+ qkv_backend=qkv_backend,
146
150
  softmax_in_single_precision=softmax_in_single_precision,
147
151
  flatten_batch=flatten_batch,
148
152
  quant_config=quant_config,
@@ -139,21 +139,21 @@ class Qwen2VisionBlock(nn.Module):
139
139
  self.norm2 = norm_layer(dim)
140
140
  mlp_hidden_dim = int(dim * mlp_ratio)
141
141
  if attn_implementation == "sdpa":
142
- use_context_forward = False
142
+ qkv_backend = "sdpa"
143
143
  softmax_in_single_precision = False
144
144
  elif attn_implementation == "flash_attention_2":
145
+ qkv_backend = "triton_attn"
145
146
  softmax_in_single_precision = False
146
- use_context_forward = True
147
147
  elif attn_implementation == "eager":
148
+ qkv_backend = "sdpa"
148
149
  softmax_in_single_precision = True
149
- use_context_forward = False
150
150
 
151
151
  self.attn = VisionAttention(
152
152
  embed_dim=dim,
153
153
  num_heads=num_heads,
154
154
  projection_size=dim,
155
155
  use_qkv_parallel=True,
156
- use_context_forward=use_context_forward,
156
+ qkv_backend=qkv_backend,
157
157
  softmax_in_single_precision=softmax_in_single_precision,
158
158
  flatten_batch=True,
159
159
  quant_config=quant_config,
@@ -0,0 +1,171 @@
1
+ # Adapted from qwen2.py
2
+
3
+ from functools import partial
4
+ from typing import Any, Dict, Iterable, Optional, Tuple
5
+
6
+ import torch
7
+ from torch import nn
8
+
9
+ from sglang.srt.distributed import (
10
+ get_tensor_model_parallel_rank,
11
+ get_tensor_model_parallel_world_size,
12
+ split_tensor_along_last_dim,
13
+ tensor_model_parallel_all_gather,
14
+ )
15
+ from sglang.srt.layers.layernorm import RMSNorm
16
+ from sglang.srt.layers.linear import QKVParallelLinear, RowParallelLinear
17
+ from sglang.srt.layers.logits_processor import LogitsProcessor
18
+ from sglang.srt.layers.pooler import Pooler, PoolingType
19
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
20
+ from sglang.srt.layers.radix_attention import RadixAttention
21
+ from sglang.srt.layers.rotary_embedding import get_rope
22
+ from sglang.srt.layers.vocab_parallel_embedding import ParallelLMHead
23
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
24
+ from sglang.srt.model_loader.weight_utils import default_weight_loader
25
+ from sglang.srt.models.qwen2 import Qwen2DecoderLayer, Qwen2MLP, Qwen2Model
26
+ from sglang.srt.utils import add_prefix
27
+
28
+ MiMoConfig = None
29
+
30
+
31
+ class MiMoModel(Qwen2Model):
32
+ def __init__(
33
+ self,
34
+ config: MiMoConfig,
35
+ quant_config: Optional[QuantizationConfig] = None,
36
+ prefix: str = "",
37
+ ) -> None:
38
+ super().__init__(
39
+ config=config,
40
+ quant_config=quant_config,
41
+ prefix=prefix,
42
+ decoder_layer_type=Qwen2DecoderLayer,
43
+ )
44
+
45
+
46
+ class MiMoForCausalLM(nn.Module):
47
+ # BitandBytes specific attributes
48
+ default_bitsandbytes_target_modules = [
49
+ ".gate_proj.",
50
+ ".down_proj.",
51
+ ".up_proj.",
52
+ ".q_proj.",
53
+ ".k_proj.",
54
+ ".v_proj.",
55
+ ".o_proj.",
56
+ ]
57
+ bitsandbytes_stacked_params_mapping = {
58
+ # shard_name, weight_name, index
59
+ "q_proj": ("qkv_proj", 0),
60
+ "k_proj": ("qkv_proj", 1),
61
+ "v_proj": ("qkv_proj", 2),
62
+ "gate_proj": ("gate_up_proj", 0),
63
+ "up_proj": ("gate_up_proj", 1),
64
+ }
65
+
66
+ def __init__(
67
+ self,
68
+ config: MiMoConfig,
69
+ quant_config: Optional[QuantizationConfig] = None,
70
+ prefix: str = "",
71
+ ) -> None:
72
+ super().__init__()
73
+ self.config = config
74
+ self.quant_config = quant_config
75
+ self.model = MiMoModel(
76
+ config, quant_config=quant_config, prefix=add_prefix("model", prefix)
77
+ )
78
+ if config.tie_word_embeddings:
79
+ self.lm_head = self.model.embed_tokens
80
+ else:
81
+ self.lm_head = ParallelLMHead(
82
+ config.vocab_size,
83
+ config.hidden_size,
84
+ quant_config=quant_config,
85
+ prefix=add_prefix("lm_head", prefix),
86
+ )
87
+ self.logits_processor = LogitsProcessor(config)
88
+ self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
89
+
90
+ def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
91
+ return self.model.get_input_embeddings(input_ids)
92
+
93
+ @torch.no_grad()
94
+ def forward(
95
+ self,
96
+ input_ids: torch.Tensor,
97
+ positions: torch.Tensor,
98
+ forward_batch: ForwardBatch,
99
+ input_embeds: torch.Tensor = None,
100
+ get_embedding: bool = False,
101
+ ) -> torch.Tensor:
102
+ hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
103
+ if not get_embedding:
104
+ return self.logits_processor(
105
+ input_ids, hidden_states, self.lm_head, forward_batch
106
+ )
107
+ else:
108
+ return self.pooler(hidden_states, forward_batch)
109
+
110
+ def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
111
+ stacked_params_mapping = [
112
+ # (param_name, shard_name, shard_id)
113
+ ("qkv_proj", "q_proj", "q"),
114
+ ("qkv_proj", "k_proj", "k"),
115
+ ("qkv_proj", "v_proj", "v"),
116
+ ("gate_up_proj", "gate_proj", 0),
117
+ ("gate_up_proj", "up_proj", 1),
118
+ ]
119
+
120
+ params_dict = dict(self.named_parameters())
121
+ for name, loaded_weight in weights:
122
+ if (
123
+ "rotary_emb.inv_freq" in name
124
+ or "projector" in name
125
+ or "mtp_layers" in name
126
+ ):
127
+ continue
128
+ if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
129
+ # Models trained using ColossalAI may include these tensors in
130
+ # the checkpoint. Skip them.
131
+ continue
132
+ if self.config.tie_word_embeddings and "lm_head.weight" in name:
133
+ continue
134
+ if name.startswith("model.vision_tower") and name not in params_dict:
135
+ continue
136
+
137
+ for param_name, weight_name, shard_id in stacked_params_mapping:
138
+ if weight_name not in name:
139
+ continue
140
+ name = name.replace(weight_name, param_name)
141
+ # Skip loading extra bias for GPTQ models.
142
+ if name.endswith(".bias") and name not in params_dict:
143
+ continue
144
+ param = params_dict[name]
145
+ weight_loader = param.weight_loader
146
+ weight_loader(param, loaded_weight, shard_id)
147
+ break
148
+ else:
149
+ # Skip loading extra bias for GPTQ models.
150
+ if name.endswith(".bias") and name not in params_dict:
151
+ continue
152
+ param = params_dict[name]
153
+ weight_loader = getattr(param, "weight_loader", default_weight_loader)
154
+ weight_loader(param, loaded_weight)
155
+
156
+ def get_embed_and_head(self):
157
+ return self.model.embed_tokens.weight, self.lm_head.weight
158
+
159
+ def set_embed_and_head(self, embed, head):
160
+ del self.model.embed_tokens.weight
161
+ del self.lm_head.weight
162
+ self.model.embed_tokens.weight = embed
163
+ self.lm_head.weight = head
164
+ torch.cuda.empty_cache()
165
+ torch.cuda.synchronize()
166
+
167
+ def load_kv_cache_scales(self, quantization_param_path: str) -> None:
168
+ self.model.load_kv_cache_scales(quantization_param_path)
169
+
170
+
171
+ EntryClass = MiMoForCausalLM
@@ -14,6 +14,7 @@
14
14
  """Conversion between OpenAI APIs and native SRT APIs"""
15
15
 
16
16
  import asyncio
17
+ import base64
17
18
  import json
18
19
  import logging
19
20
  import os
@@ -528,6 +529,7 @@ def v1_generate_request(
528
529
  "temperature": request.temperature,
529
530
  "max_new_tokens": request.max_tokens,
530
531
  "min_new_tokens": request.min_tokens,
532
+ "thinking_budget": request.thinking_budget,
531
533
  "stop": request.stop,
532
534
  "stop_token_ids": request.stop_token_ids,
533
535
  "top_p": request.top_p,
@@ -966,47 +968,23 @@ def v1_chat_generate_request(
966
968
 
967
969
  if chat_template_name is None:
968
970
  openai_compatible_messages = []
969
- if (
970
- tools
971
- and tokenizer_manager.server_args.tool_call_parser == "deepseekv3"
972
- ):
973
- # add function call prompt to deepseekv3
974
- openai_compatible_messages.append(
975
- {
976
- "role": "system",
977
- "content": """You are a helpful Assistant.
978
- ## Tools
979
- ### Function
980
- You have the following functions available:
981
- """
982
- + "".join(
983
- [
984
- f"""
985
- - `{tool['name']}`:
986
- ```json
987
- {json.dumps(tool)}
988
- ```
989
- """
990
- for tool in tools
991
- ]
992
- ),
993
- }
994
- )
995
971
 
996
972
  for message in request.messages:
997
973
  if message.content is None:
998
974
  message.content = ""
999
- if isinstance(message.content, str):
1000
- openai_compatible_messages.append(
1001
- {"role": message.role, "content": message.content}
1002
- )
975
+ msg_dict = message.dict()
976
+ if isinstance(msg_dict.get("content"), list):
977
+ for chunk in msg_dict["content"]:
978
+ if isinstance(chunk, dict) and chunk.get("type") == "text":
979
+ new_msg = msg_dict.copy()
980
+ new_msg["content"] = chunk["text"]
981
+ new_msg = {
982
+ k: v for k, v in new_msg.items() if v is not None
983
+ }
984
+ openai_compatible_messages.append(new_msg)
1003
985
  else:
1004
- content_list = message.dict()["content"]
1005
- for content in content_list:
1006
- if content["type"] == "text":
1007
- openai_compatible_messages.append(
1008
- {"role": message.role, "content": content["text"]}
1009
- )
986
+ msg_dict = {k: v for k, v in msg_dict.items() if v is not None}
987
+ openai_compatible_messages.append(msg_dict)
1010
988
  if (
1011
989
  openai_compatible_messages
1012
990
  and openai_compatible_messages[-1]["role"] == "assistant"
@@ -1124,6 +1102,7 @@ def v1_chat_generate_request(
1124
1102
  "temperature": request.temperature,
1125
1103
  "max_new_tokens": request.max_tokens or request.max_completion_tokens,
1126
1104
  "min_new_tokens": request.min_tokens,
1105
+ "thinking_budget": request.thinking_budget,
1127
1106
  "stop": stop,
1128
1107
  "stop_token_ids": request.stop_token_ids,
1129
1108
  "top_p": request.top_p,
@@ -1316,7 +1295,8 @@ def v1_chat_generate_response(
1316
1295
  text, call_info_list = parser.parse_non_stream(text)
1317
1296
  tool_calls = [
1318
1297
  ToolCall(
1319
- id=str(call_info.tool_index),
1298
+ id=f"call_{base64.urlsafe_b64encode(uuid.uuid4().bytes).rstrip(b'=').decode()}",
1299
+ index=call_info.tool_index,
1320
1300
  function=FunctionResponse(
1321
1301
  name=call_info.name, arguments=call_info.parameters
1322
1302
  ),
@@ -1432,6 +1412,7 @@ async def v1_chat_completions(
1432
1412
  reasoning_parser_dict = {}
1433
1413
 
1434
1414
  async def generate_stream_resp():
1415
+ tool_call_first = True
1435
1416
  is_firsts = {}
1436
1417
  stream_buffers = {}
1437
1418
  n_prev_tokens = {}
@@ -1598,7 +1579,6 @@ async def v1_chat_completions(
1598
1579
  # 2) if we found calls, we output them as separate chunk(s)
1599
1580
  for call_item in calls:
1600
1581
  # transform call_item -> FunctionResponse + ToolCall
1601
-
1602
1582
  if finish_reason_type == "stop":
1603
1583
  latest_delta_len = 0
1604
1584
  if isinstance(call_item.parameters, str):
@@ -1621,15 +1601,19 @@ async def v1_chat_completions(
1621
1601
  call_item.parameters = remaining_call
1622
1602
 
1623
1603
  finish_reason_type = "tool_calls"
1624
-
1625
1604
  tool_call = ToolCall(
1626
- id=str(call_item.tool_index),
1605
+ id=(
1606
+ f"call_{base64.urlsafe_b64encode(uuid.uuid4().bytes).rstrip(b'=').decode()}"
1607
+ if tool_call_first
1608
+ else None
1609
+ ),
1627
1610
  index=call_item.tool_index,
1628
1611
  function=FunctionResponse(
1629
1612
  name=call_item.name,
1630
1613
  arguments=call_item.parameters,
1631
1614
  ),
1632
1615
  )
1616
+ tool_call_first = False
1633
1617
  choice_data = ChatCompletionResponseStreamChoice(
1634
1618
  index=index,
1635
1619
  delta=DeltaMessage(tool_calls=[tool_call]),
@@ -172,6 +172,7 @@ class CompletionRequest(BaseModel):
172
172
  top_k: int = -1
173
173
  min_p: float = 0.0
174
174
  min_tokens: int = 0
175
+ thinking_budget: Optional[int] = None
175
176
  json_schema: Optional[str] = None
176
177
  regex: Optional[str] = None
177
178
  ebnf: Optional[str] = None
@@ -250,9 +251,29 @@ ChatCompletionMessageContentPart = Union[
250
251
  ]
251
252
 
252
253
 
254
+ class FunctionResponse(BaseModel):
255
+ """Function response."""
256
+
257
+ name: Optional[str] = None
258
+ arguments: Optional[str] = None
259
+
260
+
261
+ class ToolCall(BaseModel):
262
+ """Tool call response."""
263
+
264
+ id: Optional[str] = None
265
+ index: Optional[int] = None
266
+ type: Literal["function"] = "function"
267
+ function: FunctionResponse
268
+
269
+
253
270
  class ChatCompletionMessageGenericParam(BaseModel):
254
271
  role: Literal["system", "assistant", "tool"]
255
272
  content: Union[str, List[ChatCompletionMessageContentTextPart], None]
273
+ tool_call_id: Optional[str] = None
274
+ name: Optional[str] = None
275
+ reasoning_content: Optional[str] = None
276
+ tool_calls: Optional[List[ToolCall]] = Field(default=None, examples=[None])
256
277
 
257
278
 
258
279
  class ChatCompletionMessageUserParam(BaseModel):
@@ -330,6 +351,13 @@ class ChatCompletionRequest(BaseModel):
330
351
  description="The maximum number of completion tokens for a chat completion request, "
331
352
  "including visible output tokens and reasoning tokens. Input tokens are not included. ",
332
353
  )
354
+ thinking_budget: Optional[int] = Field(
355
+ default=None,
356
+ description="The maximum number of reasoning tokens that can be generated for a request. "
357
+ "This setting of does not affect the thinking process of models. "
358
+ "If the number of tokens generated by the model's thinking process exceeds thinking_budget, "
359
+ "the reasoning content will be truncated and the final response content will be generated immediately.",
360
+ )
333
361
  n: int = 1
334
362
  presence_penalty: float = 0.0
335
363
  response_format: Optional[Union[ResponseFormat, StructuralTagResponseFormat]] = None
@@ -378,22 +406,6 @@ class ChatCompletionRequest(BaseModel):
378
406
  bootstrap_room: Optional[int] = None
379
407
 
380
408
 
381
- class FunctionResponse(BaseModel):
382
- """Function response."""
383
-
384
- name: Optional[str] = None
385
- arguments: Optional[str] = None
386
-
387
-
388
- class ToolCall(BaseModel):
389
- """Tool call response."""
390
-
391
- id: str
392
- index: Optional[int] = None
393
- type: Literal["function"] = "function"
394
- function: FunctionResponse
395
-
396
-
397
409
  class ChatMessage(BaseModel):
398
410
  role: Optional[str] = None
399
411
  content: Optional[str] = None
@@ -32,7 +32,7 @@ class BaseReasoningFormatDetector:
32
32
  One-time parsing: Detects and parses reasoning sections in the provided text.
33
33
  Returns both reasoning content and normal text separately.
34
34
  """
35
- text = text.replace(self.think_start_token, "").strip()
35
+ text = text.replace(self.think_start_token, "")
36
36
  if self.think_end_token not in text:
37
37
  # Assume reasoning was truncated before `</think>` token
38
38
  return StreamingParseResult(reasoning_text=text)
@@ -73,7 +73,7 @@ class BaseReasoningFormatDetector:
73
73
  normal_text = current_text[end_idx + len(self.think_end_token) :]
74
74
 
75
75
  return StreamingParseResult(
76
- normal_text=normal_text, reasoning_text=reasoning_text.rstrip()
76
+ normal_text=normal_text, reasoning_text=reasoning_text
77
77
  )
78
78
 
79
79
  # Continue with reasoning content