sglang 0.4.6.post3__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. sglang/bench_offline_throughput.py +10 -8
  2. sglang/bench_one_batch.py +7 -6
  3. sglang/bench_one_batch_server.py +157 -21
  4. sglang/bench_serving.py +137 -59
  5. sglang/compile_deep_gemm.py +5 -5
  6. sglang/eval/loogle_eval.py +157 -0
  7. sglang/lang/chat_template.py +78 -78
  8. sglang/lang/tracer.py +1 -1
  9. sglang/srt/code_completion_parser.py +1 -1
  10. sglang/srt/configs/deepseekvl2.py +2 -2
  11. sglang/srt/configs/model_config.py +40 -28
  12. sglang/srt/constrained/base_grammar_backend.py +55 -72
  13. sglang/srt/constrained/llguidance_backend.py +25 -21
  14. sglang/srt/constrained/outlines_backend.py +27 -26
  15. sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
  16. sglang/srt/constrained/xgrammar_backend.py +69 -43
  17. sglang/srt/conversation.py +49 -44
  18. sglang/srt/disaggregation/base/conn.py +1 -0
  19. sglang/srt/disaggregation/decode.py +129 -135
  20. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
  21. sglang/srt/disaggregation/fake/conn.py +3 -13
  22. sglang/srt/disaggregation/kv_events.py +357 -0
  23. sglang/srt/disaggregation/mini_lb.py +57 -24
  24. sglang/srt/disaggregation/mooncake/conn.py +238 -122
  25. sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
  26. sglang/srt/disaggregation/nixl/conn.py +10 -19
  27. sglang/srt/disaggregation/prefill.py +132 -47
  28. sglang/srt/disaggregation/utils.py +123 -6
  29. sglang/srt/distributed/utils.py +3 -3
  30. sglang/srt/entrypoints/EngineBase.py +5 -0
  31. sglang/srt/entrypoints/engine.py +44 -9
  32. sglang/srt/entrypoints/http_server.py +23 -6
  33. sglang/srt/entrypoints/http_server_engine.py +5 -2
  34. sglang/srt/function_call/base_format_detector.py +250 -0
  35. sglang/srt/function_call/core_types.py +34 -0
  36. sglang/srt/function_call/deepseekv3_detector.py +157 -0
  37. sglang/srt/function_call/ebnf_composer.py +234 -0
  38. sglang/srt/function_call/function_call_parser.py +175 -0
  39. sglang/srt/function_call/llama32_detector.py +74 -0
  40. sglang/srt/function_call/mistral_detector.py +84 -0
  41. sglang/srt/function_call/pythonic_detector.py +163 -0
  42. sglang/srt/function_call/qwen25_detector.py +67 -0
  43. sglang/srt/function_call/utils.py +35 -0
  44. sglang/srt/hf_transformers_utils.py +46 -7
  45. sglang/srt/layers/attention/aiter_backend.py +513 -0
  46. sglang/srt/layers/attention/flashattention_backend.py +64 -18
  47. sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
  48. sglang/srt/layers/attention/flashmla_backend.py +340 -78
  49. sglang/srt/layers/attention/triton_backend.py +3 -0
  50. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
  51. sglang/srt/layers/attention/utils.py +6 -4
  52. sglang/srt/layers/attention/vision.py +1 -1
  53. sglang/srt/layers/communicator.py +451 -0
  54. sglang/srt/layers/dp_attention.py +61 -21
  55. sglang/srt/layers/layernorm.py +1 -1
  56. sglang/srt/layers/logits_processor.py +46 -11
  57. sglang/srt/layers/moe/cutlass_moe.py +207 -0
  58. sglang/srt/layers/moe/ep_moe/kernels.py +34 -12
  59. sglang/srt/layers/moe/ep_moe/layer.py +105 -51
  60. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -1
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
  63. sglang/srt/layers/moe/topk.py +67 -10
  64. sglang/srt/layers/multimodal.py +70 -0
  65. sglang/srt/layers/quantization/__init__.py +8 -3
  66. sglang/srt/layers/quantization/blockwise_int8.py +2 -2
  67. sglang/srt/layers/quantization/deep_gemm.py +77 -74
  68. sglang/srt/layers/quantization/fp8.py +92 -2
  69. sglang/srt/layers/quantization/fp8_kernel.py +3 -3
  70. sglang/srt/layers/quantization/fp8_utils.py +6 -0
  71. sglang/srt/layers/quantization/gptq.py +298 -6
  72. sglang/srt/layers/quantization/int8_kernel.py +20 -7
  73. sglang/srt/layers/quantization/qoq.py +244 -0
  74. sglang/srt/layers/sampler.py +0 -4
  75. sglang/srt/layers/vocab_parallel_embedding.py +18 -7
  76. sglang/srt/lora/lora_manager.py +2 -4
  77. sglang/srt/lora/mem_pool.py +4 -4
  78. sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
  79. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  80. sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
  81. sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
  82. sglang/srt/lora/utils.py +1 -1
  83. sglang/srt/managers/data_parallel_controller.py +3 -3
  84. sglang/srt/managers/deepseek_eplb.py +278 -0
  85. sglang/srt/managers/detokenizer_manager.py +21 -8
  86. sglang/srt/managers/eplb_manager.py +55 -0
  87. sglang/srt/managers/expert_distribution.py +704 -56
  88. sglang/srt/managers/expert_location.py +394 -0
  89. sglang/srt/managers/expert_location_dispatch.py +91 -0
  90. sglang/srt/managers/io_struct.py +19 -4
  91. sglang/srt/managers/mm_utils.py +294 -140
  92. sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
  93. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
  94. sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
  95. sglang/srt/managers/multimodal_processors/internvl.py +14 -5
  96. sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
  97. sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
  98. sglang/srt/managers/multimodal_processors/llava.py +46 -0
  99. sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
  100. sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
  101. sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
  102. sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
  103. sglang/srt/managers/schedule_batch.py +122 -42
  104. sglang/srt/managers/schedule_policy.py +1 -5
  105. sglang/srt/managers/scheduler.py +205 -138
  106. sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
  107. sglang/srt/managers/session_controller.py +1 -1
  108. sglang/srt/managers/tokenizer_manager.py +232 -58
  109. sglang/srt/managers/tp_worker.py +12 -9
  110. sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
  111. sglang/srt/mem_cache/base_prefix_cache.py +3 -0
  112. sglang/srt/mem_cache/chunk_cache.py +3 -1
  113. sglang/srt/mem_cache/hiradix_cache.py +4 -4
  114. sglang/srt/mem_cache/memory_pool.py +76 -52
  115. sglang/srt/mem_cache/multimodal_cache.py +45 -0
  116. sglang/srt/mem_cache/radix_cache.py +58 -5
  117. sglang/srt/metrics/collector.py +314 -39
  118. sglang/srt/mm_utils.py +10 -0
  119. sglang/srt/model_executor/cuda_graph_runner.py +29 -19
  120. sglang/srt/model_executor/expert_location_updater.py +422 -0
  121. sglang/srt/model_executor/forward_batch_info.py +5 -1
  122. sglang/srt/model_executor/model_runner.py +163 -68
  123. sglang/srt/model_loader/loader.py +10 -6
  124. sglang/srt/models/clip.py +5 -1
  125. sglang/srt/models/deepseek_janus_pro.py +2 -2
  126. sglang/srt/models/deepseek_v2.py +308 -351
  127. sglang/srt/models/exaone.py +8 -3
  128. sglang/srt/models/gemma3_mm.py +70 -33
  129. sglang/srt/models/llama.py +2 -0
  130. sglang/srt/models/llama4.py +15 -8
  131. sglang/srt/models/llava.py +258 -7
  132. sglang/srt/models/mimo_mtp.py +220 -0
  133. sglang/srt/models/minicpmo.py +5 -12
  134. sglang/srt/models/mistral.py +71 -1
  135. sglang/srt/models/mixtral.py +98 -34
  136. sglang/srt/models/mllama.py +3 -3
  137. sglang/srt/models/pixtral.py +467 -0
  138. sglang/srt/models/qwen2.py +95 -26
  139. sglang/srt/models/qwen2_5_vl.py +8 -0
  140. sglang/srt/models/qwen2_moe.py +330 -60
  141. sglang/srt/models/qwen2_vl.py +6 -0
  142. sglang/srt/models/qwen3.py +52 -10
  143. sglang/srt/models/qwen3_moe.py +411 -48
  144. sglang/srt/models/roberta.py +1 -1
  145. sglang/srt/models/siglip.py +294 -0
  146. sglang/srt/models/torch_native_llama.py +1 -1
  147. sglang/srt/openai_api/adapter.py +58 -20
  148. sglang/srt/openai_api/protocol.py +6 -8
  149. sglang/srt/operations.py +154 -0
  150. sglang/srt/operations_strategy.py +31 -0
  151. sglang/srt/reasoning_parser.py +3 -3
  152. sglang/srt/sampling/custom_logit_processor.py +18 -3
  153. sglang/srt/sampling/sampling_batch_info.py +4 -56
  154. sglang/srt/sampling/sampling_params.py +2 -2
  155. sglang/srt/server_args.py +162 -22
  156. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  157. sglang/srt/speculative/eagle_utils.py +138 -7
  158. sglang/srt/speculative/eagle_worker.py +69 -21
  159. sglang/srt/utils.py +74 -17
  160. sglang/test/few_shot_gsm8k.py +2 -2
  161. sglang/test/few_shot_gsm8k_engine.py +2 -2
  162. sglang/test/run_eval.py +2 -2
  163. sglang/test/runners.py +8 -1
  164. sglang/test/send_one.py +13 -3
  165. sglang/test/simple_eval_common.py +1 -1
  166. sglang/test/simple_eval_humaneval.py +1 -1
  167. sglang/test/test_cutlass_moe.py +278 -0
  168. sglang/test/test_programs.py +5 -5
  169. sglang/test/test_utils.py +55 -14
  170. sglang/utils.py +3 -3
  171. sglang/version.py +1 -1
  172. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +23 -13
  173. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +178 -149
  174. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
  175. sglang/srt/function_call_parser.py +0 -858
  176. sglang/srt/platforms/interface.py +0 -371
  177. /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
  178. /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
  179. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
  180. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,294 @@
1
+ # Adapted from
2
+ # https://github.com/huggingface/transformers/blob/af9b2eaa54c150741f298d6db939af6328e1dc38/src/transformers/models/siglip/modeling_siglip.py
3
+
4
+ from functools import partial
5
+ from typing import Optional, Type, Union
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+ from transformers import SiglipVisionConfig
10
+
11
+ from sglang.srt.layers.activation import QuickGELU
12
+ from sglang.srt.layers.attention.vision import VisionAttention
13
+ from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
14
+ from sglang.srt.layers.quantization.base_config import QuantizationConfig
15
+ from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
16
+ from sglang.srt.utils import add_prefix
17
+
18
+
19
+ # Adapted from transformers.models.siglip.modeling_siglip.SiglipVisionTransformer
20
+ class SiglipVisionEmbeddings(nn.Module):
21
+
22
+ def __init__(self, config: SiglipVisionConfig):
23
+ super().__init__()
24
+ self.config = config
25
+ self.embed_dim = config.hidden_size
26
+ self.image_size = config.image_size
27
+ self.patch_size = config.patch_size
28
+
29
+ self.patch_embedding = nn.Conv2d(
30
+ in_channels=config.num_channels,
31
+ out_channels=self.embed_dim,
32
+ kernel_size=self.patch_size,
33
+ stride=self.patch_size,
34
+ padding="valid",
35
+ )
36
+
37
+ self.num_patches = (self.image_size // self.patch_size) ** 2
38
+ self.num_positions = self.num_patches
39
+ self.position_embedding = VocabParallelEmbedding(
40
+ self.num_positions, self.embed_dim
41
+ )
42
+ self.register_buffer(
43
+ "position_ids",
44
+ torch.arange(self.num_positions).expand((1, -1)),
45
+ persistent=False,
46
+ )
47
+
48
+ def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
49
+ target_dtype = self.patch_embedding.weight.dtype
50
+ patch_embeds = self.patch_embedding(
51
+ pixel_values.to(dtype=target_dtype)
52
+ ) # shape = [*, width, grid, grid]
53
+ embeddings = patch_embeds.flatten(2).transpose(1, 2)
54
+ # interpolate_pos_encoding is never used in sglang
55
+ embeddings = embeddings + self.position_embedding(self.position_ids)
56
+
57
+ return embeddings
58
+
59
+
60
+ # Copied from sglang.srt.models.clip.CLIPMLP
61
+ class SiglipMLP(nn.Module):
62
+
63
+ def __init__(
64
+ self,
65
+ config,
66
+ act_layer: Type[nn.Module] = QuickGELU,
67
+ quant_config: Optional[QuantizationConfig] = None,
68
+ prefix: str = "",
69
+ ):
70
+ super().__init__()
71
+ self.fc1 = ColumnParallelLinear(
72
+ config.hidden_size,
73
+ config.intermediate_size,
74
+ quant_config=quant_config,
75
+ prefix=add_prefix("fc1", prefix),
76
+ )
77
+ self.act = act_layer()
78
+ self.fc2 = RowParallelLinear(
79
+ config.intermediate_size,
80
+ config.hidden_size,
81
+ quant_config=quant_config,
82
+ prefix=add_prefix("fc2", prefix),
83
+ )
84
+
85
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
86
+ x_parallel, _ = self.fc1(x)
87
+ x_parallel = self.act(x_parallel)
88
+ x, _ = self.fc2(x_parallel)
89
+ return x
90
+
91
+
92
+ # Copied from sglang.srt.models.clip.CLIPEncoderLayer
93
+ class SiglipEncoderLayer(nn.Module):
94
+
95
+ def __init__(
96
+ self,
97
+ config: SiglipVisionConfig,
98
+ act_layer: Type[nn.Module] = QuickGELU,
99
+ norm_layer: Type[nn.Module] = None,
100
+ attn_implementation: Optional[str] = "sdpa",
101
+ quant_config: Optional[QuantizationConfig] = None,
102
+ prefix: str = "",
103
+ ) -> None:
104
+ super().__init__()
105
+ if norm_layer is None:
106
+ norm_layer = partial(nn.LayerNorm, eps=config.layer_norm_eps)
107
+ self.layer_norm1 = norm_layer(config.hidden_size)
108
+ self.layer_norm2 = norm_layer(config.hidden_size)
109
+ if attn_implementation == "sdpa":
110
+ qkv_backend = "sdpa"
111
+ softmax_in_single_precision = False
112
+ elif attn_implementation == "flash_attention_2":
113
+ qkv_backend = "triton_attn"
114
+ softmax_in_single_precision = False
115
+ elif attn_implementation == "eager":
116
+ qkv_backend = "sdpa"
117
+ softmax_in_single_precision = True
118
+ self.self_attn = VisionAttention(
119
+ embed_dim=config.hidden_size,
120
+ num_heads=config.num_attention_heads,
121
+ projection_size=config.hidden_size,
122
+ use_qkv_parallel=True,
123
+ qkv_backend=qkv_backend,
124
+ softmax_in_single_precision=softmax_in_single_precision,
125
+ flatten_batch=True,
126
+ quant_config=quant_config,
127
+ prefix=add_prefix("self_attn", prefix),
128
+ )
129
+ self.mlp = SiglipMLP(
130
+ config,
131
+ act_layer=act_layer,
132
+ quant_config=quant_config,
133
+ prefix=add_prefix("mlp", prefix),
134
+ )
135
+
136
+ def forward(
137
+ self,
138
+ hidden_states: torch.Tensor,
139
+ attention_mask: torch.Tensor,
140
+ causal_attention_mask: torch.Tensor,
141
+ ) -> torch.Tensor:
142
+
143
+ residual = hidden_states
144
+ hidden_states = self.layer_norm1(hidden_states)
145
+ # Siglip text model uses both `causal_attention_mask` and `attention_mask`
146
+ if attention_mask is not None and causal_attention_mask is not None:
147
+ attn_mask = attention_mask + causal_attention_mask
148
+ elif causal_attention_mask is not None:
149
+ attn_mask = causal_attention_mask
150
+ else:
151
+ attn_mask = attention_mask
152
+ hidden_states = self.self_attn(
153
+ hidden_states,
154
+ attention_mask=attn_mask,
155
+ # causal_attention_mask=causal_attention_mask,
156
+ )
157
+
158
+ hidden_states = residual + hidden_states
159
+ residual = hidden_states
160
+ hidden_states = self.layer_norm2(hidden_states)
161
+ hidden_states = self.mlp(hidden_states)
162
+ hidden_states = residual + hidden_states
163
+ return hidden_states
164
+
165
+
166
+ # Copied from sglang.srt.models.clip.CLIPEncoder
167
+ class SiglipEncoder(nn.Module):
168
+ """
169
+ Transformer encoder consisting of `config.num_hidden_layers` self
170
+ attention layers. Each layer is a [`SiglipEncoderLayer`].
171
+
172
+ Args:
173
+ config: SiglipConfig
174
+ """
175
+
176
+ def __init__(
177
+ self,
178
+ config: SiglipVisionConfig,
179
+ quant_config: Optional[QuantizationConfig] = None,
180
+ prefix: str = "",
181
+ ) -> None:
182
+ super().__init__()
183
+
184
+ self.config = config
185
+
186
+ num_hidden_layers = config.num_hidden_layers
187
+ norm_layer = partial(nn.LayerNorm, eps=config.layer_norm_eps)
188
+ self.layers = nn.ModuleList(
189
+ [
190
+ SiglipEncoderLayer(
191
+ config=config,
192
+ norm_layer=norm_layer,
193
+ attn_implementation="sdpa",
194
+ quant_config=quant_config,
195
+ prefix=add_prefix(f"layers.{layer_idx}", prefix),
196
+ )
197
+ for layer_idx in range(num_hidden_layers)
198
+ ]
199
+ )
200
+
201
+ def forward(
202
+ self,
203
+ inputs_embeds: torch.Tensor,
204
+ attention_mask: torch.Tensor = None,
205
+ causal_attention_mask: torch.Tensor = None,
206
+ return_all_hidden_states: bool = False,
207
+ ) -> Union[torch.Tensor, list[torch.Tensor]]:
208
+ hidden_states_pool = [inputs_embeds]
209
+ hidden_states = inputs_embeds
210
+
211
+ for encoder_layer in self.layers:
212
+ hidden_states = encoder_layer(
213
+ hidden_states, attention_mask, causal_attention_mask
214
+ )
215
+ if return_all_hidden_states:
216
+ hidden_states_pool.append(hidden_states)
217
+ if return_all_hidden_states:
218
+ return hidden_states_pool
219
+ return hidden_states
220
+
221
+
222
+ # Adapted from transformers.models.siglip.modeling_siglip.SiglipVisionTransformer
223
+ class SiglipVisionTransformer(nn.Module):
224
+
225
+ def __init__(
226
+ self,
227
+ config: SiglipVisionConfig,
228
+ quant_config: Optional[QuantizationConfig] = None,
229
+ prefix: str = "",
230
+ ) -> None:
231
+ super().__init__()
232
+
233
+ self.config = config
234
+ embed_dim = config.hidden_size
235
+
236
+ self.embeddings = SiglipVisionEmbeddings(config)
237
+
238
+ self.encoder = SiglipEncoder(
239
+ config=config,
240
+ quant_config=quant_config,
241
+ prefix=add_prefix("encoder", prefix),
242
+ )
243
+
244
+ num_hidden_layers = config.num_hidden_layers
245
+ if len(self.encoder.layers) > config.num_hidden_layers:
246
+ raise ValueError(
247
+ f"The original encoder only has {num_hidden_layers} "
248
+ f"layers, but you requested {len(self.encoder.layers)} layers."
249
+ )
250
+
251
+ # VisionAttention in SiglipEncoderLayer is multihead attention
252
+ self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
253
+
254
+ @property
255
+ def device(self) -> torch.device:
256
+ return self.encoder.layers[0].layer_norm1.weight.device
257
+
258
+ def forward(
259
+ self,
260
+ pixel_values: torch.Tensor,
261
+ ) -> torch.Tensor:
262
+ hidden_states = self.embeddings(pixel_values.to(self.device))
263
+
264
+ return_all_hidden_states = False
265
+
266
+ last_hidden_state = self.encoder(
267
+ inputs_embeds=hidden_states,
268
+ return_all_hidden_states=return_all_hidden_states,
269
+ )
270
+
271
+ last_hidden_state = self.post_layernorm(last_hidden_state)
272
+
273
+ return last_hidden_state
274
+
275
+
276
+ # Copied from sglang.srt.models.clip.CLIPVisionModel
277
+ class SiglipVisionModel(nn.Module):
278
+ def __init__(
279
+ self,
280
+ config: SiglipVisionConfig,
281
+ quant_config: Optional[QuantizationConfig] = None,
282
+ prefix: str = "",
283
+ ):
284
+ super().__init__()
285
+ self.vision_model = SiglipVisionTransformer(
286
+ config, quant_config, prefix=add_prefix("vision_model", prefix)
287
+ )
288
+
289
+ @property
290
+ def device(self) -> torch.device:
291
+ return self.vision_model.device
292
+
293
+ def forward(self, pixel_values: torch.Tensor):
294
+ return self.vision_model(pixel_values)
@@ -37,7 +37,7 @@ $ python3 -m sglang.bench_one_batch --correct \
37
37
  --tensor-parallel-size 2 \
38
38
  --disable-cuda-graph
39
39
  ```
40
- We will eanble CUDA Graph support soon.
40
+ We will enable CUDA Graph support soon.
41
41
  """
42
42
 
43
43
  import types
@@ -40,7 +40,7 @@ from sglang.srt.conversation import (
40
40
  get_conv_template_by_model_path,
41
41
  register_conv_template,
42
42
  )
43
- from sglang.srt.function_call_parser import FunctionCallParser
43
+ from sglang.srt.function_call.function_call_parser import FunctionCallParser
44
44
  from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput
45
45
  from sglang.srt.openai_api.protocol import (
46
46
  BatchRequest,
@@ -175,6 +175,32 @@ def guess_chat_template_name_from_model_path(model_path):
175
175
  )
176
176
 
177
177
 
178
+ def _validate_prompt(prompt: str):
179
+ """Validate that the prompt is not empty or whitespace only."""
180
+ is_invalid = False
181
+
182
+ # Check for empty/whitespace string
183
+ if isinstance(prompt, str):
184
+ is_invalid = not prompt.strip()
185
+ # Check for various invalid list cases: [], [""], [" "], [[]]
186
+ elif isinstance(prompt, list):
187
+ is_invalid = not prompt or (
188
+ len(prompt) == 1
189
+ and (
190
+ (isinstance(prompt[0], str) and not prompt[0].strip())
191
+ or (isinstance(prompt[0], list) and not prompt[0])
192
+ )
193
+ )
194
+
195
+ if is_invalid:
196
+ raise HTTPException(
197
+ status_code=400,
198
+ detail="Input cannot be empty or contain only whitespace.",
199
+ )
200
+
201
+ return prompt
202
+
203
+
178
204
  async def v1_files_create(
179
205
  file: UploadFile, purpose: str, file_storage_path: str = None
180
206
  ):
@@ -529,7 +555,6 @@ def v1_generate_request(
529
555
  "temperature": request.temperature,
530
556
  "max_new_tokens": request.max_tokens,
531
557
  "min_new_tokens": request.min_tokens,
532
- "thinking_budget": request.thinking_budget,
533
558
  "stop": request.stop,
534
559
  "stop_token_ids": request.stop_token_ids,
535
560
  "top_p": request.top_p,
@@ -591,7 +616,7 @@ def v1_generate_response(
591
616
  echo = False
592
617
 
593
618
  if (not isinstance(request, list)) and request.echo:
594
- # TODO: handle the case propmt is token ids
619
+ # TODO: handle the case prompt is token ids
595
620
  if isinstance(request.prompt, list) and isinstance(request.prompt[0], str):
596
621
  # for the case of multiple str prompts
597
622
  prompts = request.prompt
@@ -647,7 +672,7 @@ def v1_generate_response(
647
672
  finish_reason = ret_item["meta_info"]["finish_reason"]
648
673
 
649
674
  if to_file:
650
- # to make the choise data json serializable
675
+ # to make the choice data json serializable
651
676
  choice_data = {
652
677
  "index": 0,
653
678
  "text": text,
@@ -945,7 +970,7 @@ def v1_chat_generate_request(
945
970
  # - image_data: None or a list of image strings (URLs or base64 strings).
946
971
  # - audio_data: None or a list of audio strings (URLs).
947
972
  # None skips any image processing in GenerateReqInput.
948
- strict_tag = None
973
+ tool_call_constraint = None
949
974
  prompt = ""
950
975
  prompt_ids = []
951
976
  if not isinstance(request.messages, str):
@@ -964,7 +989,9 @@ def v1_chat_generate_request(
964
989
 
965
990
  tool_call_parser = tokenizer_manager.server_args.tool_call_parser
966
991
  parser = FunctionCallParser(request.tools, tool_call_parser)
967
- strict_tag = parser.get_structure_tag()
992
+ tool_call_constraint = parser.get_structure_constraint(
993
+ request.tool_choice
994
+ )
968
995
 
969
996
  if chat_template_name is None:
970
997
  openai_compatible_messages = []
@@ -1102,7 +1129,6 @@ def v1_chat_generate_request(
1102
1129
  "temperature": request.temperature,
1103
1130
  "max_new_tokens": request.max_tokens or request.max_completion_tokens,
1104
1131
  "min_new_tokens": request.min_tokens,
1105
- "thinking_budget": request.thinking_budget,
1106
1132
  "stop": stop,
1107
1133
  "stop_token_ids": request.stop_token_ids,
1108
1134
  "top_p": request.top_p,
@@ -1132,20 +1158,24 @@ def v1_chat_generate_request(
1132
1158
  request.response_format.model_dump(by_alias=True)
1133
1159
  )
1134
1160
 
1135
- if strict_tag is not None:
1136
- if (
1137
- sampling_params.get("regex")
1138
- or sampling_params.get("ebnf")
1139
- or sampling_params.get("structural_tag")
1140
- or sampling_params.get("json_schema")
1141
- ):
1142
- logger.warning(
1143
- "Constrained decoding is not compatible with tool calls."
1161
+ # Check if there are already existing output constraints
1162
+ has_existing_constraints = (
1163
+ sampling_params.get("regex")
1164
+ or sampling_params.get("ebnf")
1165
+ or sampling_params.get("structural_tag")
1166
+ or sampling_params.get("json_schema")
1167
+ )
1168
+
1169
+ if tool_call_constraint and has_existing_constraints:
1170
+ logger.warning("Constrained decoding is not compatible with tool calls.")
1171
+ elif tool_call_constraint:
1172
+ constraint_type, constraint_value = tool_call_constraint
1173
+ if constraint_type == "structural_tag":
1174
+ sampling_params[constraint_type] = convert_json_schema_to_str(
1175
+ constraint_value.model_dump(by_alias=True)
1144
1176
  )
1145
1177
  else:
1146
- sampling_params["structural_tag"] = convert_json_schema_to_str(
1147
- strict_tag.model_dump(by_alias=True)
1148
- )
1178
+ sampling_params[constraint_type] = constraint_value
1149
1179
 
1150
1180
  sampling_params_list.append(sampling_params)
1151
1181
 
@@ -1169,6 +1199,7 @@ def v1_chat_generate_request(
1169
1199
  top_logprobs_nums = top_logprobs_nums[0]
1170
1200
  modalities_list = modalities_list[0]
1171
1201
  lora_paths = lora_paths[0]
1202
+ request_ids = request_ids[0]
1172
1203
  else:
1173
1204
  if tokenizer_manager.model_config.is_multimodal:
1174
1205
  # processor will need text input
@@ -1405,7 +1436,9 @@ async def v1_chat_completions(
1405
1436
  return create_error_response("Invalid request body, error: ", str(e))
1406
1437
  all_requests = [ChatCompletionRequest(**request_json)]
1407
1438
  created = int(time.time())
1408
- adapted_request, request = v1_chat_generate_request(all_requests, tokenizer_manager)
1439
+ adapted_request, request = v1_chat_generate_request(
1440
+ all_requests, tokenizer_manager, request_ids=[all_requests[0].rid]
1441
+ )
1409
1442
 
1410
1443
  if adapted_request.stream:
1411
1444
  parser_dict = {}
@@ -1755,6 +1788,8 @@ def v1_embedding_request(all_requests, tokenizer_manager):
1755
1788
 
1756
1789
  for request in all_requests:
1757
1790
  prompt = request.input
1791
+ # Check for empty/whitespace string
1792
+ prompt = _validate_prompt(request.input)
1758
1793
  assert (
1759
1794
  type(prompt) is first_prompt_type
1760
1795
  ), "All prompts must be of the same type in file input settings"
@@ -1786,6 +1821,7 @@ def v1_embedding_request(all_requests, tokenizer_manager):
1786
1821
  prompt_kwargs = {"text": generate_prompts, "image_data": images}
1787
1822
  else:
1788
1823
  prompt_kwargs = {"input_ids": prompt}
1824
+ request_ids = all_requests[0].rid
1789
1825
  else:
1790
1826
  if isinstance(prompts[0], str) or isinstance(prompts[0][0], str):
1791
1827
  prompt_kwargs = {"text": prompts}
@@ -1798,8 +1834,10 @@ def v1_embedding_request(all_requests, tokenizer_manager):
1798
1834
  )
1799
1835
  else:
1800
1836
  prompt_kwargs = {"input_ids": prompts}
1837
+ request_ids = [req.rid for req in all_requests]
1801
1838
 
1802
1839
  adapted_request = EmbeddingReqInput(
1840
+ rid=request_ids,
1803
1841
  **prompt_kwargs,
1804
1842
  )
1805
1843
 
@@ -172,7 +172,6 @@ class CompletionRequest(BaseModel):
172
172
  top_k: int = -1
173
173
  min_p: float = 0.0
174
174
  min_tokens: int = 0
175
- thinking_budget: Optional[int] = None
176
175
  json_schema: Optional[str] = None
177
176
  regex: Optional[str] = None
178
177
  ebnf: Optional[str] = None
@@ -351,13 +350,6 @@ class ChatCompletionRequest(BaseModel):
351
350
  description="The maximum number of completion tokens for a chat completion request, "
352
351
  "including visible output tokens and reasoning tokens. Input tokens are not included. ",
353
352
  )
354
- thinking_budget: Optional[int] = Field(
355
- default=None,
356
- description="The maximum number of reasoning tokens that can be generated for a request. "
357
- "This setting of does not affect the thinking process of models. "
358
- "If the number of tokens generated by the model's thinking process exceeds thinking_budget, "
359
- "the reasoning content will be truncated and the final response content will be generated immediately.",
360
- )
361
353
  n: int = 1
362
354
  presence_penalty: float = 0.0
363
355
  response_format: Optional[Union[ResponseFormat, StructuralTagResponseFormat]] = None
@@ -400,6 +392,9 @@ class ChatCompletionRequest(BaseModel):
400
392
  stream_reasoning: bool = True
401
393
  chat_template_kwargs: Optional[Dict] = None
402
394
 
395
+ # The request id.
396
+ rid: Optional[str] = None
397
+
403
398
  # For PD disaggregation
404
399
  bootstrap_host: Optional[str] = None
405
400
  bootstrap_port: Optional[int] = None
@@ -474,6 +469,9 @@ class EmbeddingRequest(BaseModel):
474
469
  dimensions: int = None
475
470
  user: Optional[str] = None
476
471
 
472
+ # The request id.
473
+ rid: Optional[str] = None
474
+
477
475
 
478
476
  class EmbeddingObject(BaseModel):
479
477
  embedding: List[float]
@@ -0,0 +1,154 @@
1
+ import os
2
+ from contextlib import contextmanager
3
+ from dataclasses import dataclass
4
+ from typing import Any, Callable, Dict, Generator, List, Sequence, Union
5
+
6
+ import torch
7
+
8
+ _ENABLE_PROFILE = bool(int(os.environ.get("SGLANG_OPERATIONS_ENABLE_PROFILE", "0")))
9
+
10
+ if _ENABLE_PROFILE:
11
+ import nvtx
12
+
13
+
14
+ def execute_operations(inputs, operations):
15
+ stages = _convert_operations_to_stages(decorate_operations(operations))
16
+ executor = _StageExecutor("primary", stages, inputs=inputs)
17
+ for _ in range(executor.num_stages):
18
+ executor.next()
19
+ assert executor.done
20
+ return executor.output
21
+
22
+
23
+ class YieldOperation:
24
+ pass
25
+
26
+
27
+ @dataclass
28
+ class ExecutionOperation:
29
+ debug_name: str
30
+ fn: Callable
31
+
32
+
33
+ Operation = Union[YieldOperation, ExecutionOperation, Callable]
34
+ Stage = List[ExecutionOperation]
35
+
36
+
37
+ class _StageExecutor:
38
+ def __init__(self, debug_name: str, stages: List[Stage], inputs):
39
+ self._debug_name = debug_name
40
+ self._stages = stages
41
+ self._index = 0
42
+ self._stage_state = _StateDict()
43
+ self._stage_output = inputs
44
+
45
+ def next(self):
46
+ assert not self.done
47
+
48
+ stage = self._stages[self._index]
49
+
50
+ with _annotate_region(debug_name=f"{self._debug_name}{self._index}"):
51
+ for op in stage:
52
+ with _annotate_region(debug_name=op.debug_name):
53
+ self._stage_output = op.fn(
54
+ state=self._stage_state,
55
+ **(
56
+ self._stage_output if self._stage_output is not None else {}
57
+ ),
58
+ )
59
+
60
+ self._index += 1
61
+
62
+ @property
63
+ def output(self):
64
+ assert self.done
65
+ return self._stage_output
66
+
67
+ @property
68
+ def done(self):
69
+ return self._index >= self.num_stages
70
+
71
+ @property
72
+ def num_stages(self):
73
+ return len(self._stages)
74
+
75
+
76
+ @contextmanager
77
+ def _annotate_region(debug_name):
78
+ if _ENABLE_PROFILE:
79
+ with torch.autograd.profiler.record_function(debug_name):
80
+ with nvtx.annotate(debug_name):
81
+ yield
82
+ else:
83
+ yield
84
+
85
+
86
+ class _StateDict:
87
+ def __init__(self):
88
+ self._data = {}
89
+
90
+ def __setattr__(self, key, value):
91
+ if key == "_data":
92
+ super().__setattr__(key, value)
93
+ return
94
+ assert (
95
+ key not in self._data
96
+ ), f"`{key}` already exist, are you sure you want to override it?"
97
+ self._data[key] = value
98
+
99
+ def __getattr__(self, item):
100
+ return self._data[item]
101
+
102
+ def __delattr__(self, item):
103
+ del self._data[item]
104
+
105
+ def pop(self, item):
106
+ return self._data.pop(item)
107
+
108
+ def update(self, values: Dict[str, Any]):
109
+ for k, v in values.items():
110
+ setattr(self, k, v)
111
+
112
+ def clear(self, expect_keys: Sequence[str]):
113
+ if set(self._data.keys()) != set(expect_keys):
114
+ raise Exception(
115
+ f"Unexpected keys when clearning. This may indicate you do not release memory early enough but leave it to here. {list(self._data.keys())=} {expect_keys=}"
116
+ )
117
+
118
+ self._data.clear()
119
+
120
+
121
+ def _convert_operations_to_stages(operations: List[Operation]) -> List[Stage]:
122
+ operation_chunks = list(
123
+ _chunk_by_separator(operations, lambda op: isinstance(op, YieldOperation))
124
+ )
125
+ assert all(len(chunk) > 0 for chunk in operation_chunks)
126
+ return operation_chunks
127
+
128
+
129
+ def _chunk_by_separator(
130
+ items: List[Any], is_separator: Callable[[Any], bool]
131
+ ) -> Generator[List[Any], None, None]:
132
+ pending_items = []
133
+ for item in items:
134
+ if is_separator(item):
135
+ yield pending_items
136
+ pending_items = []
137
+ else:
138
+ pending_items.append(item)
139
+ if len(pending_items) > 0:
140
+ yield pending_items
141
+
142
+
143
+ def decorate_operations(operations: List[Operation], debug_name_prefix: str = ""):
144
+ return [_decorate_operation(op, debug_name_prefix) for op in operations]
145
+
146
+
147
+ def _decorate_operation(operation: Operation, debug_name_prefix: str):
148
+ if isinstance(operation, YieldOperation):
149
+ return operation
150
+ return ExecutionOperation(
151
+ debug_name=debug_name_prefix
152
+ + getattr(operation, "__name__", "unknown").replace("op_", ""),
153
+ fn=operation,
154
+ )