sglang 0.4.0.post2__py3-none-any.whl → 0.4.1.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +0 -12
- sglang/bench_one_batch.py +0 -12
- sglang/bench_serving.py +11 -2
- sglang/lang/backend/openai.py +10 -0
- sglang/srt/aio_rwlock.py +100 -0
- sglang/srt/configs/model_config.py +8 -1
- sglang/srt/constrained/xgrammar_backend.py +6 -0
- sglang/srt/layers/attention/flashinfer_backend.py +49 -5
- sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -14
- sglang/srt/layers/linear.py +20 -2
- sglang/srt/layers/{ep_moe → moe/ep_moe}/layer.py +14 -39
- sglang/srt/layers/moe/fused_moe_native.py +46 -0
- sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/__init__.py +3 -7
- sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/fused_moe.py +124 -99
- sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/layer.py +16 -48
- sglang/srt/layers/moe/topk.py +205 -0
- sglang/srt/layers/quantization/__init__.py +3 -3
- sglang/srt/layers/quantization/fp8.py +169 -32
- sglang/srt/layers/quantization/fp8_kernel.py +292 -0
- sglang/srt/layers/quantization/fp8_utils.py +90 -1
- sglang/srt/layers/torchao_utils.py +11 -15
- sglang/srt/managers/schedule_batch.py +16 -10
- sglang/srt/managers/schedule_policy.py +1 -1
- sglang/srt/managers/scheduler.py +13 -16
- sglang/srt/managers/tokenizer_manager.py +130 -111
- sglang/srt/mem_cache/memory_pool.py +15 -8
- sglang/srt/model_executor/cuda_graph_runner.py +1 -1
- sglang/srt/model_loader/loader.py +22 -11
- sglang/srt/models/dbrx.py +1 -1
- sglang/srt/models/deepseek.py +1 -1
- sglang/srt/models/deepseek_v2.py +67 -18
- sglang/srt/models/gemma2.py +19 -0
- sglang/srt/models/grok.py +1 -1
- sglang/srt/models/llama.py +2 -2
- sglang/srt/models/mixtral.py +2 -2
- sglang/srt/models/olmoe.py +1 -1
- sglang/srt/models/qwen2_moe.py +1 -1
- sglang/srt/models/xverse_moe.py +1 -1
- sglang/srt/openai_api/adapter.py +23 -0
- sglang/srt/openai_api/protocol.py +2 -0
- sglang/srt/sampling/sampling_params.py +9 -2
- sglang/srt/server.py +21 -37
- sglang/srt/utils.py +33 -44
- sglang/test/test_block_fp8.py +341 -0
- sglang/version.py +1 -1
- {sglang-0.4.0.post2.dist-info → sglang-0.4.1.post1.dist-info}/METADATA +4 -4
- {sglang-0.4.0.post2.dist-info → sglang-0.4.1.post1.dist-info}/RECORD +52 -48
- sglang/srt/layers/fused_moe_patch.py +0 -133
- /sglang/srt/layers/{ep_moe → moe/ep_moe}/__init__.py +0 -0
- /sglang/srt/layers/{ep_moe → moe/ep_moe}/kernels.py +0 -0
- {sglang-0.4.0.post2.dist-info → sglang-0.4.1.post1.dist-info}/LICENSE +0 -0
- {sglang-0.4.0.post2.dist-info → sglang-0.4.1.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.0.post2.dist-info → sglang-0.4.1.post1.dist-info}/top_level.txt +0 -0
sglang/srt/models/deepseek_v2.py
CHANGED
@@ -19,6 +19,7 @@
|
|
19
19
|
from typing import Any, Dict, Iterable, Optional, Tuple
|
20
20
|
|
21
21
|
import torch
|
22
|
+
import torch.nn.functional as F
|
22
23
|
from torch import nn
|
23
24
|
from transformers import PretrainedConfig
|
24
25
|
from vllm import _custom_ops as ops
|
@@ -31,8 +32,6 @@ from vllm.distributed import (
|
|
31
32
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
32
33
|
|
33
34
|
from sglang.srt.layers.activation import SiluAndMul
|
34
|
-
from sglang.srt.layers.ep_moe.layer import EPMoE
|
35
|
-
from sglang.srt.layers.fused_moe_triton import FusedMoE
|
36
35
|
from sglang.srt.layers.layernorm import RMSNorm
|
37
36
|
from sglang.srt.layers.linear import (
|
38
37
|
ColumnParallelLinear,
|
@@ -41,7 +40,13 @@ from sglang.srt.layers.linear import (
|
|
41
40
|
RowParallelLinear,
|
42
41
|
)
|
43
42
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
43
|
+
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
44
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
44
45
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
46
|
+
from sglang.srt.layers.quantization.fp8_utils import (
|
47
|
+
block_quant_to_tensor_quant,
|
48
|
+
input_to_float8,
|
49
|
+
)
|
45
50
|
from sglang.srt.layers.radix_attention import RadixAttention
|
46
51
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
47
52
|
ParallelLMHead,
|
@@ -90,6 +95,24 @@ class DeepseekV2MLP(nn.Module):
|
|
90
95
|
return x
|
91
96
|
|
92
97
|
|
98
|
+
class MoEGate(nn.Module):
|
99
|
+
def __init__(self, config):
|
100
|
+
super().__init__()
|
101
|
+
self.weight = nn.Parameter(
|
102
|
+
torch.empty((config.n_routed_experts, config.hidden_size))
|
103
|
+
)
|
104
|
+
if config.topk_method == "noaux_tc":
|
105
|
+
self.e_score_correction_bias = nn.Parameter(
|
106
|
+
torch.empty((config.n_routed_experts))
|
107
|
+
)
|
108
|
+
else:
|
109
|
+
self.e_score_correction_bias = None
|
110
|
+
|
111
|
+
def forward(self, hidden_states):
|
112
|
+
logits = F.linear(hidden_states, self.weight, None)
|
113
|
+
return logits
|
114
|
+
|
115
|
+
|
93
116
|
class DeepseekV2MoE(nn.Module):
|
94
117
|
|
95
118
|
def __init__(
|
@@ -114,6 +137,8 @@ class DeepseekV2MoE(nn.Module):
|
|
114
137
|
"Only silu is supported for now."
|
115
138
|
)
|
116
139
|
|
140
|
+
self.gate = MoEGate(config=config)
|
141
|
+
|
117
142
|
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
|
118
143
|
self.experts = MoEImpl(
|
119
144
|
num_experts=config.n_routed_experts,
|
@@ -125,11 +150,9 @@ class DeepseekV2MoE(nn.Module):
|
|
125
150
|
use_grouped_topk=True,
|
126
151
|
num_expert_group=config.n_group,
|
127
152
|
topk_group=config.topk_group,
|
153
|
+
correction_bias=self.gate.e_score_correction_bias,
|
128
154
|
)
|
129
155
|
|
130
|
-
self.gate = ReplicatedLinear(
|
131
|
-
config.hidden_size, config.n_routed_experts, bias=False, quant_config=None
|
132
|
-
)
|
133
156
|
if config.n_shared_experts is not None:
|
134
157
|
intermediate_size = config.moe_intermediate_size * config.n_shared_experts
|
135
158
|
self.shared_experts = DeepseekV2MLP(
|
@@ -146,7 +169,7 @@ class DeepseekV2MoE(nn.Module):
|
|
146
169
|
if self.n_shared_experts is not None:
|
147
170
|
shared_output = self.shared_experts(hidden_states)
|
148
171
|
# router_logits: (num_tokens, n_experts)
|
149
|
-
router_logits
|
172
|
+
router_logits = self.gate(hidden_states)
|
150
173
|
final_hidden_states = (
|
151
174
|
self.experts(hidden_states=hidden_states, router_logits=router_logits)
|
152
175
|
* self.routed_scaling_factor
|
@@ -167,15 +190,6 @@ def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
|
|
167
190
|
return 0.1 * mscale * math.log(scale) + 1.0
|
168
191
|
|
169
192
|
|
170
|
-
def input_to_float8(x, dtype=torch.float8_e4m3fn):
|
171
|
-
finfo = torch.finfo(dtype)
|
172
|
-
min_val, max_val = x.aminmax()
|
173
|
-
amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
|
174
|
-
scale = finfo.max / amax
|
175
|
-
x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
|
176
|
-
return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
|
177
|
-
|
178
|
-
|
179
193
|
class DeepseekV2Attention(nn.Module):
|
180
194
|
|
181
195
|
def __init__(
|
@@ -439,7 +453,10 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
439
453
|
quant_config=quant_config,
|
440
454
|
)
|
441
455
|
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
|
442
|
-
|
456
|
+
|
457
|
+
if rope_scaling:
|
458
|
+
rope_scaling["rope_type"] = "deepseek_yarn"
|
459
|
+
|
443
460
|
self.rotary_emb = get_rope(
|
444
461
|
qk_rope_head_dim,
|
445
462
|
rotary_dim=qk_rope_head_dim,
|
@@ -454,6 +471,8 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
454
471
|
scaling_factor = rope_scaling["factor"]
|
455
472
|
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
|
456
473
|
self.scaling = self.scaling * mscale * mscale
|
474
|
+
else:
|
475
|
+
self.rotary_emb.forward = self.rotary_emb.forward_native
|
457
476
|
|
458
477
|
self.attn_mqa = RadixAttention(
|
459
478
|
self.num_local_heads,
|
@@ -845,6 +864,16 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
845
864
|
|
846
865
|
params_dict = dict(self.named_parameters())
|
847
866
|
for name, loaded_weight in weights:
|
867
|
+
# TODO(HandH1998): Modify it when nextn is supported.
|
868
|
+
if hasattr(self.config, "num_nextn_predict_layers"):
|
869
|
+
num_nextn_layers = self.config.num_nextn_predict_layers
|
870
|
+
if num_nextn_layers > 0 and name.startswith("model.layers"):
|
871
|
+
name_list = name.split(".")
|
872
|
+
if (
|
873
|
+
len(name_list) >= 3
|
874
|
+
and int(name_list[2]) >= self.config.num_hidden_layers
|
875
|
+
):
|
876
|
+
continue
|
848
877
|
if "rotary_emb.inv_freq" in name:
|
849
878
|
continue
|
850
879
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
@@ -909,13 +938,33 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
909
938
|
).T
|
910
939
|
else:
|
911
940
|
w = self_attn.kv_b_proj.weight
|
941
|
+
# NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
|
942
|
+
# This may affect the accuracy of fp8 model.
|
943
|
+
if (
|
944
|
+
hasattr(self.quant_config, "weight_block_size")
|
945
|
+
and w.dtype == torch.float8_e4m3fn
|
946
|
+
):
|
947
|
+
weight_block_size = self.quant_config.weight_block_size
|
948
|
+
if weight_block_size is not None:
|
949
|
+
assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
|
950
|
+
w, scale = block_quant_to_tensor_quant(
|
951
|
+
w, self_attn.kv_b_proj.weight_scale_inv, weight_block_size
|
952
|
+
)
|
953
|
+
self_attn.w_scale = scale
|
912
954
|
w_kc, w_vc = w.unflatten(
|
913
955
|
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
|
914
956
|
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
|
915
957
|
self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
|
916
958
|
self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
|
917
|
-
if
|
959
|
+
if (
|
960
|
+
hasattr(self_attn.kv_b_proj, "weight_scale")
|
961
|
+
and self_attn.w_scale is None
|
962
|
+
):
|
918
963
|
self_attn.w_scale = self_attn.kv_b_proj.weight_scale
|
919
964
|
|
920
965
|
|
921
|
-
|
966
|
+
class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
|
967
|
+
pass
|
968
|
+
|
969
|
+
|
970
|
+
EntryClass = [DeepseekV2ForCausalLM, DeepseekV3ForCausalLM]
|
sglang/srt/models/gemma2.py
CHANGED
@@ -307,6 +307,25 @@ class Gemma2Model(nn.Module):
|
|
307
307
|
|
308
308
|
|
309
309
|
class Gemma2ForCausalLM(nn.Module):
|
310
|
+
# BitandBytes specific attributes
|
311
|
+
default_bitsandbytes_target_modules = [
|
312
|
+
".gate_proj.",
|
313
|
+
".down_proj.",
|
314
|
+
".up_proj.",
|
315
|
+
".q_proj.",
|
316
|
+
".k_proj.",
|
317
|
+
".v_proj.",
|
318
|
+
".o_proj.",
|
319
|
+
]
|
320
|
+
bitsandbytes_stacked_params_mapping = {
|
321
|
+
# shard_name, weight_name, index
|
322
|
+
"q_proj": ("qkv_proj", 0),
|
323
|
+
"k_proj": ("qkv_proj", 1),
|
324
|
+
"v_proj": ("qkv_proj", 2),
|
325
|
+
"gate_proj": ("gate_up_proj", 0),
|
326
|
+
"up_proj": ("gate_up_proj", 1),
|
327
|
+
}
|
328
|
+
|
310
329
|
packed_modules_mapping = {
|
311
330
|
"qkv_proj": [
|
312
331
|
"q_proj",
|
sglang/srt/models/grok.py
CHANGED
@@ -26,7 +26,6 @@ from vllm.distributed import get_tensor_model_parallel_world_size
|
|
26
26
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
27
27
|
|
28
28
|
from sglang.srt.layers.activation import GeluAndMul
|
29
|
-
from sglang.srt.layers.fused_moe_triton import FusedMoE
|
30
29
|
from sglang.srt.layers.layernorm import RMSNorm
|
31
30
|
from sglang.srt.layers.linear import (
|
32
31
|
MergedColumnParallelLinear,
|
@@ -35,6 +34,7 @@ from sglang.srt.layers.linear import (
|
|
35
34
|
RowParallelLinear,
|
36
35
|
)
|
37
36
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
37
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
38
38
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
39
39
|
from sglang.srt.layers.radix_attention import RadixAttention
|
40
40
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
sglang/srt/models/llama.py
CHANGED
@@ -325,8 +325,8 @@ class LlamaForCausalLM(nn.Module):
|
|
325
325
|
self.config = config
|
326
326
|
self.quant_config = quant_config
|
327
327
|
self.model = LlamaModel(config, quant_config=quant_config)
|
328
|
-
# Llama 3.2 1B
|
329
|
-
# Llama 3.1 8B
|
328
|
+
# Llama 3.2 1B Instruct set tie_word_embeddings to True
|
329
|
+
# Llama 3.1 8B Instruct set tie_word_embeddings to False
|
330
330
|
if self.config.tie_word_embeddings:
|
331
331
|
self.lm_head = self.model.embed_tokens
|
332
332
|
else:
|
sglang/srt/models/mixtral.py
CHANGED
@@ -27,8 +27,6 @@ from vllm.distributed import (
|
|
27
27
|
)
|
28
28
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
29
29
|
|
30
|
-
from sglang.srt.layers.ep_moe.layer import EPMoE
|
31
|
-
from sglang.srt.layers.fused_moe_triton import FusedMoE
|
32
30
|
from sglang.srt.layers.layernorm import RMSNorm
|
33
31
|
from sglang.srt.layers.linear import (
|
34
32
|
QKVParallelLinear,
|
@@ -36,6 +34,8 @@ from sglang.srt.layers.linear import (
|
|
36
34
|
RowParallelLinear,
|
37
35
|
)
|
38
36
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
37
|
+
from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
38
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
39
39
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
40
40
|
from sglang.srt.layers.radix_attention import RadixAttention
|
41
41
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
sglang/srt/models/olmoe.py
CHANGED
@@ -36,9 +36,9 @@ from vllm.model_executor.layers.linear import (
|
|
36
36
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
37
37
|
|
38
38
|
from sglang.srt.layers.activation import SiluAndMul
|
39
|
-
from sglang.srt.layers.fused_moe_triton import FusedMoE
|
40
39
|
from sglang.srt.layers.layernorm import RMSNorm
|
41
40
|
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
|
41
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
42
42
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
43
43
|
from sglang.srt.layers.radix_attention import RadixAttention
|
44
44
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
sglang/srt/models/qwen2_moe.py
CHANGED
@@ -29,7 +29,6 @@ from vllm.distributed import (
|
|
29
29
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
30
30
|
|
31
31
|
from sglang.srt.layers.activation import SiluAndMul
|
32
|
-
from sglang.srt.layers.fused_moe_triton import FusedMoE
|
33
32
|
from sglang.srt.layers.layernorm import RMSNorm
|
34
33
|
from sglang.srt.layers.linear import (
|
35
34
|
MergedColumnParallelLinear,
|
@@ -38,6 +37,7 @@ from sglang.srt.layers.linear import (
|
|
38
37
|
RowParallelLinear,
|
39
38
|
)
|
40
39
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
40
|
+
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
41
41
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
42
42
|
from sglang.srt.layers.radix_attention import RadixAttention
|
43
43
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
sglang/srt/models/xverse_moe.py
CHANGED
@@ -33,8 +33,8 @@ from vllm.model_executor.layers.linear import (
|
|
33
33
|
)
|
34
34
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
35
35
|
|
36
|
-
from sglang.srt.layers.fused_moe_triton import fused_moe
|
37
36
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
37
|
+
from sglang.srt.layers.moe.fused_moe_triton import fused_moe
|
38
38
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
39
39
|
from sglang.srt.layers.radix_attention import RadixAttention
|
40
40
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
sglang/srt/openai_api/adapter.py
CHANGED
@@ -517,6 +517,7 @@ def v1_generate_request(
|
|
517
517
|
"repetition_penalty": request.repetition_penalty,
|
518
518
|
"regex": request.regex,
|
519
519
|
"json_schema": request.json_schema,
|
520
|
+
"ebnf": request.ebnf,
|
520
521
|
"n": request.n,
|
521
522
|
"no_stop_trim": request.no_stop_trim,
|
522
523
|
"ignore_eos": request.ignore_eos,
|
@@ -692,6 +693,14 @@ def v1_generate_response(request, ret, tokenizer_manager, to_file=False):
|
|
692
693
|
|
693
694
|
async def v1_completions(tokenizer_manager, raw_request: Request):
|
694
695
|
request_json = await raw_request.json()
|
696
|
+
if "extra_body" in request_json:
|
697
|
+
extra = request_json["extra_body"]
|
698
|
+
if "ebnf" in extra:
|
699
|
+
request_json["ebnf"] = extra["ebnf"]
|
700
|
+
if "regex" in extra:
|
701
|
+
request_json["regex"] = extra["regex"]
|
702
|
+
# remove extra_body to avoid pydantic conflict
|
703
|
+
del request_json["extra_body"]
|
695
704
|
all_requests = [CompletionRequest(**request_json)]
|
696
705
|
adapted_request, request = v1_generate_request(all_requests)
|
697
706
|
|
@@ -858,6 +867,7 @@ def v1_chat_generate_request(
|
|
858
867
|
logprob_start_lens = []
|
859
868
|
top_logprobs_nums = []
|
860
869
|
modalities_list = []
|
870
|
+
lora_paths = []
|
861
871
|
|
862
872
|
# NOTE: with openai API, the prompt's logprobs are always not computed
|
863
873
|
|
@@ -920,6 +930,7 @@ def v1_chat_generate_request(
|
|
920
930
|
return_logprobs.append(request.logprobs)
|
921
931
|
logprob_start_lens.append(-1)
|
922
932
|
top_logprobs_nums.append(request.top_logprobs or 0)
|
933
|
+
lora_paths.append(request.lora_path)
|
923
934
|
|
924
935
|
sampling_params = {
|
925
936
|
"temperature": request.temperature,
|
@@ -934,6 +945,7 @@ def v1_chat_generate_request(
|
|
934
945
|
"frequency_penalty": request.frequency_penalty,
|
935
946
|
"repetition_penalty": request.repetition_penalty,
|
936
947
|
"regex": request.regex,
|
948
|
+
"ebnf": request.ebnf,
|
937
949
|
"n": request.n,
|
938
950
|
"no_stop_trim": request.no_stop_trim,
|
939
951
|
"ignore_eos": request.ignore_eos,
|
@@ -958,6 +970,7 @@ def v1_chat_generate_request(
|
|
958
970
|
logprob_start_lens = logprob_start_lens[0]
|
959
971
|
top_logprobs_nums = top_logprobs_nums[0]
|
960
972
|
modalities_list = modalities_list[0]
|
973
|
+
lora_paths = lora_paths[0]
|
961
974
|
else:
|
962
975
|
if isinstance(input_ids[0], str):
|
963
976
|
prompt_kwargs = {"text": input_ids}
|
@@ -975,6 +988,7 @@ def v1_chat_generate_request(
|
|
975
988
|
return_text_in_logprobs=True,
|
976
989
|
rid=request_ids,
|
977
990
|
modalities=modalities_list,
|
991
|
+
lora_path=lora_paths,
|
978
992
|
)
|
979
993
|
|
980
994
|
return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0]
|
@@ -1104,6 +1118,15 @@ def v1_chat_generate_response(request, ret, to_file=False, cache_report=False):
|
|
1104
1118
|
|
1105
1119
|
async def v1_chat_completions(tokenizer_manager, raw_request: Request):
|
1106
1120
|
request_json = await raw_request.json()
|
1121
|
+
if "extra_body" in request_json:
|
1122
|
+
extra = request_json["extra_body"]
|
1123
|
+
# For example, if 'ebnf' is given:
|
1124
|
+
if "ebnf" in extra:
|
1125
|
+
request_json["ebnf"] = extra["ebnf"]
|
1126
|
+
if "regex" in extra:
|
1127
|
+
request_json["regex"] = extra["regex"]
|
1128
|
+
# remove extra_body to avoid pydantic conflict
|
1129
|
+
del request_json["extra_body"]
|
1107
1130
|
all_requests = [ChatCompletionRequest(**request_json)]
|
1108
1131
|
adapted_request, request = v1_chat_generate_request(all_requests, tokenizer_manager)
|
1109
1132
|
|
@@ -179,6 +179,7 @@ class CompletionRequest(BaseModel):
|
|
179
179
|
ignore_eos: bool = False
|
180
180
|
skip_special_tokens: bool = True
|
181
181
|
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
182
|
+
ebnf: Optional[str] = None
|
182
183
|
|
183
184
|
|
184
185
|
class CompletionResponseChoice(BaseModel):
|
@@ -288,6 +289,7 @@ class ChatCompletionRequest(BaseModel):
|
|
288
289
|
ignore_eos: bool = False
|
289
290
|
skip_special_tokens: bool = True
|
290
291
|
lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
|
292
|
+
ebnf: Optional[str] = None
|
291
293
|
|
292
294
|
|
293
295
|
class ChatMessage(BaseModel):
|
@@ -36,6 +36,7 @@ class SamplingParams:
|
|
36
36
|
regex: Optional[str] = None,
|
37
37
|
n: int = 1,
|
38
38
|
json_schema: Optional[str] = None,
|
39
|
+
ebnf: Optional[str] = None,
|
39
40
|
no_stop_trim: bool = False,
|
40
41
|
ignore_eos: bool = False,
|
41
42
|
skip_special_tokens: bool = True,
|
@@ -60,6 +61,7 @@ class SamplingParams:
|
|
60
61
|
self.regex = regex
|
61
62
|
self.n = n
|
62
63
|
self.json_schema = json_schema
|
64
|
+
self.ebnf = ebnf
|
63
65
|
self.no_stop_trim = no_stop_trim
|
64
66
|
|
65
67
|
# Process some special cases
|
@@ -111,8 +113,13 @@ class SamplingParams:
|
|
111
113
|
f"min_new_tokens must be in (0, max_new_tokens({self.max_new_tokens})], got "
|
112
114
|
f"{self.min_new_tokens}."
|
113
115
|
)
|
114
|
-
|
115
|
-
|
116
|
+
grammars = [
|
117
|
+
self.json_schema,
|
118
|
+
self.regex,
|
119
|
+
self.ebnf,
|
120
|
+
] # since mutually exclusive, only one can be set
|
121
|
+
if sum(x is not None for x in grammars) > 1:
|
122
|
+
raise ValueError("Only one of regex, json_schema, or ebnf can be set.")
|
116
123
|
|
117
124
|
def normalize(self, tokenizer):
|
118
125
|
# Process stop strings
|
sglang/srt/server.py
CHANGED
@@ -245,16 +245,11 @@ async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request):
|
|
245
245
|
try:
|
246
246
|
ret = await tokenizer_manager.get_weights_by_name(obj, request)
|
247
247
|
if ret is None:
|
248
|
-
return
|
249
|
-
{"error": {"message": "Get parameter by name failed"}},
|
250
|
-
status_code=HTTPStatus.BAD_REQUEST,
|
251
|
-
)
|
248
|
+
return _create_error_response("Get parameter by name failed")
|
252
249
|
else:
|
253
250
|
return ORJSONResponse(ret, status_code=200)
|
254
251
|
except Exception as e:
|
255
|
-
return
|
256
|
-
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
257
|
-
)
|
252
|
+
return _create_error_response(e)
|
258
253
|
|
259
254
|
|
260
255
|
@app.api_route("/open_session", methods=["GET", "POST"])
|
@@ -264,9 +259,7 @@ async def open_session(obj: OpenSessionReqInput, request: Request):
|
|
264
259
|
session_id = await tokenizer_manager.open_session(obj, request)
|
265
260
|
return session_id
|
266
261
|
except Exception as e:
|
267
|
-
return
|
268
|
-
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
269
|
-
)
|
262
|
+
return _create_error_response(e)
|
270
263
|
|
271
264
|
|
272
265
|
@app.api_route("/close_session", methods=["GET", "POST"])
|
@@ -276,9 +269,7 @@ async def close_session(obj: CloseSessionReqInput, request: Request):
|
|
276
269
|
await tokenizer_manager.close_session(obj, request)
|
277
270
|
return Response(status_code=200)
|
278
271
|
except Exception as e:
|
279
|
-
return
|
280
|
-
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
281
|
-
)
|
272
|
+
return _create_error_response(e)
|
282
273
|
|
283
274
|
|
284
275
|
# fastapi implicitly converts json in the request to obj (dataclass)
|
@@ -311,9 +302,8 @@ async def generate_request(obj: GenerateReqInput, request: Request):
|
|
311
302
|
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
312
303
|
return ret
|
313
304
|
except ValueError as e:
|
314
|
-
|
315
|
-
|
316
|
-
)
|
305
|
+
logger.error(f"Error: {e}")
|
306
|
+
return _create_error_response(e)
|
317
307
|
|
318
308
|
|
319
309
|
@app.api_route("/encode", methods=["POST", "PUT"])
|
@@ -324,9 +314,7 @@ async def encode_request(obj: EmbeddingReqInput, request: Request):
|
|
324
314
|
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
325
315
|
return ret
|
326
316
|
except ValueError as e:
|
327
|
-
return
|
328
|
-
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
329
|
-
)
|
317
|
+
return _create_error_response(e)
|
330
318
|
|
331
319
|
|
332
320
|
@app.api_route("/classify", methods=["POST", "PUT"])
|
@@ -337,9 +325,7 @@ async def classify_request(obj: EmbeddingReqInput, request: Request):
|
|
337
325
|
ret = await tokenizer_manager.generate_request(obj, request).__anext__()
|
338
326
|
return ret
|
339
327
|
except ValueError as e:
|
340
|
-
return
|
341
|
-
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
342
|
-
)
|
328
|
+
return _create_error_response(e)
|
343
329
|
|
344
330
|
|
345
331
|
##### OpenAI-compatible API endpoints #####
|
@@ -415,6 +401,12 @@ async def retrieve_file_content(file_id: str):
|
|
415
401
|
return await v1_retrieve_file_content(file_id)
|
416
402
|
|
417
403
|
|
404
|
+
def _create_error_response(e):
|
405
|
+
return ORJSONResponse(
|
406
|
+
{"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
|
407
|
+
)
|
408
|
+
|
409
|
+
|
418
410
|
def launch_engine(
|
419
411
|
server_args: ServerArgs,
|
420
412
|
):
|
@@ -848,12 +840,10 @@ class Engine:
|
|
848
840
|
group_name=group_name,
|
849
841
|
backend=backend,
|
850
842
|
)
|
851
|
-
|
852
|
-
async def _init_group():
|
853
|
-
return await tokenizer_manager.init_weights_update_group(obj, None)
|
854
|
-
|
855
843
|
loop = asyncio.get_event_loop()
|
856
|
-
return loop.run_until_complete(
|
844
|
+
return loop.run_until_complete(
|
845
|
+
tokenizer_manager.init_weights_update_group(obj, None)
|
846
|
+
)
|
857
847
|
|
858
848
|
def update_weights_from_distributed(self, name, dtype, shape):
|
859
849
|
"""Update weights from distributed source."""
|
@@ -862,22 +852,16 @@ class Engine:
|
|
862
852
|
dtype=dtype,
|
863
853
|
shape=shape,
|
864
854
|
)
|
865
|
-
|
866
|
-
async def _update_weights():
|
867
|
-
return await tokenizer_manager.update_weights_from_distributed(obj, None)
|
868
|
-
|
869
855
|
loop = asyncio.get_event_loop()
|
870
|
-
return loop.run_until_complete(
|
856
|
+
return loop.run_until_complete(
|
857
|
+
tokenizer_manager.update_weights_from_distributed(obj, None)
|
858
|
+
)
|
871
859
|
|
872
860
|
def get_weights_by_name(self, name, truncate_size=100):
|
873
861
|
"""Get weights by parameter name."""
|
874
862
|
obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size)
|
875
|
-
|
876
|
-
async def _get_weights():
|
877
|
-
return await tokenizer_manager.get_weights_by_name(obj, None)
|
878
|
-
|
879
863
|
loop = asyncio.get_event_loop()
|
880
|
-
return loop.run_until_complete(
|
864
|
+
return loop.run_until_complete(tokenizer_manager.get_weights_by_name(obj, None))
|
881
865
|
|
882
866
|
|
883
867
|
class Runtime:
|
sglang/srt/utils.py
CHANGED
@@ -14,6 +14,7 @@
|
|
14
14
|
"""Common utilities."""
|
15
15
|
|
16
16
|
import base64
|
17
|
+
import dataclasses
|
17
18
|
import ipaddress
|
18
19
|
import itertools
|
19
20
|
import json
|
@@ -1238,49 +1239,37 @@ def cuda_device_count_stateless() -> int:
|
|
1238
1239
|
return _cuda_device_count_stateless(os.environ.get("CUDA_VISIBLE_DEVICES", None))
|
1239
1240
|
|
1240
1241
|
|
1241
|
-
def
|
1242
|
-
|
1243
|
-
|
1244
|
-
|
1245
|
-
|
1246
|
-
"""
|
1247
|
-
Determine whether to use tensor cores for attention computation.
|
1248
|
-
|
1249
|
-
Args:
|
1250
|
-
kv_cache_dtype: Data type of the KV cache
|
1251
|
-
num_attention_heads: Number of attention heads
|
1252
|
-
num_kv_heads: Number of key/value heads
|
1253
|
-
|
1254
|
-
Returns:
|
1255
|
-
bool: Whether to use tensor cores
|
1256
|
-
"""
|
1257
|
-
# Try to use environment variable first
|
1258
|
-
env_override = os.environ.get("SGLANG_FLASHINFER_USE_TENSOR_CORE")
|
1259
|
-
if env_override is not None:
|
1260
|
-
return env_override.lower() == "true"
|
1261
|
-
|
1262
|
-
# Try to use _grouped_size_compiled_for_decode_kernels if available
|
1263
|
-
# This is for flashinfer <=0.1.6. Otherwise, there is an accuracy bug
|
1264
|
-
try:
|
1265
|
-
from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
|
1266
|
-
|
1267
|
-
if not _grouped_size_compiled_for_decode_kernels(
|
1268
|
-
num_attention_heads,
|
1269
|
-
num_kv_heads,
|
1270
|
-
):
|
1271
|
-
return True
|
1242
|
+
def dataclass_to_string_truncated(data, max_length=2048):
|
1243
|
+
if isinstance(data, str):
|
1244
|
+
if len(data) > max_length:
|
1245
|
+
half_length = max_length // 2
|
1246
|
+
return f'"{data[:half_length]} ... {data[-half_length:]}"'
|
1272
1247
|
else:
|
1273
|
-
return
|
1274
|
-
|
1275
|
-
|
1276
|
-
|
1277
|
-
|
1278
|
-
|
1279
|
-
|
1280
|
-
|
1281
|
-
|
1282
|
-
|
1283
|
-
|
1284
|
-
|
1248
|
+
return f'"{data}"'
|
1249
|
+
elif isinstance(data, (list, tuple)):
|
1250
|
+
if len(data) > max_length:
|
1251
|
+
half_length = max_length // 2
|
1252
|
+
return str(data[:half_length]) + " ... " + str(data[-half_length:])
|
1253
|
+
else:
|
1254
|
+
return str(data)
|
1255
|
+
elif isinstance(data, dict):
|
1256
|
+
return (
|
1257
|
+
"{"
|
1258
|
+
+ ", ".join(
|
1259
|
+
f"{k}: {dataclass_to_string_truncated(v, max_length)}"
|
1260
|
+
for k, v in data.items()
|
1261
|
+
)
|
1262
|
+
+ "}"
|
1263
|
+
)
|
1264
|
+
elif dataclasses.is_dataclass(data):
|
1265
|
+
fields = dataclasses.fields(data)
|
1266
|
+
return (
|
1267
|
+
f"{data.__class__.__name__}("
|
1268
|
+
+ ", ".join(
|
1269
|
+
f"{f.name}={dataclass_to_string_truncated(getattr(data, f.name), max_length)}"
|
1270
|
+
for f in fields
|
1271
|
+
)
|
1272
|
+
+ ")"
|
1273
|
+
)
|
1285
1274
|
else:
|
1286
|
-
return
|
1275
|
+
return str(data)
|