sglang 0.4.6.post4__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +6 -6
- sglang/bench_one_batch.py +5 -4
- sglang/bench_one_batch_server.py +23 -15
- sglang/bench_serving.py +133 -57
- sglang/compile_deep_gemm.py +4 -4
- sglang/srt/configs/model_config.py +39 -28
- sglang/srt/conversation.py +1 -1
- sglang/srt/disaggregation/decode.py +122 -133
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
- sglang/srt/disaggregation/fake/conn.py +3 -13
- sglang/srt/disaggregation/kv_events.py +357 -0
- sglang/srt/disaggregation/mini_lb.py +57 -24
- sglang/srt/disaggregation/mooncake/conn.py +11 -2
- sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
- sglang/srt/disaggregation/nixl/conn.py +9 -19
- sglang/srt/disaggregation/prefill.py +126 -44
- sglang/srt/disaggregation/utils.py +116 -5
- sglang/srt/distributed/utils.py +3 -3
- sglang/srt/entrypoints/EngineBase.py +5 -0
- sglang/srt/entrypoints/engine.py +28 -8
- sglang/srt/entrypoints/http_server.py +6 -4
- sglang/srt/entrypoints/http_server_engine.py +5 -2
- sglang/srt/function_call/base_format_detector.py +250 -0
- sglang/srt/function_call/core_types.py +34 -0
- sglang/srt/function_call/deepseekv3_detector.py +157 -0
- sglang/srt/function_call/ebnf_composer.py +234 -0
- sglang/srt/function_call/function_call_parser.py +175 -0
- sglang/srt/function_call/llama32_detector.py +74 -0
- sglang/srt/function_call/mistral_detector.py +84 -0
- sglang/srt/function_call/pythonic_detector.py +163 -0
- sglang/srt/function_call/qwen25_detector.py +67 -0
- sglang/srt/function_call/utils.py +35 -0
- sglang/srt/hf_transformers_utils.py +46 -7
- sglang/srt/layers/attention/aiter_backend.py +513 -0
- sglang/srt/layers/attention/flashattention_backend.py +63 -17
- sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
- sglang/srt/layers/attention/flashmla_backend.py +340 -78
- sglang/srt/layers/attention/triton_backend.py +3 -0
- sglang/srt/layers/attention/utils.py +2 -2
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/communicator.py +451 -0
- sglang/srt/layers/dp_attention.py +0 -10
- sglang/srt/layers/moe/cutlass_moe.py +207 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +33 -11
- sglang/srt/layers/moe/ep_moe/layer.py +104 -50
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
- sglang/srt/layers/moe/topk.py +66 -9
- sglang/srt/layers/multimodal.py +70 -0
- sglang/srt/layers/quantization/__init__.py +7 -2
- sglang/srt/layers/quantization/deep_gemm.py +5 -3
- sglang/srt/layers/quantization/fp8.py +90 -0
- sglang/srt/layers/quantization/fp8_utils.py +6 -0
- sglang/srt/layers/quantization/gptq.py +298 -6
- sglang/srt/layers/quantization/int8_kernel.py +18 -5
- sglang/srt/layers/quantization/qoq.py +244 -0
- sglang/srt/lora/lora_manager.py +1 -3
- sglang/srt/managers/deepseek_eplb.py +278 -0
- sglang/srt/managers/eplb_manager.py +55 -0
- sglang/srt/managers/expert_distribution.py +704 -56
- sglang/srt/managers/expert_location.py +394 -0
- sglang/srt/managers/expert_location_dispatch.py +91 -0
- sglang/srt/managers/io_struct.py +16 -3
- sglang/srt/managers/mm_utils.py +293 -139
- sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
- sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
- sglang/srt/managers/multimodal_processors/internvl.py +14 -5
- sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
- sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
- sglang/srt/managers/multimodal_processors/llava.py +3 -3
- sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
- sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
- sglang/srt/managers/multimodal_processors/pixtral.py +9 -9
- sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
- sglang/srt/managers/schedule_batch.py +49 -21
- sglang/srt/managers/schedule_policy.py +4 -5
- sglang/srt/managers/scheduler.py +92 -50
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +99 -24
- sglang/srt/mem_cache/base_prefix_cache.py +3 -0
- sglang/srt/mem_cache/chunk_cache.py +3 -1
- sglang/srt/mem_cache/hiradix_cache.py +4 -4
- sglang/srt/mem_cache/memory_pool.py +74 -52
- sglang/srt/mem_cache/multimodal_cache.py +45 -0
- sglang/srt/mem_cache/radix_cache.py +58 -5
- sglang/srt/metrics/collector.py +2 -2
- sglang/srt/mm_utils.py +10 -0
- sglang/srt/model_executor/cuda_graph_runner.py +20 -9
- sglang/srt/model_executor/expert_location_updater.py +422 -0
- sglang/srt/model_executor/forward_batch_info.py +4 -0
- sglang/srt/model_executor/model_runner.py +144 -54
- sglang/srt/model_loader/loader.py +10 -6
- sglang/srt/models/clip.py +5 -1
- sglang/srt/models/deepseek_v2.py +297 -343
- sglang/srt/models/exaone.py +8 -3
- sglang/srt/models/gemma3_mm.py +70 -33
- sglang/srt/models/llama4.py +10 -2
- sglang/srt/models/llava.py +26 -18
- sglang/srt/models/mimo_mtp.py +220 -0
- sglang/srt/models/minicpmo.py +5 -12
- sglang/srt/models/mistral.py +71 -1
- sglang/srt/models/mllama.py +3 -3
- sglang/srt/models/qwen2.py +95 -26
- sglang/srt/models/qwen2_5_vl.py +8 -0
- sglang/srt/models/qwen2_moe.py +330 -60
- sglang/srt/models/qwen2_vl.py +6 -0
- sglang/srt/models/qwen3.py +52 -10
- sglang/srt/models/qwen3_moe.py +411 -48
- sglang/srt/models/siglip.py +294 -0
- sglang/srt/openai_api/adapter.py +28 -16
- sglang/srt/openai_api/protocol.py +6 -0
- sglang/srt/operations.py +154 -0
- sglang/srt/operations_strategy.py +31 -0
- sglang/srt/server_args.py +134 -24
- sglang/srt/speculative/eagle_utils.py +131 -0
- sglang/srt/speculative/eagle_worker.py +47 -2
- sglang/srt/utils.py +68 -12
- sglang/test/test_cutlass_moe.py +278 -0
- sglang/test/test_utils.py +2 -36
- sglang/utils.py +2 -2
- sglang/version.py +1 -1
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +20 -11
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +128 -102
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
- sglang/srt/function_call_parser.py +0 -858
- sglang/srt/platforms/interface.py +0 -371
- /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,294 @@
|
|
1
|
+
# Adapted from
|
2
|
+
# https://github.com/huggingface/transformers/blob/af9b2eaa54c150741f298d6db939af6328e1dc38/src/transformers/models/siglip/modeling_siglip.py
|
3
|
+
|
4
|
+
from functools import partial
|
5
|
+
from typing import Optional, Type, Union
|
6
|
+
|
7
|
+
import torch
|
8
|
+
import torch.nn as nn
|
9
|
+
from transformers import SiglipVisionConfig
|
10
|
+
|
11
|
+
from sglang.srt.layers.activation import QuickGELU
|
12
|
+
from sglang.srt.layers.attention.vision import VisionAttention
|
13
|
+
from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
|
14
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
15
|
+
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
16
|
+
from sglang.srt.utils import add_prefix
|
17
|
+
|
18
|
+
|
19
|
+
# Adapted from transformers.models.siglip.modeling_siglip.SiglipVisionTransformer
|
20
|
+
class SiglipVisionEmbeddings(nn.Module):
|
21
|
+
|
22
|
+
def __init__(self, config: SiglipVisionConfig):
|
23
|
+
super().__init__()
|
24
|
+
self.config = config
|
25
|
+
self.embed_dim = config.hidden_size
|
26
|
+
self.image_size = config.image_size
|
27
|
+
self.patch_size = config.patch_size
|
28
|
+
|
29
|
+
self.patch_embedding = nn.Conv2d(
|
30
|
+
in_channels=config.num_channels,
|
31
|
+
out_channels=self.embed_dim,
|
32
|
+
kernel_size=self.patch_size,
|
33
|
+
stride=self.patch_size,
|
34
|
+
padding="valid",
|
35
|
+
)
|
36
|
+
|
37
|
+
self.num_patches = (self.image_size // self.patch_size) ** 2
|
38
|
+
self.num_positions = self.num_patches
|
39
|
+
self.position_embedding = VocabParallelEmbedding(
|
40
|
+
self.num_positions, self.embed_dim
|
41
|
+
)
|
42
|
+
self.register_buffer(
|
43
|
+
"position_ids",
|
44
|
+
torch.arange(self.num_positions).expand((1, -1)),
|
45
|
+
persistent=False,
|
46
|
+
)
|
47
|
+
|
48
|
+
def forward(self, pixel_values: torch.Tensor) -> torch.Tensor:
|
49
|
+
target_dtype = self.patch_embedding.weight.dtype
|
50
|
+
patch_embeds = self.patch_embedding(
|
51
|
+
pixel_values.to(dtype=target_dtype)
|
52
|
+
) # shape = [*, width, grid, grid]
|
53
|
+
embeddings = patch_embeds.flatten(2).transpose(1, 2)
|
54
|
+
# interpolate_pos_encoding is never used in sglang
|
55
|
+
embeddings = embeddings + self.position_embedding(self.position_ids)
|
56
|
+
|
57
|
+
return embeddings
|
58
|
+
|
59
|
+
|
60
|
+
# Copied from sglang.srt.models.clip.CLIPMLP
|
61
|
+
class SiglipMLP(nn.Module):
|
62
|
+
|
63
|
+
def __init__(
|
64
|
+
self,
|
65
|
+
config,
|
66
|
+
act_layer: Type[nn.Module] = QuickGELU,
|
67
|
+
quant_config: Optional[QuantizationConfig] = None,
|
68
|
+
prefix: str = "",
|
69
|
+
):
|
70
|
+
super().__init__()
|
71
|
+
self.fc1 = ColumnParallelLinear(
|
72
|
+
config.hidden_size,
|
73
|
+
config.intermediate_size,
|
74
|
+
quant_config=quant_config,
|
75
|
+
prefix=add_prefix("fc1", prefix),
|
76
|
+
)
|
77
|
+
self.act = act_layer()
|
78
|
+
self.fc2 = RowParallelLinear(
|
79
|
+
config.intermediate_size,
|
80
|
+
config.hidden_size,
|
81
|
+
quant_config=quant_config,
|
82
|
+
prefix=add_prefix("fc2", prefix),
|
83
|
+
)
|
84
|
+
|
85
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
86
|
+
x_parallel, _ = self.fc1(x)
|
87
|
+
x_parallel = self.act(x_parallel)
|
88
|
+
x, _ = self.fc2(x_parallel)
|
89
|
+
return x
|
90
|
+
|
91
|
+
|
92
|
+
# Copied from sglang.srt.models.clip.CLIPEncoderLayer
|
93
|
+
class SiglipEncoderLayer(nn.Module):
|
94
|
+
|
95
|
+
def __init__(
|
96
|
+
self,
|
97
|
+
config: SiglipVisionConfig,
|
98
|
+
act_layer: Type[nn.Module] = QuickGELU,
|
99
|
+
norm_layer: Type[nn.Module] = None,
|
100
|
+
attn_implementation: Optional[str] = "sdpa",
|
101
|
+
quant_config: Optional[QuantizationConfig] = None,
|
102
|
+
prefix: str = "",
|
103
|
+
) -> None:
|
104
|
+
super().__init__()
|
105
|
+
if norm_layer is None:
|
106
|
+
norm_layer = partial(nn.LayerNorm, eps=config.layer_norm_eps)
|
107
|
+
self.layer_norm1 = norm_layer(config.hidden_size)
|
108
|
+
self.layer_norm2 = norm_layer(config.hidden_size)
|
109
|
+
if attn_implementation == "sdpa":
|
110
|
+
qkv_backend = "sdpa"
|
111
|
+
softmax_in_single_precision = False
|
112
|
+
elif attn_implementation == "flash_attention_2":
|
113
|
+
qkv_backend = "triton_attn"
|
114
|
+
softmax_in_single_precision = False
|
115
|
+
elif attn_implementation == "eager":
|
116
|
+
qkv_backend = "sdpa"
|
117
|
+
softmax_in_single_precision = True
|
118
|
+
self.self_attn = VisionAttention(
|
119
|
+
embed_dim=config.hidden_size,
|
120
|
+
num_heads=config.num_attention_heads,
|
121
|
+
projection_size=config.hidden_size,
|
122
|
+
use_qkv_parallel=True,
|
123
|
+
qkv_backend=qkv_backend,
|
124
|
+
softmax_in_single_precision=softmax_in_single_precision,
|
125
|
+
flatten_batch=True,
|
126
|
+
quant_config=quant_config,
|
127
|
+
prefix=add_prefix("self_attn", prefix),
|
128
|
+
)
|
129
|
+
self.mlp = SiglipMLP(
|
130
|
+
config,
|
131
|
+
act_layer=act_layer,
|
132
|
+
quant_config=quant_config,
|
133
|
+
prefix=add_prefix("mlp", prefix),
|
134
|
+
)
|
135
|
+
|
136
|
+
def forward(
|
137
|
+
self,
|
138
|
+
hidden_states: torch.Tensor,
|
139
|
+
attention_mask: torch.Tensor,
|
140
|
+
causal_attention_mask: torch.Tensor,
|
141
|
+
) -> torch.Tensor:
|
142
|
+
|
143
|
+
residual = hidden_states
|
144
|
+
hidden_states = self.layer_norm1(hidden_states)
|
145
|
+
# Siglip text model uses both `causal_attention_mask` and `attention_mask`
|
146
|
+
if attention_mask is not None and causal_attention_mask is not None:
|
147
|
+
attn_mask = attention_mask + causal_attention_mask
|
148
|
+
elif causal_attention_mask is not None:
|
149
|
+
attn_mask = causal_attention_mask
|
150
|
+
else:
|
151
|
+
attn_mask = attention_mask
|
152
|
+
hidden_states = self.self_attn(
|
153
|
+
hidden_states,
|
154
|
+
attention_mask=attn_mask,
|
155
|
+
# causal_attention_mask=causal_attention_mask,
|
156
|
+
)
|
157
|
+
|
158
|
+
hidden_states = residual + hidden_states
|
159
|
+
residual = hidden_states
|
160
|
+
hidden_states = self.layer_norm2(hidden_states)
|
161
|
+
hidden_states = self.mlp(hidden_states)
|
162
|
+
hidden_states = residual + hidden_states
|
163
|
+
return hidden_states
|
164
|
+
|
165
|
+
|
166
|
+
# Copied from sglang.srt.models.clip.CLIPEncoder
|
167
|
+
class SiglipEncoder(nn.Module):
|
168
|
+
"""
|
169
|
+
Transformer encoder consisting of `config.num_hidden_layers` self
|
170
|
+
attention layers. Each layer is a [`SiglipEncoderLayer`].
|
171
|
+
|
172
|
+
Args:
|
173
|
+
config: SiglipConfig
|
174
|
+
"""
|
175
|
+
|
176
|
+
def __init__(
|
177
|
+
self,
|
178
|
+
config: SiglipVisionConfig,
|
179
|
+
quant_config: Optional[QuantizationConfig] = None,
|
180
|
+
prefix: str = "",
|
181
|
+
) -> None:
|
182
|
+
super().__init__()
|
183
|
+
|
184
|
+
self.config = config
|
185
|
+
|
186
|
+
num_hidden_layers = config.num_hidden_layers
|
187
|
+
norm_layer = partial(nn.LayerNorm, eps=config.layer_norm_eps)
|
188
|
+
self.layers = nn.ModuleList(
|
189
|
+
[
|
190
|
+
SiglipEncoderLayer(
|
191
|
+
config=config,
|
192
|
+
norm_layer=norm_layer,
|
193
|
+
attn_implementation="sdpa",
|
194
|
+
quant_config=quant_config,
|
195
|
+
prefix=add_prefix(f"layers.{layer_idx}", prefix),
|
196
|
+
)
|
197
|
+
for layer_idx in range(num_hidden_layers)
|
198
|
+
]
|
199
|
+
)
|
200
|
+
|
201
|
+
def forward(
|
202
|
+
self,
|
203
|
+
inputs_embeds: torch.Tensor,
|
204
|
+
attention_mask: torch.Tensor = None,
|
205
|
+
causal_attention_mask: torch.Tensor = None,
|
206
|
+
return_all_hidden_states: bool = False,
|
207
|
+
) -> Union[torch.Tensor, list[torch.Tensor]]:
|
208
|
+
hidden_states_pool = [inputs_embeds]
|
209
|
+
hidden_states = inputs_embeds
|
210
|
+
|
211
|
+
for encoder_layer in self.layers:
|
212
|
+
hidden_states = encoder_layer(
|
213
|
+
hidden_states, attention_mask, causal_attention_mask
|
214
|
+
)
|
215
|
+
if return_all_hidden_states:
|
216
|
+
hidden_states_pool.append(hidden_states)
|
217
|
+
if return_all_hidden_states:
|
218
|
+
return hidden_states_pool
|
219
|
+
return hidden_states
|
220
|
+
|
221
|
+
|
222
|
+
# Adapted from transformers.models.siglip.modeling_siglip.SiglipVisionTransformer
|
223
|
+
class SiglipVisionTransformer(nn.Module):
|
224
|
+
|
225
|
+
def __init__(
|
226
|
+
self,
|
227
|
+
config: SiglipVisionConfig,
|
228
|
+
quant_config: Optional[QuantizationConfig] = None,
|
229
|
+
prefix: str = "",
|
230
|
+
) -> None:
|
231
|
+
super().__init__()
|
232
|
+
|
233
|
+
self.config = config
|
234
|
+
embed_dim = config.hidden_size
|
235
|
+
|
236
|
+
self.embeddings = SiglipVisionEmbeddings(config)
|
237
|
+
|
238
|
+
self.encoder = SiglipEncoder(
|
239
|
+
config=config,
|
240
|
+
quant_config=quant_config,
|
241
|
+
prefix=add_prefix("encoder", prefix),
|
242
|
+
)
|
243
|
+
|
244
|
+
num_hidden_layers = config.num_hidden_layers
|
245
|
+
if len(self.encoder.layers) > config.num_hidden_layers:
|
246
|
+
raise ValueError(
|
247
|
+
f"The original encoder only has {num_hidden_layers} "
|
248
|
+
f"layers, but you requested {len(self.encoder.layers)} layers."
|
249
|
+
)
|
250
|
+
|
251
|
+
# VisionAttention in SiglipEncoderLayer is multihead attention
|
252
|
+
self.post_layernorm = nn.LayerNorm(embed_dim, eps=config.layer_norm_eps)
|
253
|
+
|
254
|
+
@property
|
255
|
+
def device(self) -> torch.device:
|
256
|
+
return self.encoder.layers[0].layer_norm1.weight.device
|
257
|
+
|
258
|
+
def forward(
|
259
|
+
self,
|
260
|
+
pixel_values: torch.Tensor,
|
261
|
+
) -> torch.Tensor:
|
262
|
+
hidden_states = self.embeddings(pixel_values.to(self.device))
|
263
|
+
|
264
|
+
return_all_hidden_states = False
|
265
|
+
|
266
|
+
last_hidden_state = self.encoder(
|
267
|
+
inputs_embeds=hidden_states,
|
268
|
+
return_all_hidden_states=return_all_hidden_states,
|
269
|
+
)
|
270
|
+
|
271
|
+
last_hidden_state = self.post_layernorm(last_hidden_state)
|
272
|
+
|
273
|
+
return last_hidden_state
|
274
|
+
|
275
|
+
|
276
|
+
# Copied from sglang.srt.models.clip.CLIPVisionModel
|
277
|
+
class SiglipVisionModel(nn.Module):
|
278
|
+
def __init__(
|
279
|
+
self,
|
280
|
+
config: SiglipVisionConfig,
|
281
|
+
quant_config: Optional[QuantizationConfig] = None,
|
282
|
+
prefix: str = "",
|
283
|
+
):
|
284
|
+
super().__init__()
|
285
|
+
self.vision_model = SiglipVisionTransformer(
|
286
|
+
config, quant_config, prefix=add_prefix("vision_model", prefix)
|
287
|
+
)
|
288
|
+
|
289
|
+
@property
|
290
|
+
def device(self) -> torch.device:
|
291
|
+
return self.vision_model.device
|
292
|
+
|
293
|
+
def forward(self, pixel_values: torch.Tensor):
|
294
|
+
return self.vision_model(pixel_values)
|
sglang/srt/openai_api/adapter.py
CHANGED
@@ -40,7 +40,7 @@ from sglang.srt.conversation import (
|
|
40
40
|
get_conv_template_by_model_path,
|
41
41
|
register_conv_template,
|
42
42
|
)
|
43
|
-
from sglang.srt.function_call_parser import FunctionCallParser
|
43
|
+
from sglang.srt.function_call.function_call_parser import FunctionCallParser
|
44
44
|
from sglang.srt.managers.io_struct import EmbeddingReqInput, GenerateReqInput
|
45
45
|
from sglang.srt.openai_api.protocol import (
|
46
46
|
BatchRequest,
|
@@ -970,7 +970,7 @@ def v1_chat_generate_request(
|
|
970
970
|
# - image_data: None or a list of image strings (URLs or base64 strings).
|
971
971
|
# - audio_data: None or a list of audio strings (URLs).
|
972
972
|
# None skips any image processing in GenerateReqInput.
|
973
|
-
|
973
|
+
tool_call_constraint = None
|
974
974
|
prompt = ""
|
975
975
|
prompt_ids = []
|
976
976
|
if not isinstance(request.messages, str):
|
@@ -989,7 +989,9 @@ def v1_chat_generate_request(
|
|
989
989
|
|
990
990
|
tool_call_parser = tokenizer_manager.server_args.tool_call_parser
|
991
991
|
parser = FunctionCallParser(request.tools, tool_call_parser)
|
992
|
-
|
992
|
+
tool_call_constraint = parser.get_structure_constraint(
|
993
|
+
request.tool_choice
|
994
|
+
)
|
993
995
|
|
994
996
|
if chat_template_name is None:
|
995
997
|
openai_compatible_messages = []
|
@@ -1156,20 +1158,24 @@ def v1_chat_generate_request(
|
|
1156
1158
|
request.response_format.model_dump(by_alias=True)
|
1157
1159
|
)
|
1158
1160
|
|
1159
|
-
if
|
1160
|
-
|
1161
|
-
|
1162
|
-
|
1163
|
-
|
1164
|
-
|
1165
|
-
|
1166
|
-
|
1167
|
-
|
1161
|
+
# Check if there are already existing output constraints
|
1162
|
+
has_existing_constraints = (
|
1163
|
+
sampling_params.get("regex")
|
1164
|
+
or sampling_params.get("ebnf")
|
1165
|
+
or sampling_params.get("structural_tag")
|
1166
|
+
or sampling_params.get("json_schema")
|
1167
|
+
)
|
1168
|
+
|
1169
|
+
if tool_call_constraint and has_existing_constraints:
|
1170
|
+
logger.warning("Constrained decoding is not compatible with tool calls.")
|
1171
|
+
elif tool_call_constraint:
|
1172
|
+
constraint_type, constraint_value = tool_call_constraint
|
1173
|
+
if constraint_type == "structural_tag":
|
1174
|
+
sampling_params[constraint_type] = convert_json_schema_to_str(
|
1175
|
+
constraint_value.model_dump(by_alias=True)
|
1168
1176
|
)
|
1169
1177
|
else:
|
1170
|
-
sampling_params[
|
1171
|
-
strict_tag.model_dump(by_alias=True)
|
1172
|
-
)
|
1178
|
+
sampling_params[constraint_type] = constraint_value
|
1173
1179
|
|
1174
1180
|
sampling_params_list.append(sampling_params)
|
1175
1181
|
|
@@ -1193,6 +1199,7 @@ def v1_chat_generate_request(
|
|
1193
1199
|
top_logprobs_nums = top_logprobs_nums[0]
|
1194
1200
|
modalities_list = modalities_list[0]
|
1195
1201
|
lora_paths = lora_paths[0]
|
1202
|
+
request_ids = request_ids[0]
|
1196
1203
|
else:
|
1197
1204
|
if tokenizer_manager.model_config.is_multimodal:
|
1198
1205
|
# processor will need text input
|
@@ -1429,7 +1436,9 @@ async def v1_chat_completions(
|
|
1429
1436
|
return create_error_response("Invalid request body, error: ", str(e))
|
1430
1437
|
all_requests = [ChatCompletionRequest(**request_json)]
|
1431
1438
|
created = int(time.time())
|
1432
|
-
adapted_request, request = v1_chat_generate_request(
|
1439
|
+
adapted_request, request = v1_chat_generate_request(
|
1440
|
+
all_requests, tokenizer_manager, request_ids=[all_requests[0].rid]
|
1441
|
+
)
|
1433
1442
|
|
1434
1443
|
if adapted_request.stream:
|
1435
1444
|
parser_dict = {}
|
@@ -1812,6 +1821,7 @@ def v1_embedding_request(all_requests, tokenizer_manager):
|
|
1812
1821
|
prompt_kwargs = {"text": generate_prompts, "image_data": images}
|
1813
1822
|
else:
|
1814
1823
|
prompt_kwargs = {"input_ids": prompt}
|
1824
|
+
request_ids = all_requests[0].rid
|
1815
1825
|
else:
|
1816
1826
|
if isinstance(prompts[0], str) or isinstance(prompts[0][0], str):
|
1817
1827
|
prompt_kwargs = {"text": prompts}
|
@@ -1824,8 +1834,10 @@ def v1_embedding_request(all_requests, tokenizer_manager):
|
|
1824
1834
|
)
|
1825
1835
|
else:
|
1826
1836
|
prompt_kwargs = {"input_ids": prompts}
|
1837
|
+
request_ids = [req.rid for req in all_requests]
|
1827
1838
|
|
1828
1839
|
adapted_request = EmbeddingReqInput(
|
1840
|
+
rid=request_ids,
|
1829
1841
|
**prompt_kwargs,
|
1830
1842
|
)
|
1831
1843
|
|
@@ -392,6 +392,9 @@ class ChatCompletionRequest(BaseModel):
|
|
392
392
|
stream_reasoning: bool = True
|
393
393
|
chat_template_kwargs: Optional[Dict] = None
|
394
394
|
|
395
|
+
# The request id.
|
396
|
+
rid: Optional[str] = None
|
397
|
+
|
395
398
|
# For PD disaggregation
|
396
399
|
bootstrap_host: Optional[str] = None
|
397
400
|
bootstrap_port: Optional[int] = None
|
@@ -466,6 +469,9 @@ class EmbeddingRequest(BaseModel):
|
|
466
469
|
dimensions: int = None
|
467
470
|
user: Optional[str] = None
|
468
471
|
|
472
|
+
# The request id.
|
473
|
+
rid: Optional[str] = None
|
474
|
+
|
469
475
|
|
470
476
|
class EmbeddingObject(BaseModel):
|
471
477
|
embedding: List[float]
|
sglang/srt/operations.py
ADDED
@@ -0,0 +1,154 @@
|
|
1
|
+
import os
|
2
|
+
from contextlib import contextmanager
|
3
|
+
from dataclasses import dataclass
|
4
|
+
from typing import Any, Callable, Dict, Generator, List, Sequence, Union
|
5
|
+
|
6
|
+
import torch
|
7
|
+
|
8
|
+
_ENABLE_PROFILE = bool(int(os.environ.get("SGLANG_OPERATIONS_ENABLE_PROFILE", "0")))
|
9
|
+
|
10
|
+
if _ENABLE_PROFILE:
|
11
|
+
import nvtx
|
12
|
+
|
13
|
+
|
14
|
+
def execute_operations(inputs, operations):
|
15
|
+
stages = _convert_operations_to_stages(decorate_operations(operations))
|
16
|
+
executor = _StageExecutor("primary", stages, inputs=inputs)
|
17
|
+
for _ in range(executor.num_stages):
|
18
|
+
executor.next()
|
19
|
+
assert executor.done
|
20
|
+
return executor.output
|
21
|
+
|
22
|
+
|
23
|
+
class YieldOperation:
|
24
|
+
pass
|
25
|
+
|
26
|
+
|
27
|
+
@dataclass
|
28
|
+
class ExecutionOperation:
|
29
|
+
debug_name: str
|
30
|
+
fn: Callable
|
31
|
+
|
32
|
+
|
33
|
+
Operation = Union[YieldOperation, ExecutionOperation, Callable]
|
34
|
+
Stage = List[ExecutionOperation]
|
35
|
+
|
36
|
+
|
37
|
+
class _StageExecutor:
|
38
|
+
def __init__(self, debug_name: str, stages: List[Stage], inputs):
|
39
|
+
self._debug_name = debug_name
|
40
|
+
self._stages = stages
|
41
|
+
self._index = 0
|
42
|
+
self._stage_state = _StateDict()
|
43
|
+
self._stage_output = inputs
|
44
|
+
|
45
|
+
def next(self):
|
46
|
+
assert not self.done
|
47
|
+
|
48
|
+
stage = self._stages[self._index]
|
49
|
+
|
50
|
+
with _annotate_region(debug_name=f"{self._debug_name}{self._index}"):
|
51
|
+
for op in stage:
|
52
|
+
with _annotate_region(debug_name=op.debug_name):
|
53
|
+
self._stage_output = op.fn(
|
54
|
+
state=self._stage_state,
|
55
|
+
**(
|
56
|
+
self._stage_output if self._stage_output is not None else {}
|
57
|
+
),
|
58
|
+
)
|
59
|
+
|
60
|
+
self._index += 1
|
61
|
+
|
62
|
+
@property
|
63
|
+
def output(self):
|
64
|
+
assert self.done
|
65
|
+
return self._stage_output
|
66
|
+
|
67
|
+
@property
|
68
|
+
def done(self):
|
69
|
+
return self._index >= self.num_stages
|
70
|
+
|
71
|
+
@property
|
72
|
+
def num_stages(self):
|
73
|
+
return len(self._stages)
|
74
|
+
|
75
|
+
|
76
|
+
@contextmanager
|
77
|
+
def _annotate_region(debug_name):
|
78
|
+
if _ENABLE_PROFILE:
|
79
|
+
with torch.autograd.profiler.record_function(debug_name):
|
80
|
+
with nvtx.annotate(debug_name):
|
81
|
+
yield
|
82
|
+
else:
|
83
|
+
yield
|
84
|
+
|
85
|
+
|
86
|
+
class _StateDict:
|
87
|
+
def __init__(self):
|
88
|
+
self._data = {}
|
89
|
+
|
90
|
+
def __setattr__(self, key, value):
|
91
|
+
if key == "_data":
|
92
|
+
super().__setattr__(key, value)
|
93
|
+
return
|
94
|
+
assert (
|
95
|
+
key not in self._data
|
96
|
+
), f"`{key}` already exist, are you sure you want to override it?"
|
97
|
+
self._data[key] = value
|
98
|
+
|
99
|
+
def __getattr__(self, item):
|
100
|
+
return self._data[item]
|
101
|
+
|
102
|
+
def __delattr__(self, item):
|
103
|
+
del self._data[item]
|
104
|
+
|
105
|
+
def pop(self, item):
|
106
|
+
return self._data.pop(item)
|
107
|
+
|
108
|
+
def update(self, values: Dict[str, Any]):
|
109
|
+
for k, v in values.items():
|
110
|
+
setattr(self, k, v)
|
111
|
+
|
112
|
+
def clear(self, expect_keys: Sequence[str]):
|
113
|
+
if set(self._data.keys()) != set(expect_keys):
|
114
|
+
raise Exception(
|
115
|
+
f"Unexpected keys when clearning. This may indicate you do not release memory early enough but leave it to here. {list(self._data.keys())=} {expect_keys=}"
|
116
|
+
)
|
117
|
+
|
118
|
+
self._data.clear()
|
119
|
+
|
120
|
+
|
121
|
+
def _convert_operations_to_stages(operations: List[Operation]) -> List[Stage]:
|
122
|
+
operation_chunks = list(
|
123
|
+
_chunk_by_separator(operations, lambda op: isinstance(op, YieldOperation))
|
124
|
+
)
|
125
|
+
assert all(len(chunk) > 0 for chunk in operation_chunks)
|
126
|
+
return operation_chunks
|
127
|
+
|
128
|
+
|
129
|
+
def _chunk_by_separator(
|
130
|
+
items: List[Any], is_separator: Callable[[Any], bool]
|
131
|
+
) -> Generator[List[Any], None, None]:
|
132
|
+
pending_items = []
|
133
|
+
for item in items:
|
134
|
+
if is_separator(item):
|
135
|
+
yield pending_items
|
136
|
+
pending_items = []
|
137
|
+
else:
|
138
|
+
pending_items.append(item)
|
139
|
+
if len(pending_items) > 0:
|
140
|
+
yield pending_items
|
141
|
+
|
142
|
+
|
143
|
+
def decorate_operations(operations: List[Operation], debug_name_prefix: str = ""):
|
144
|
+
return [_decorate_operation(op, debug_name_prefix) for op in operations]
|
145
|
+
|
146
|
+
|
147
|
+
def _decorate_operation(operation: Operation, debug_name_prefix: str):
|
148
|
+
if isinstance(operation, YieldOperation):
|
149
|
+
return operation
|
150
|
+
return ExecutionOperation(
|
151
|
+
debug_name=debug_name_prefix
|
152
|
+
+ getattr(operation, "__name__", "unknown").replace("op_", ""),
|
153
|
+
fn=operation,
|
154
|
+
)
|
@@ -0,0 +1,31 @@
|
|
1
|
+
import torch
|
2
|
+
|
3
|
+
|
4
|
+
def compute_layer_operations(
|
5
|
+
layer: torch.nn.Module,
|
6
|
+
):
|
7
|
+
if not layer.is_layer_sparse:
|
8
|
+
return [
|
9
|
+
layer.op_comm_prepare_attn,
|
10
|
+
layer.op_attn,
|
11
|
+
layer.op_comm_prepare_mlp,
|
12
|
+
layer.op_mlp,
|
13
|
+
layer.op_comm_postprocess_layer,
|
14
|
+
]
|
15
|
+
|
16
|
+
# Will add TBO operation orders here
|
17
|
+
return [
|
18
|
+
layer.op_comm_prepare_attn,
|
19
|
+
layer.op_attn,
|
20
|
+
layer.op_comm_prepare_mlp,
|
21
|
+
layer.mlp.op_gate,
|
22
|
+
layer.mlp.op_shared_experts,
|
23
|
+
layer.mlp.op_select_experts,
|
24
|
+
layer.mlp.op_dispatch_a,
|
25
|
+
layer.mlp.op_dispatch_b,
|
26
|
+
layer.mlp.op_experts,
|
27
|
+
layer.mlp.op_combine_a,
|
28
|
+
layer.mlp.op_combine_b,
|
29
|
+
layer.mlp.op_output,
|
30
|
+
layer.op_comm_postprocess_layer,
|
31
|
+
]
|