sglang 0.4.0.post2__py3-none-any.whl → 0.4.1.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. sglang/bench_offline_throughput.py +0 -12
  2. sglang/bench_one_batch.py +0 -12
  3. sglang/bench_serving.py +11 -2
  4. sglang/lang/backend/openai.py +10 -0
  5. sglang/srt/aio_rwlock.py +100 -0
  6. sglang/srt/configs/model_config.py +8 -1
  7. sglang/srt/constrained/xgrammar_backend.py +6 -0
  8. sglang/srt/layers/attention/flashinfer_backend.py +49 -5
  9. sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -14
  10. sglang/srt/layers/linear.py +20 -2
  11. sglang/srt/layers/{ep_moe → moe/ep_moe}/layer.py +14 -39
  12. sglang/srt/layers/moe/fused_moe_native.py +46 -0
  13. sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/__init__.py +3 -7
  14. sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/fused_moe.py +124 -99
  15. sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/layer.py +16 -48
  16. sglang/srt/layers/moe/topk.py +205 -0
  17. sglang/srt/layers/quantization/__init__.py +3 -3
  18. sglang/srt/layers/quantization/fp8.py +169 -32
  19. sglang/srt/layers/quantization/fp8_kernel.py +292 -0
  20. sglang/srt/layers/quantization/fp8_utils.py +90 -1
  21. sglang/srt/layers/torchao_utils.py +11 -15
  22. sglang/srt/managers/schedule_batch.py +16 -10
  23. sglang/srt/managers/schedule_policy.py +1 -1
  24. sglang/srt/managers/scheduler.py +13 -16
  25. sglang/srt/managers/tokenizer_manager.py +130 -111
  26. sglang/srt/mem_cache/memory_pool.py +15 -8
  27. sglang/srt/model_executor/cuda_graph_runner.py +1 -1
  28. sglang/srt/model_loader/loader.py +22 -11
  29. sglang/srt/models/dbrx.py +1 -1
  30. sglang/srt/models/deepseek.py +1 -1
  31. sglang/srt/models/deepseek_v2.py +67 -18
  32. sglang/srt/models/gemma2.py +19 -0
  33. sglang/srt/models/grok.py +1 -1
  34. sglang/srt/models/llama.py +2 -2
  35. sglang/srt/models/mixtral.py +2 -2
  36. sglang/srt/models/olmoe.py +1 -1
  37. sglang/srt/models/qwen2_moe.py +1 -1
  38. sglang/srt/models/xverse_moe.py +1 -1
  39. sglang/srt/openai_api/adapter.py +23 -0
  40. sglang/srt/openai_api/protocol.py +2 -0
  41. sglang/srt/sampling/sampling_params.py +9 -2
  42. sglang/srt/server.py +21 -37
  43. sglang/srt/utils.py +33 -44
  44. sglang/test/test_block_fp8.py +341 -0
  45. sglang/version.py +1 -1
  46. {sglang-0.4.0.post2.dist-info → sglang-0.4.1.post1.dist-info}/METADATA +4 -4
  47. {sglang-0.4.0.post2.dist-info → sglang-0.4.1.post1.dist-info}/RECORD +52 -48
  48. sglang/srt/layers/fused_moe_patch.py +0 -133
  49. /sglang/srt/layers/{ep_moe → moe/ep_moe}/__init__.py +0 -0
  50. /sglang/srt/layers/{ep_moe → moe/ep_moe}/kernels.py +0 -0
  51. {sglang-0.4.0.post2.dist-info → sglang-0.4.1.post1.dist-info}/LICENSE +0 -0
  52. {sglang-0.4.0.post2.dist-info → sglang-0.4.1.post1.dist-info}/WHEEL +0 -0
  53. {sglang-0.4.0.post2.dist-info → sglang-0.4.1.post1.dist-info}/top_level.txt +0 -0
@@ -19,6 +19,7 @@
19
19
  from typing import Any, Dict, Iterable, Optional, Tuple
20
20
 
21
21
  import torch
22
+ import torch.nn.functional as F
22
23
  from torch import nn
23
24
  from transformers import PretrainedConfig
24
25
  from vllm import _custom_ops as ops
@@ -31,8 +32,6 @@ from vllm.distributed import (
31
32
  from vllm.model_executor.layers.rotary_embedding import get_rope
32
33
 
33
34
  from sglang.srt.layers.activation import SiluAndMul
34
- from sglang.srt.layers.ep_moe.layer import EPMoE
35
- from sglang.srt.layers.fused_moe_triton import FusedMoE
36
35
  from sglang.srt.layers.layernorm import RMSNorm
37
36
  from sglang.srt.layers.linear import (
38
37
  ColumnParallelLinear,
@@ -41,7 +40,13 @@ from sglang.srt.layers.linear import (
41
40
  RowParallelLinear,
42
41
  )
43
42
  from sglang.srt.layers.logits_processor import LogitsProcessor
43
+ from sglang.srt.layers.moe.ep_moe.layer import EPMoE
44
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
44
45
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
46
+ from sglang.srt.layers.quantization.fp8_utils import (
47
+ block_quant_to_tensor_quant,
48
+ input_to_float8,
49
+ )
45
50
  from sglang.srt.layers.radix_attention import RadixAttention
46
51
  from sglang.srt.layers.vocab_parallel_embedding import (
47
52
  ParallelLMHead,
@@ -90,6 +95,24 @@ class DeepseekV2MLP(nn.Module):
90
95
  return x
91
96
 
92
97
 
98
+ class MoEGate(nn.Module):
99
+ def __init__(self, config):
100
+ super().__init__()
101
+ self.weight = nn.Parameter(
102
+ torch.empty((config.n_routed_experts, config.hidden_size))
103
+ )
104
+ if config.topk_method == "noaux_tc":
105
+ self.e_score_correction_bias = nn.Parameter(
106
+ torch.empty((config.n_routed_experts))
107
+ )
108
+ else:
109
+ self.e_score_correction_bias = None
110
+
111
+ def forward(self, hidden_states):
112
+ logits = F.linear(hidden_states, self.weight, None)
113
+ return logits
114
+
115
+
93
116
  class DeepseekV2MoE(nn.Module):
94
117
 
95
118
  def __init__(
@@ -114,6 +137,8 @@ class DeepseekV2MoE(nn.Module):
114
137
  "Only silu is supported for now."
115
138
  )
116
139
 
140
+ self.gate = MoEGate(config=config)
141
+
117
142
  MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
118
143
  self.experts = MoEImpl(
119
144
  num_experts=config.n_routed_experts,
@@ -125,11 +150,9 @@ class DeepseekV2MoE(nn.Module):
125
150
  use_grouped_topk=True,
126
151
  num_expert_group=config.n_group,
127
152
  topk_group=config.topk_group,
153
+ correction_bias=self.gate.e_score_correction_bias,
128
154
  )
129
155
 
130
- self.gate = ReplicatedLinear(
131
- config.hidden_size, config.n_routed_experts, bias=False, quant_config=None
132
- )
133
156
  if config.n_shared_experts is not None:
134
157
  intermediate_size = config.moe_intermediate_size * config.n_shared_experts
135
158
  self.shared_experts = DeepseekV2MLP(
@@ -146,7 +169,7 @@ class DeepseekV2MoE(nn.Module):
146
169
  if self.n_shared_experts is not None:
147
170
  shared_output = self.shared_experts(hidden_states)
148
171
  # router_logits: (num_tokens, n_experts)
149
- router_logits, _ = self.gate(hidden_states)
172
+ router_logits = self.gate(hidden_states)
150
173
  final_hidden_states = (
151
174
  self.experts(hidden_states=hidden_states, router_logits=router_logits)
152
175
  * self.routed_scaling_factor
@@ -167,15 +190,6 @@ def yarn_get_mscale(scale: float = 1, mscale: float = 1) -> float:
167
190
  return 0.1 * mscale * math.log(scale) + 1.0
168
191
 
169
192
 
170
- def input_to_float8(x, dtype=torch.float8_e4m3fn):
171
- finfo = torch.finfo(dtype)
172
- min_val, max_val = x.aminmax()
173
- amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
174
- scale = finfo.max / amax
175
- x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
176
- return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
177
-
178
-
179
193
  class DeepseekV2Attention(nn.Module):
180
194
 
181
195
  def __init__(
@@ -439,7 +453,10 @@ class DeepseekV2AttentionMLA(nn.Module):
439
453
  quant_config=quant_config,
440
454
  )
441
455
  self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
442
- rope_scaling["rope_type"] = "deepseek_yarn"
456
+
457
+ if rope_scaling:
458
+ rope_scaling["rope_type"] = "deepseek_yarn"
459
+
443
460
  self.rotary_emb = get_rope(
444
461
  qk_rope_head_dim,
445
462
  rotary_dim=qk_rope_head_dim,
@@ -454,6 +471,8 @@ class DeepseekV2AttentionMLA(nn.Module):
454
471
  scaling_factor = rope_scaling["factor"]
455
472
  mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
456
473
  self.scaling = self.scaling * mscale * mscale
474
+ else:
475
+ self.rotary_emb.forward = self.rotary_emb.forward_native
457
476
 
458
477
  self.attn_mqa = RadixAttention(
459
478
  self.num_local_heads,
@@ -845,6 +864,16 @@ class DeepseekV2ForCausalLM(nn.Module):
845
864
 
846
865
  params_dict = dict(self.named_parameters())
847
866
  for name, loaded_weight in weights:
867
+ # TODO(HandH1998): Modify it when nextn is supported.
868
+ if hasattr(self.config, "num_nextn_predict_layers"):
869
+ num_nextn_layers = self.config.num_nextn_predict_layers
870
+ if num_nextn_layers > 0 and name.startswith("model.layers"):
871
+ name_list = name.split(".")
872
+ if (
873
+ len(name_list) >= 3
874
+ and int(name_list[2]) >= self.config.num_hidden_layers
875
+ ):
876
+ continue
848
877
  if "rotary_emb.inv_freq" in name:
849
878
  continue
850
879
  for param_name, weight_name, shard_id in stacked_params_mapping:
@@ -909,13 +938,33 @@ class DeepseekV2ForCausalLM(nn.Module):
909
938
  ).T
910
939
  else:
911
940
  w = self_attn.kv_b_proj.weight
941
+ # NOTE(HandH1998): Since `bmm_fp8` only supports per-tensor scale, we have to requantize `self_attn.kv_b_proj`.
942
+ # This may affect the accuracy of fp8 model.
943
+ if (
944
+ hasattr(self.quant_config, "weight_block_size")
945
+ and w.dtype == torch.float8_e4m3fn
946
+ ):
947
+ weight_block_size = self.quant_config.weight_block_size
948
+ if weight_block_size is not None:
949
+ assert hasattr(self_attn.kv_b_proj, "weight_scale_inv")
950
+ w, scale = block_quant_to_tensor_quant(
951
+ w, self_attn.kv_b_proj.weight_scale_inv, weight_block_size
952
+ )
953
+ self_attn.w_scale = scale
912
954
  w_kc, w_vc = w.unflatten(
913
955
  0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
914
956
  ).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
915
957
  self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
916
958
  self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
917
- if hasattr(self_attn.kv_b_proj, "weight_scale"):
959
+ if (
960
+ hasattr(self_attn.kv_b_proj, "weight_scale")
961
+ and self_attn.w_scale is None
962
+ ):
918
963
  self_attn.w_scale = self_attn.kv_b_proj.weight_scale
919
964
 
920
965
 
921
- EntryClass = DeepseekV2ForCausalLM
966
+ class DeepseekV3ForCausalLM(DeepseekV2ForCausalLM):
967
+ pass
968
+
969
+
970
+ EntryClass = [DeepseekV2ForCausalLM, DeepseekV3ForCausalLM]
@@ -307,6 +307,25 @@ class Gemma2Model(nn.Module):
307
307
 
308
308
 
309
309
  class Gemma2ForCausalLM(nn.Module):
310
+ # BitandBytes specific attributes
311
+ default_bitsandbytes_target_modules = [
312
+ ".gate_proj.",
313
+ ".down_proj.",
314
+ ".up_proj.",
315
+ ".q_proj.",
316
+ ".k_proj.",
317
+ ".v_proj.",
318
+ ".o_proj.",
319
+ ]
320
+ bitsandbytes_stacked_params_mapping = {
321
+ # shard_name, weight_name, index
322
+ "q_proj": ("qkv_proj", 0),
323
+ "k_proj": ("qkv_proj", 1),
324
+ "v_proj": ("qkv_proj", 2),
325
+ "gate_proj": ("gate_up_proj", 0),
326
+ "up_proj": ("gate_up_proj", 1),
327
+ }
328
+
310
329
  packed_modules_mapping = {
311
330
  "qkv_proj": [
312
331
  "q_proj",
sglang/srt/models/grok.py CHANGED
@@ -26,7 +26,6 @@ from vllm.distributed import get_tensor_model_parallel_world_size
26
26
  from vllm.model_executor.layers.rotary_embedding import get_rope
27
27
 
28
28
  from sglang.srt.layers.activation import GeluAndMul
29
- from sglang.srt.layers.fused_moe_triton import FusedMoE
30
29
  from sglang.srt.layers.layernorm import RMSNorm
31
30
  from sglang.srt.layers.linear import (
32
31
  MergedColumnParallelLinear,
@@ -35,6 +34,7 @@ from sglang.srt.layers.linear import (
35
34
  RowParallelLinear,
36
35
  )
37
36
  from sglang.srt.layers.logits_processor import LogitsProcessor
37
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
38
38
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
39
39
  from sglang.srt.layers.radix_attention import RadixAttention
40
40
  from sglang.srt.layers.vocab_parallel_embedding import (
@@ -325,8 +325,8 @@ class LlamaForCausalLM(nn.Module):
325
325
  self.config = config
326
326
  self.quant_config = quant_config
327
327
  self.model = LlamaModel(config, quant_config=quant_config)
328
- # Llama 3.2 1B Insturct set tie_word_embeddings to True
329
- # Llama 3.1 8B Insturct set tie_word_embeddings to False
328
+ # Llama 3.2 1B Instruct set tie_word_embeddings to True
329
+ # Llama 3.1 8B Instruct set tie_word_embeddings to False
330
330
  if self.config.tie_word_embeddings:
331
331
  self.lm_head = self.model.embed_tokens
332
332
  else:
@@ -27,8 +27,6 @@ from vllm.distributed import (
27
27
  )
28
28
  from vllm.model_executor.layers.rotary_embedding import get_rope
29
29
 
30
- from sglang.srt.layers.ep_moe.layer import EPMoE
31
- from sglang.srt.layers.fused_moe_triton import FusedMoE
32
30
  from sglang.srt.layers.layernorm import RMSNorm
33
31
  from sglang.srt.layers.linear import (
34
32
  QKVParallelLinear,
@@ -36,6 +34,8 @@ from sglang.srt.layers.linear import (
36
34
  RowParallelLinear,
37
35
  )
38
36
  from sglang.srt.layers.logits_processor import LogitsProcessor
37
+ from sglang.srt.layers.moe.ep_moe.layer import EPMoE
38
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
39
39
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
40
40
  from sglang.srt.layers.radix_attention import RadixAttention
41
41
  from sglang.srt.layers.vocab_parallel_embedding import (
@@ -36,9 +36,9 @@ from vllm.model_executor.layers.linear import (
36
36
  from vllm.model_executor.layers.rotary_embedding import get_rope
37
37
 
38
38
  from sglang.srt.layers.activation import SiluAndMul
39
- from sglang.srt.layers.fused_moe_triton import FusedMoE
40
39
  from sglang.srt.layers.layernorm import RMSNorm
41
40
  from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
41
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
42
42
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
43
43
  from sglang.srt.layers.radix_attention import RadixAttention
44
44
  from sglang.srt.layers.vocab_parallel_embedding import (
@@ -29,7 +29,6 @@ from vllm.distributed import (
29
29
  from vllm.model_executor.layers.rotary_embedding import get_rope
30
30
 
31
31
  from sglang.srt.layers.activation import SiluAndMul
32
- from sglang.srt.layers.fused_moe_triton import FusedMoE
33
32
  from sglang.srt.layers.layernorm import RMSNorm
34
33
  from sglang.srt.layers.linear import (
35
34
  MergedColumnParallelLinear,
@@ -38,6 +37,7 @@ from sglang.srt.layers.linear import (
38
37
  RowParallelLinear,
39
38
  )
40
39
  from sglang.srt.layers.logits_processor import LogitsProcessor
40
+ from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
41
41
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
42
42
  from sglang.srt.layers.radix_attention import RadixAttention
43
43
  from sglang.srt.layers.vocab_parallel_embedding import (
@@ -33,8 +33,8 @@ from vllm.model_executor.layers.linear import (
33
33
  )
34
34
  from vllm.model_executor.layers.rotary_embedding import get_rope
35
35
 
36
- from sglang.srt.layers.fused_moe_triton import fused_moe
37
36
  from sglang.srt.layers.logits_processor import LogitsProcessor
37
+ from sglang.srt.layers.moe.fused_moe_triton import fused_moe
38
38
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
39
39
  from sglang.srt.layers.radix_attention import RadixAttention
40
40
  from sglang.srt.layers.vocab_parallel_embedding import (
@@ -517,6 +517,7 @@ def v1_generate_request(
517
517
  "repetition_penalty": request.repetition_penalty,
518
518
  "regex": request.regex,
519
519
  "json_schema": request.json_schema,
520
+ "ebnf": request.ebnf,
520
521
  "n": request.n,
521
522
  "no_stop_trim": request.no_stop_trim,
522
523
  "ignore_eos": request.ignore_eos,
@@ -692,6 +693,14 @@ def v1_generate_response(request, ret, tokenizer_manager, to_file=False):
692
693
 
693
694
  async def v1_completions(tokenizer_manager, raw_request: Request):
694
695
  request_json = await raw_request.json()
696
+ if "extra_body" in request_json:
697
+ extra = request_json["extra_body"]
698
+ if "ebnf" in extra:
699
+ request_json["ebnf"] = extra["ebnf"]
700
+ if "regex" in extra:
701
+ request_json["regex"] = extra["regex"]
702
+ # remove extra_body to avoid pydantic conflict
703
+ del request_json["extra_body"]
695
704
  all_requests = [CompletionRequest(**request_json)]
696
705
  adapted_request, request = v1_generate_request(all_requests)
697
706
 
@@ -858,6 +867,7 @@ def v1_chat_generate_request(
858
867
  logprob_start_lens = []
859
868
  top_logprobs_nums = []
860
869
  modalities_list = []
870
+ lora_paths = []
861
871
 
862
872
  # NOTE: with openai API, the prompt's logprobs are always not computed
863
873
 
@@ -920,6 +930,7 @@ def v1_chat_generate_request(
920
930
  return_logprobs.append(request.logprobs)
921
931
  logprob_start_lens.append(-1)
922
932
  top_logprobs_nums.append(request.top_logprobs or 0)
933
+ lora_paths.append(request.lora_path)
923
934
 
924
935
  sampling_params = {
925
936
  "temperature": request.temperature,
@@ -934,6 +945,7 @@ def v1_chat_generate_request(
934
945
  "frequency_penalty": request.frequency_penalty,
935
946
  "repetition_penalty": request.repetition_penalty,
936
947
  "regex": request.regex,
948
+ "ebnf": request.ebnf,
937
949
  "n": request.n,
938
950
  "no_stop_trim": request.no_stop_trim,
939
951
  "ignore_eos": request.ignore_eos,
@@ -958,6 +970,7 @@ def v1_chat_generate_request(
958
970
  logprob_start_lens = logprob_start_lens[0]
959
971
  top_logprobs_nums = top_logprobs_nums[0]
960
972
  modalities_list = modalities_list[0]
973
+ lora_paths = lora_paths[0]
961
974
  else:
962
975
  if isinstance(input_ids[0], str):
963
976
  prompt_kwargs = {"text": input_ids}
@@ -975,6 +988,7 @@ def v1_chat_generate_request(
975
988
  return_text_in_logprobs=True,
976
989
  rid=request_ids,
977
990
  modalities=modalities_list,
991
+ lora_path=lora_paths,
978
992
  )
979
993
 
980
994
  return adapted_request, all_requests if len(all_requests) > 1 else all_requests[0]
@@ -1104,6 +1118,15 @@ def v1_chat_generate_response(request, ret, to_file=False, cache_report=False):
1104
1118
 
1105
1119
  async def v1_chat_completions(tokenizer_manager, raw_request: Request):
1106
1120
  request_json = await raw_request.json()
1121
+ if "extra_body" in request_json:
1122
+ extra = request_json["extra_body"]
1123
+ # For example, if 'ebnf' is given:
1124
+ if "ebnf" in extra:
1125
+ request_json["ebnf"] = extra["ebnf"]
1126
+ if "regex" in extra:
1127
+ request_json["regex"] = extra["regex"]
1128
+ # remove extra_body to avoid pydantic conflict
1129
+ del request_json["extra_body"]
1107
1130
  all_requests = [ChatCompletionRequest(**request_json)]
1108
1131
  adapted_request, request = v1_chat_generate_request(all_requests, tokenizer_manager)
1109
1132
 
@@ -179,6 +179,7 @@ class CompletionRequest(BaseModel):
179
179
  ignore_eos: bool = False
180
180
  skip_special_tokens: bool = True
181
181
  lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
182
+ ebnf: Optional[str] = None
182
183
 
183
184
 
184
185
  class CompletionResponseChoice(BaseModel):
@@ -288,6 +289,7 @@ class ChatCompletionRequest(BaseModel):
288
289
  ignore_eos: bool = False
289
290
  skip_special_tokens: bool = True
290
291
  lora_path: Optional[Union[List[Optional[str]], Optional[str]]] = None
292
+ ebnf: Optional[str] = None
291
293
 
292
294
 
293
295
  class ChatMessage(BaseModel):
@@ -36,6 +36,7 @@ class SamplingParams:
36
36
  regex: Optional[str] = None,
37
37
  n: int = 1,
38
38
  json_schema: Optional[str] = None,
39
+ ebnf: Optional[str] = None,
39
40
  no_stop_trim: bool = False,
40
41
  ignore_eos: bool = False,
41
42
  skip_special_tokens: bool = True,
@@ -60,6 +61,7 @@ class SamplingParams:
60
61
  self.regex = regex
61
62
  self.n = n
62
63
  self.json_schema = json_schema
64
+ self.ebnf = ebnf
63
65
  self.no_stop_trim = no_stop_trim
64
66
 
65
67
  # Process some special cases
@@ -111,8 +113,13 @@ class SamplingParams:
111
113
  f"min_new_tokens must be in (0, max_new_tokens({self.max_new_tokens})], got "
112
114
  f"{self.min_new_tokens}."
113
115
  )
114
- if self.regex is not None and self.json_schema is not None:
115
- raise ValueError("regex and json_schema cannot be both set.")
116
+ grammars = [
117
+ self.json_schema,
118
+ self.regex,
119
+ self.ebnf,
120
+ ] # since mutually exclusive, only one can be set
121
+ if sum(x is not None for x in grammars) > 1:
122
+ raise ValueError("Only one of regex, json_schema, or ebnf can be set.")
116
123
 
117
124
  def normalize(self, tokenizer):
118
125
  # Process stop strings
sglang/srt/server.py CHANGED
@@ -245,16 +245,11 @@ async def get_weights_by_name(obj: GetWeightsByNameReqInput, request: Request):
245
245
  try:
246
246
  ret = await tokenizer_manager.get_weights_by_name(obj, request)
247
247
  if ret is None:
248
- return ORJSONResponse(
249
- {"error": {"message": "Get parameter by name failed"}},
250
- status_code=HTTPStatus.BAD_REQUEST,
251
- )
248
+ return _create_error_response("Get parameter by name failed")
252
249
  else:
253
250
  return ORJSONResponse(ret, status_code=200)
254
251
  except Exception as e:
255
- return ORJSONResponse(
256
- {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
257
- )
252
+ return _create_error_response(e)
258
253
 
259
254
 
260
255
  @app.api_route("/open_session", methods=["GET", "POST"])
@@ -264,9 +259,7 @@ async def open_session(obj: OpenSessionReqInput, request: Request):
264
259
  session_id = await tokenizer_manager.open_session(obj, request)
265
260
  return session_id
266
261
  except Exception as e:
267
- return ORJSONResponse(
268
- {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
269
- )
262
+ return _create_error_response(e)
270
263
 
271
264
 
272
265
  @app.api_route("/close_session", methods=["GET", "POST"])
@@ -276,9 +269,7 @@ async def close_session(obj: CloseSessionReqInput, request: Request):
276
269
  await tokenizer_manager.close_session(obj, request)
277
270
  return Response(status_code=200)
278
271
  except Exception as e:
279
- return ORJSONResponse(
280
- {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
281
- )
272
+ return _create_error_response(e)
282
273
 
283
274
 
284
275
  # fastapi implicitly converts json in the request to obj (dataclass)
@@ -311,9 +302,8 @@ async def generate_request(obj: GenerateReqInput, request: Request):
311
302
  ret = await tokenizer_manager.generate_request(obj, request).__anext__()
312
303
  return ret
313
304
  except ValueError as e:
314
- return ORJSONResponse(
315
- {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
316
- )
305
+ logger.error(f"Error: {e}")
306
+ return _create_error_response(e)
317
307
 
318
308
 
319
309
  @app.api_route("/encode", methods=["POST", "PUT"])
@@ -324,9 +314,7 @@ async def encode_request(obj: EmbeddingReqInput, request: Request):
324
314
  ret = await tokenizer_manager.generate_request(obj, request).__anext__()
325
315
  return ret
326
316
  except ValueError as e:
327
- return ORJSONResponse(
328
- {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
329
- )
317
+ return _create_error_response(e)
330
318
 
331
319
 
332
320
  @app.api_route("/classify", methods=["POST", "PUT"])
@@ -337,9 +325,7 @@ async def classify_request(obj: EmbeddingReqInput, request: Request):
337
325
  ret = await tokenizer_manager.generate_request(obj, request).__anext__()
338
326
  return ret
339
327
  except ValueError as e:
340
- return ORJSONResponse(
341
- {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
342
- )
328
+ return _create_error_response(e)
343
329
 
344
330
 
345
331
  ##### OpenAI-compatible API endpoints #####
@@ -415,6 +401,12 @@ async def retrieve_file_content(file_id: str):
415
401
  return await v1_retrieve_file_content(file_id)
416
402
 
417
403
 
404
+ def _create_error_response(e):
405
+ return ORJSONResponse(
406
+ {"error": {"message": str(e)}}, status_code=HTTPStatus.BAD_REQUEST
407
+ )
408
+
409
+
418
410
  def launch_engine(
419
411
  server_args: ServerArgs,
420
412
  ):
@@ -848,12 +840,10 @@ class Engine:
848
840
  group_name=group_name,
849
841
  backend=backend,
850
842
  )
851
-
852
- async def _init_group():
853
- return await tokenizer_manager.init_weights_update_group(obj, None)
854
-
855
843
  loop = asyncio.get_event_loop()
856
- return loop.run_until_complete(_init_group())
844
+ return loop.run_until_complete(
845
+ tokenizer_manager.init_weights_update_group(obj, None)
846
+ )
857
847
 
858
848
  def update_weights_from_distributed(self, name, dtype, shape):
859
849
  """Update weights from distributed source."""
@@ -862,22 +852,16 @@ class Engine:
862
852
  dtype=dtype,
863
853
  shape=shape,
864
854
  )
865
-
866
- async def _update_weights():
867
- return await tokenizer_manager.update_weights_from_distributed(obj, None)
868
-
869
855
  loop = asyncio.get_event_loop()
870
- return loop.run_until_complete(_update_weights())
856
+ return loop.run_until_complete(
857
+ tokenizer_manager.update_weights_from_distributed(obj, None)
858
+ )
871
859
 
872
860
  def get_weights_by_name(self, name, truncate_size=100):
873
861
  """Get weights by parameter name."""
874
862
  obj = GetWeightsByNameReqInput(name=name, truncate_size=truncate_size)
875
-
876
- async def _get_weights():
877
- return await tokenizer_manager.get_weights_by_name(obj, None)
878
-
879
863
  loop = asyncio.get_event_loop()
880
- return loop.run_until_complete(_get_weights())
864
+ return loop.run_until_complete(tokenizer_manager.get_weights_by_name(obj, None))
881
865
 
882
866
 
883
867
  class Runtime:
sglang/srt/utils.py CHANGED
@@ -14,6 +14,7 @@
14
14
  """Common utilities."""
15
15
 
16
16
  import base64
17
+ import dataclasses
17
18
  import ipaddress
18
19
  import itertools
19
20
  import json
@@ -1238,49 +1239,37 @@ def cuda_device_count_stateless() -> int:
1238
1239
  return _cuda_device_count_stateless(os.environ.get("CUDA_VISIBLE_DEVICES", None))
1239
1240
 
1240
1241
 
1241
- def should_use_tensor_core(
1242
- kv_cache_dtype: torch.dtype,
1243
- num_attention_heads: int,
1244
- num_kv_heads: int,
1245
- ) -> bool:
1246
- """
1247
- Determine whether to use tensor cores for attention computation.
1248
-
1249
- Args:
1250
- kv_cache_dtype: Data type of the KV cache
1251
- num_attention_heads: Number of attention heads
1252
- num_kv_heads: Number of key/value heads
1253
-
1254
- Returns:
1255
- bool: Whether to use tensor cores
1256
- """
1257
- # Try to use environment variable first
1258
- env_override = os.environ.get("SGLANG_FLASHINFER_USE_TENSOR_CORE")
1259
- if env_override is not None:
1260
- return env_override.lower() == "true"
1261
-
1262
- # Try to use _grouped_size_compiled_for_decode_kernels if available
1263
- # This is for flashinfer <=0.1.6. Otherwise, there is an accuracy bug
1264
- try:
1265
- from flashinfer.decode import _grouped_size_compiled_for_decode_kernels
1266
-
1267
- if not _grouped_size_compiled_for_decode_kernels(
1268
- num_attention_heads,
1269
- num_kv_heads,
1270
- ):
1271
- return True
1242
+ def dataclass_to_string_truncated(data, max_length=2048):
1243
+ if isinstance(data, str):
1244
+ if len(data) > max_length:
1245
+ half_length = max_length // 2
1246
+ return f'"{data[:half_length]} ... {data[-half_length:]}"'
1272
1247
  else:
1273
- return False
1274
- except (ImportError, AttributeError):
1275
- pass
1276
-
1277
- # Calculate GQA group size
1278
- gqa_group_size = num_attention_heads // num_kv_heads
1279
-
1280
- # Determine based on dtype and GQA group size
1281
- if kv_cache_dtype in (torch.float8_e4m3fn, torch.float8_e5m2):
1282
- return True
1283
- elif kv_cache_dtype in (torch.float16, torch.half, torch.bfloat16):
1284
- return gqa_group_size > 4
1248
+ return f'"{data}"'
1249
+ elif isinstance(data, (list, tuple)):
1250
+ if len(data) > max_length:
1251
+ half_length = max_length // 2
1252
+ return str(data[:half_length]) + " ... " + str(data[-half_length:])
1253
+ else:
1254
+ return str(data)
1255
+ elif isinstance(data, dict):
1256
+ return (
1257
+ "{"
1258
+ + ", ".join(
1259
+ f"{k}: {dataclass_to_string_truncated(v, max_length)}"
1260
+ for k, v in data.items()
1261
+ )
1262
+ + "}"
1263
+ )
1264
+ elif dataclasses.is_dataclass(data):
1265
+ fields = dataclasses.fields(data)
1266
+ return (
1267
+ f"{data.__class__.__name__}("
1268
+ + ", ".join(
1269
+ f"{f.name}={dataclass_to_string_truncated(getattr(data, f.name), max_length)}"
1270
+ for f in fields
1271
+ )
1272
+ + ")"
1273
+ )
1285
1274
  else:
1286
- return False
1275
+ return str(data)