sglang 0.3.5.post1__py3-none-any.whl → 0.3.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. sglang/bench_latency.py +1 -553
  2. sglang/bench_offline_throughput.py +337 -0
  3. sglang/bench_one_batch.py +474 -0
  4. sglang/{bench_server_latency.py → bench_one_batch_server.py} +3 -3
  5. sglang/bench_serving.py +115 -31
  6. sglang/check_env.py +3 -6
  7. sglang/srt/constrained/base_grammar_backend.py +4 -3
  8. sglang/srt/constrained/outlines_backend.py +39 -26
  9. sglang/srt/constrained/xgrammar_backend.py +58 -14
  10. sglang/srt/layers/activation.py +3 -0
  11. sglang/srt/layers/attention/flashinfer_backend.py +93 -48
  12. sglang/srt/layers/attention/triton_backend.py +9 -7
  13. sglang/srt/layers/custom_op_util.py +26 -0
  14. sglang/srt/layers/fused_moe/fused_moe.py +11 -4
  15. sglang/srt/layers/fused_moe/patch.py +4 -2
  16. sglang/srt/layers/layernorm.py +4 -0
  17. sglang/srt/layers/logits_processor.py +10 -10
  18. sglang/srt/layers/sampler.py +4 -8
  19. sglang/srt/layers/torchao_utils.py +2 -0
  20. sglang/srt/managers/data_parallel_controller.py +74 -9
  21. sglang/srt/managers/detokenizer_manager.py +1 -14
  22. sglang/srt/managers/io_struct.py +27 -0
  23. sglang/srt/managers/schedule_batch.py +104 -38
  24. sglang/srt/managers/schedule_policy.py +5 -1
  25. sglang/srt/managers/scheduler.py +210 -56
  26. sglang/srt/managers/session_controller.py +62 -0
  27. sglang/srt/managers/tokenizer_manager.py +38 -0
  28. sglang/srt/managers/tp_worker.py +12 -1
  29. sglang/srt/managers/tp_worker_overlap_thread.py +49 -52
  30. sglang/srt/model_executor/cuda_graph_runner.py +43 -6
  31. sglang/srt/model_executor/forward_batch_info.py +109 -15
  32. sglang/srt/model_executor/model_runner.py +102 -43
  33. sglang/srt/model_parallel.py +98 -0
  34. sglang/srt/models/deepseek_v2.py +147 -44
  35. sglang/srt/models/gemma2.py +9 -8
  36. sglang/srt/models/llava.py +1 -1
  37. sglang/srt/models/llavavid.py +1 -1
  38. sglang/srt/models/olmo.py +3 -3
  39. sglang/srt/models/phi3_small.py +447 -0
  40. sglang/srt/models/qwen2_vl.py +13 -6
  41. sglang/srt/models/torch_native_llama.py +94 -78
  42. sglang/srt/openai_api/adapter.py +11 -4
  43. sglang/srt/openai_api/protocol.py +30 -27
  44. sglang/srt/sampling/penaltylib/orchestrator.py +49 -79
  45. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +3 -8
  46. sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +3 -9
  47. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +3 -8
  48. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +3 -8
  49. sglang/srt/sampling/sampling_batch_info.py +58 -57
  50. sglang/srt/sampling/sampling_params.py +3 -3
  51. sglang/srt/server.py +29 -2
  52. sglang/srt/server_args.py +97 -60
  53. sglang/srt/utils.py +103 -51
  54. sglang/test/runners.py +25 -6
  55. sglang/test/srt/sampling/penaltylib/utils.py +23 -21
  56. sglang/test/test_utils.py +33 -22
  57. sglang/version.py +1 -1
  58. {sglang-0.3.5.post1.dist-info → sglang-0.3.6.dist-info}/METADATA +43 -43
  59. {sglang-0.3.5.post1.dist-info → sglang-0.3.6.dist-info}/RECORD +62 -56
  60. {sglang-0.3.5.post1.dist-info → sglang-0.3.6.dist-info}/WHEEL +1 -1
  61. {sglang-0.3.5.post1.dist-info → sglang-0.3.6.dist-info}/LICENSE +0 -0
  62. {sglang-0.3.5.post1.dist-info → sglang-0.3.6.dist-info}/top_level.txt +0 -0
sglang/bench_serving.py CHANGED
@@ -15,6 +15,7 @@ import argparse
15
15
  import asyncio
16
16
  import json
17
17
  import os
18
+ import pickle
18
19
  import random
19
20
  import resource
20
21
  import sys
@@ -387,6 +388,24 @@ async def async_request_gserver(
387
388
  raise NotImplementedError()
388
389
 
389
390
 
391
+ async def async_request_profile(api_url: str) -> RequestFuncOutput:
392
+ async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
393
+ output = RequestFuncOutput()
394
+ try:
395
+ async with session.post(url=api_url) as response:
396
+ if response.status == 200:
397
+ output.success = True
398
+ else:
399
+ output.error = response.reason or ""
400
+ output.success = False
401
+ except Exception:
402
+ output.success = False
403
+ exc_info = sys.exc_info()
404
+ output.error = "".join(traceback.format_exception(*exc_info))
405
+
406
+ return output
407
+
408
+
390
409
  def get_model(pretrained_model_name_or_path: str) -> str:
391
410
  if os.getenv("SGLANG_USE_MODELSCOPE", "False").lower() == "true":
392
411
  import huggingface_hub.constants
@@ -421,6 +440,37 @@ def get_tokenizer(
421
440
  )
422
441
 
423
442
 
443
+ def get_dataset(args, tokenizer):
444
+ if args.dataset_name == "sharegpt":
445
+ input_requests = sample_sharegpt_requests(
446
+ dataset_path=args.dataset_path,
447
+ num_requests=args.num_prompts,
448
+ tokenizer=tokenizer,
449
+ fixed_output_len=args.sharegpt_output_len,
450
+ )
451
+ elif args.dataset_name == "random":
452
+ input_requests = sample_random_requests(
453
+ input_len=args.random_input_len,
454
+ output_len=args.random_output_len,
455
+ num_prompts=args.num_prompts,
456
+ range_ratio=args.random_range_ratio,
457
+ tokenizer=tokenizer,
458
+ dataset_path=args.dataset_path,
459
+ )
460
+ elif args.dataset_name == "generated-shared-prefix":
461
+ input_requests = sample_generated_shared_prefix_requests(
462
+ num_groups=args.gen_num_groups,
463
+ prompts_per_group=args.gen_prompts_per_group,
464
+ system_prompt_len=args.gen_system_prompt_len,
465
+ question_len=args.gen_question_len,
466
+ output_len=args.gen_output_len,
467
+ tokenizer=tokenizer,
468
+ )
469
+ else:
470
+ raise ValueError(f"Unknown dataset: {args.dataset_name}")
471
+ return input_requests
472
+
473
+
424
474
  ASYNC_REQUEST_FUNCS = {
425
475
  "sglang": async_request_sglang_generate,
426
476
  "sglang-native": async_request_sglang_generate,
@@ -443,6 +493,8 @@ class BenchmarkMetrics:
443
493
  input_throughput: float
444
494
  output_throughput: float
445
495
  output_throughput_retokenized: float
496
+ total_throughput: float
497
+ total_throughput_retokenized: float
446
498
  mean_ttft_ms: float
447
499
  median_ttft_ms: float
448
500
  std_ttft_ms: float
@@ -590,7 +642,6 @@ def sample_random_requests(
590
642
  (data["conversations"][0]["value"], data["conversations"][1]["value"])
591
643
  for data in dataset
592
644
  ]
593
-
594
645
  # Shuffle the dataset.
595
646
  random.shuffle(dataset)
596
647
 
@@ -650,6 +701,11 @@ def sample_generated_shared_prefix_requests(
650
701
  output_len: int,
651
702
  tokenizer: PreTrainedTokenizerBase,
652
703
  ) -> List[Tuple[str, int, int]]:
704
+ if args.generated_input_path and os.path.exists(args.generated_input_path):
705
+ print(f"\nloading generated input data from {args.generated_input_path}")
706
+ with open(args.generated_input_path, "rb") as f:
707
+ return pickle.load(f)
708
+
653
709
  """Generate benchmark requests with shared system prompts using random tokens."""
654
710
  # Generate system prompts for each group
655
711
  system_prompts = []
@@ -663,6 +719,9 @@ def sample_generated_shared_prefix_requests(
663
719
  question = gen_prompt(tokenizer, question_len)
664
720
  questions.append(question)
665
721
 
722
+ # Shuffle questions
723
+ random.shuffle(questions)
724
+
666
725
  # Combine system prompts with questions
667
726
  input_requests = []
668
727
  total_input_tokens = 0
@@ -691,6 +750,11 @@ def sample_generated_shared_prefix_requests(
691
750
  print(
692
751
  f"Average question length: {sum(len(tokenizer.encode(q)) for q in questions) / len(questions):.1f} tokens\n"
693
752
  )
753
+ if args.generated_input_save_path:
754
+ print(f"Saving generated input data to {args.generated_input_save_path}")
755
+ os.makedirs(os.path.dirname(args.generated_input_save_path), exist_ok=True)
756
+ with open(args.generated_input_save_path, "wb") as f:
757
+ pickle.dump(input_requests, f)
694
758
 
695
759
  return input_requests
696
760
 
@@ -764,6 +828,9 @@ def calculate_metrics(
764
828
  input_throughput=total_input / dur_s,
765
829
  output_throughput=sum(output_lens) / dur_s,
766
830
  output_throughput_retokenized=sum(retokenized_output_lens) / dur_s,
831
+ total_throughput=(total_input + sum(output_lens)) / dur_s,
832
+ total_throughput_retokenized=(total_input + sum(retokenized_output_lens))
833
+ / dur_s,
767
834
  mean_ttft_ms=np.mean(ttfts or 0)
768
835
  * 1000, # ttfts is empty if streaming is not supported by backend
769
836
  median_ttft_ms=np.median(ttfts or 0) * 1000,
@@ -787,12 +854,14 @@ def calculate_metrics(
787
854
  async def benchmark(
788
855
  backend: str,
789
856
  api_url: str,
857
+ base_url: str,
790
858
  model_id: str,
791
859
  tokenizer: PreTrainedTokenizerBase,
792
860
  input_requests: List[Tuple[str, int, int]],
793
861
  request_rate: float,
794
862
  disable_tqdm: bool,
795
863
  extra_request_body: Dict[str, Any],
864
+ profile: bool,
796
865
  ):
797
866
  if backend in ASYNC_REQUEST_FUNCS:
798
867
  request_func = ASYNC_REQUEST_FUNCS[backend]
@@ -820,6 +889,14 @@ async def benchmark(
820
889
 
821
890
  time.sleep(1.5)
822
891
 
892
+ if profile:
893
+ print("Starting profiler...")
894
+ profile_output = await async_request_profile(
895
+ api_url=base_url + "/start_profile"
896
+ )
897
+ if profile_output.success:
898
+ print("Profiler started")
899
+
823
900
  pbar = None if disable_tqdm else tqdm(total=len(input_requests))
824
901
 
825
902
  benchmark_start_time = time.perf_counter()
@@ -841,6 +918,12 @@ async def benchmark(
841
918
  )
842
919
  outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks)
843
920
 
921
+ if profile:
922
+ print("Stopping profiler...")
923
+ profile_output = await async_request_profile(api_url=base_url + "/stop_profile")
924
+ if profile_output.success:
925
+ print("Profiler stopped")
926
+
844
927
  if pbar is not None:
845
928
  pbar.close()
846
929
 
@@ -881,6 +964,11 @@ async def benchmark(
881
964
  "Output token throughput (tok/s):", metrics.output_throughput
882
965
  )
883
966
  )
967
+ print(
968
+ "{:<40} {:<10.2f}".format(
969
+ "Total token throughput (tok/s):", metrics.total_throughput
970
+ )
971
+ )
884
972
  print("{s:{c}^{n}}".format(s="End-to-End Latency", n=50, c="-"))
885
973
  print(
886
974
  "{:<40} {:<10.2f}".format("Mean E2E Latency (ms):", metrics.mean_e2e_latency_ms)
@@ -1060,6 +1148,9 @@ def run_benchmark(args_: argparse.Namespace):
1060
1148
  if args.base_url
1061
1149
  else f"http://{args.host}:{args.port}/v1/models/model:predict"
1062
1150
  )
1151
+ base_url = (
1152
+ f"http://{args.host}:{args.port}" if args.base_url is None else args.base_url
1153
+ )
1063
1154
 
1064
1155
  # Get model name
1065
1156
  if args.model is None:
@@ -1098,47 +1189,21 @@ def run_benchmark(args_: argparse.Namespace):
1098
1189
 
1099
1190
  tokenizer = get_tokenizer(tokenizer_id)
1100
1191
 
1101
- if args.dataset_name == "sharegpt":
1102
- assert args.random_input_len is None and args.random_output_len is None
1103
- input_requests = sample_sharegpt_requests(
1104
- dataset_path=args.dataset_path,
1105
- num_requests=args.num_prompts,
1106
- tokenizer=tokenizer,
1107
- fixed_output_len=args.sharegpt_output_len,
1108
- )
1109
- elif args.dataset_name == "random":
1110
- assert args.random_input_len is not None and args.random_output_len is not None
1111
- input_requests = sample_random_requests(
1112
- input_len=args.random_input_len,
1113
- output_len=args.random_output_len,
1114
- num_prompts=args.num_prompts,
1115
- range_ratio=args.random_range_ratio,
1116
- tokenizer=tokenizer,
1117
- dataset_path=args.dataset_path,
1118
- )
1119
- elif args.dataset_name == "generated-shared-prefix":
1120
- input_requests = sample_generated_shared_prefix_requests(
1121
- num_groups=args.gen_num_groups,
1122
- prompts_per_group=args.gen_prompts_per_group,
1123
- system_prompt_len=args.gen_system_prompt_len,
1124
- question_len=args.gen_question_len,
1125
- output_len=args.gen_output_len,
1126
- tokenizer=tokenizer,
1127
- )
1128
- else:
1129
- raise ValueError(f"Unknown dataset: {args.dataset_name}")
1192
+ input_requests = get_dataset(args, tokenizer)
1130
1193
 
1131
1194
  if not args.multi:
1132
1195
  return asyncio.run(
1133
1196
  benchmark(
1134
1197
  backend=backend,
1135
1198
  api_url=api_url,
1199
+ base_url=base_url,
1136
1200
  model_id=model_id,
1137
1201
  tokenizer=tokenizer,
1138
1202
  input_requests=input_requests,
1139
1203
  request_rate=args.request_rate,
1140
1204
  disable_tqdm=args.disable_tqdm,
1141
1205
  extra_request_body=extra_request_body,
1206
+ profile=args.profile,
1142
1207
  )
1143
1208
  )
1144
1209
  else:
@@ -1150,12 +1215,14 @@ def run_benchmark(args_: argparse.Namespace):
1150
1215
  benchmark(
1151
1216
  backend=backend,
1152
1217
  api_url=api_url,
1218
+ base_url=base_url,
1153
1219
  model_id=model_id,
1154
1220
  tokenizer=tokenizer,
1155
1221
  input_requests=input_requests,
1156
1222
  request_rate=rate,
1157
1223
  disable_tqdm=args.disable_tqdm,
1158
1224
  extra_request_body=extra_request_body,
1225
+ profile=args.profile,
1159
1226
  )
1160
1227
  )
1161
1228
 
@@ -1229,10 +1296,12 @@ if __name__ == "__main__":
1229
1296
  parser.add_argument(
1230
1297
  "--random-input-len",
1231
1298
  type=int,
1299
+ default=1024,
1232
1300
  help="Number of input tokens per request, used only for random dataset.",
1233
1301
  )
1234
1302
  parser.add_argument(
1235
1303
  "--random-output-len",
1304
+ default=1024,
1236
1305
  type=int,
1237
1306
  help="Number of output tokens per request, used only for random dataset.",
1238
1307
  )
@@ -1317,6 +1386,21 @@ if __name__ == "__main__":
1317
1386
  default=256,
1318
1387
  help="Target length in tokens for outputs in generated-shared-prefix dataset",
1319
1388
  )
1320
-
1389
+ parser.add_argument(
1390
+ "--generated-input-save-path",
1391
+ type=str,
1392
+ help="Path to save generated input data",
1393
+ )
1394
+ parser.add_argument(
1395
+ "--generated-input-path",
1396
+ type=str,
1397
+ help="Path to load previously generated input data",
1398
+ )
1399
+ parser.add_argument(
1400
+ "--profile",
1401
+ action="store_true",
1402
+ help="Use Torch Profiler. The endpoint must be launched with "
1403
+ "SGLANG_TORCH_PROFILER_DIR to enable profiler.",
1404
+ )
1321
1405
  args = parser.parse_args()
1322
1406
  run_benchmark(args)
sglang/check_env.py CHANGED
@@ -15,24 +15,21 @@ PACKAGE_LIST = [
15
15
  "flashinfer",
16
16
  "triton",
17
17
  "transformers",
18
- "requests",
19
- "tqdm",
18
+ "torchao",
20
19
  "numpy",
21
20
  "aiohttp",
22
21
  "fastapi",
23
22
  "hf_transfer",
24
23
  "huggingface_hub",
25
24
  "interegular",
26
- "packaging",
27
- "PIL",
28
25
  "psutil",
29
26
  "pydantic",
27
+ "multipart",
28
+ "zmq",
30
29
  "uvicorn",
31
30
  "uvloop",
32
- "zmq",
33
31
  "vllm",
34
32
  "outlines",
35
- "multipart",
36
33
  "openai",
37
34
  "tiktoken",
38
35
  "anthropic",
@@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
13
13
  limitations under the License.
14
14
  """
15
15
 
16
- """The baseclass of backends for grammar-guided constrained decoding."""
16
+ """The baseclass of a backend for grammar-guided constrained decoding."""
17
17
 
18
18
  from concurrent.futures import Future, ThreadPoolExecutor
19
19
  from dataclasses import dataclass
@@ -52,7 +52,7 @@ class BaseGrammarBackend:
52
52
  else:
53
53
  entry.value = self.init_value_impl(key)
54
54
  entry.event.set()
55
- return entry.value.copy()
55
+ return entry.value.copy() if entry.value else None
56
56
 
57
57
  def init_value_impl(self, key: Tuple[str, str]) -> BaseGrammarObject:
58
58
  raise NotImplementedError()
@@ -62,7 +62,8 @@ class BaseGrammarBackend:
62
62
  entry = self.cache.get(key)
63
63
  if not entry or not entry.event.is_set():
64
64
  return None
65
- return self.cache[key].value.copy()
65
+ val = self.cache[key].value
66
+ return val.copy() if val else None
66
67
 
67
68
  def get_future_value(self, key: Tuple[str, str]) -> Future:
68
69
  return self.executor.submit(self.init_value, key)
@@ -19,9 +19,12 @@ import json
19
19
  import logging
20
20
  from typing import Dict, List, Optional, Tuple, Union
21
21
 
22
+ import interegular
22
23
  import torch
23
24
  from outlines.fsm.guide import RegexGuide
25
+ from outlines.fsm.json_schema import build_regex_from_schema
24
26
  from outlines.models.transformers import TransformerTokenizer
27
+ from pydantic import BaseModel
25
28
 
26
29
  from sglang.srt.constrained.base_grammar_backend import (
27
30
  BaseGrammarBackend,
@@ -32,26 +35,6 @@ from sglang.srt.constrained.outlines_jump_forward import OutlinesJumpForwardMap
32
35
  logger = logging.getLogger(__name__)
33
36
 
34
37
 
35
- try:
36
- from outlines.fsm.json_schema import build_regex_from_object
37
- except ImportError:
38
- # Since outlines 0.0.32, build_regex_from_object is replaced by build_regex_from_schema,
39
- # which only accepts string schema as input.
40
- from outlines.fsm.json_schema import build_regex_from_schema
41
- from pydantic import BaseModel
42
-
43
- def build_regex_from_object(
44
- object: Union[str, BaseModel, Dict], whitespace_pattern: Optional[str] = None
45
- ):
46
- if isinstance(object, type(BaseModel)):
47
- schema = json.dumps(object.model_json_schema())
48
- elif isinstance(object, Dict):
49
- schema = json.dumps(object)
50
- else:
51
- schema = object
52
- return build_regex_from_schema(schema, whitespace_pattern)
53
-
54
-
55
38
  class OutlinesGrammar(BaseGrammarObject):
56
39
  def __init__(
57
40
  self,
@@ -98,9 +81,22 @@ class OutlinesGrammar(BaseGrammarObject):
98
81
  ):
99
82
  self.state = next_state
100
83
 
101
- def fill_vocab_mask(self, vocab_mask: torch.Tensor):
84
+ def allocate_vocab_mask(
85
+ self, vocab_size: int, batch_size: int, device
86
+ ) -> torch.Tensor:
87
+ return torch.zeros(batch_size, vocab_size, dtype=torch.bool, device=device)
88
+
89
+ def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
90
+ tokens = torch.tensor(
91
+ self.guide.get_next_instruction(self.state).tokens, dtype=torch.int64
92
+ ).to(vocab_mask.device, non_blocking=True)
93
+ vocab_mask = vocab_mask[idx]
102
94
  vocab_mask.fill_(1)
103
- vocab_mask[self.guide.get_next_instruction(self.state).tokens] = 0
95
+ vocab_mask.scatter_(0, tokens, torch.zeros_like(tokens, dtype=torch.bool))
96
+
97
+ @staticmethod
98
+ def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor):
99
+ logits.masked_fill_(vocab_mask, float("-inf"))
104
100
 
105
101
  def copy(self):
106
102
  return OutlinesGrammar(self.guide, self.jump_forward_map)
@@ -147,19 +143,36 @@ class OutlinesGrammarBackend(BaseGrammarBackend):
147
143
  key_string,
148
144
  whitespace_pattern=self.whitespace_pattern,
149
145
  )
150
- except NotImplementedError as e:
146
+ except (NotImplementedError, json.decoder.JSONDecodeError) as e:
151
147
  logger.warning(
152
- f"skip invalid json schema: json_schema={key_string}, {e=}"
148
+ f"Skip invalid json_schema: json_schema={key_string}, {e=}"
153
149
  )
154
- return None, key_string
150
+ return None
155
151
  elif key_type == "regex":
156
152
  regex = key_string
157
153
  else:
158
154
  raise ValueError(f"Invalid key_type: {key_type}")
159
155
 
160
- guide = RegexGuide(regex, self.outlines_tokenizer)
156
+ try:
157
+ guide = RegexGuide(regex, self.outlines_tokenizer)
158
+ except interegular.patterns.InvalidSyntax as e:
159
+ logger.warning(f"skip invalid regex schema: {regex=}, {e=}")
160
+ return None
161
+
161
162
  if self.allow_jump_forward:
162
163
  jump_forward_map = OutlinesJumpForwardMap(regex)
163
164
  else:
164
165
  jump_forward_map = None
165
166
  return OutlinesGrammar(guide, jump_forward_map)
167
+
168
+
169
+ def build_regex_from_object(
170
+ object: Union[str, BaseModel, Dict], whitespace_pattern: Optional[str] = None
171
+ ):
172
+ if isinstance(object, type(BaseModel)):
173
+ schema = json.dumps(object.model_json_schema())
174
+ elif isinstance(object, Dict):
175
+ schema = json.dumps(object)
176
+ else:
177
+ schema = object
178
+ return build_regex_from_schema(schema, whitespace_pattern)
@@ -15,16 +15,34 @@ limitations under the License.
15
15
 
16
16
  """Constrained decoding with xgrammar backend."""
17
17
 
18
+ import logging
18
19
  from typing import List, Tuple
19
20
 
20
21
  import torch
21
- from xgrammar import CachedGrammarCompiler, CompiledGrammar, GrammarMatcher
22
+
23
+ try:
24
+ from xgrammar import (
25
+ CachedGrammarCompiler,
26
+ CompiledGrammar,
27
+ GrammarMatcher,
28
+ TokenizerInfo,
29
+ )
30
+
31
+ import_error = None
32
+ except ImportError as e:
33
+ CachedGrammarCompiler = CompiledGrammar = GrammarMatcher = TokenizerInfo = (
34
+ ImportError
35
+ )
36
+ import_error = e
22
37
 
23
38
  from sglang.srt.constrained.base_grammar_backend import (
24
39
  BaseGrammarBackend,
25
40
  BaseGrammarObject,
26
41
  )
27
42
 
43
+ logger = logging.getLogger(__name__)
44
+
45
+
28
46
  MAX_ROLLBACK_TOKENS = 10
29
47
 
30
48
 
@@ -67,19 +85,23 @@ class XGrammarGrammar(BaseGrammarObject):
67
85
  for i in range(k, len(new_output_ids)):
68
86
  assert self.matcher.accept_token(new_output_ids[i])
69
87
 
70
- def fill_vocab_mask(self, vocab_mask: torch.Tensor):
71
- # Note that this bitmask is a bitset, not bool
72
- bitmask = self.matcher.get_next_token_bitmask()
73
- # Mask the tokens that are not allowed
74
- vocab_mask[
75
- self.matcher.get_rejected_tokens_from_bitmask(bitmask, self.vocab_size)
76
- ] = 1
88
+ def allocate_vocab_mask(
89
+ self, vocab_size: int, batch_size: int, device
90
+ ) -> torch.Tensor:
91
+ return self.matcher.allocate_token_bitmask(vocab_size, batch_size)
92
+
93
+ def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
94
+ self.matcher.fill_next_token_bitmask(vocab_mask, idx)
95
+
96
+ @staticmethod
97
+ def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
98
+ GrammarMatcher.apply_token_bitmask_inplace(logits, vocab_mask)
77
99
 
78
100
  def copy(self):
79
101
  matcher = GrammarMatcher(
80
102
  self.ctx,
81
103
  max_rollback_tokens=MAX_ROLLBACK_TOKENS,
82
- mask_vocab_size=self.vocab_size,
104
+ vocab_size=self.vocab_size,
83
105
  )
84
106
  return XGrammarGrammar(matcher, self.vocab_size, self.ctx)
85
107
 
@@ -91,24 +113,46 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
91
113
  vocab_size: int,
92
114
  ):
93
115
  super().__init__()
94
- self.grammar_cache = CachedGrammarCompiler(tokenizer_or_vocab=tokenizer)
116
+
117
+ if import_error:
118
+ logger.warning(
119
+ f"Ignore import error for the grammar backend: {import_error}"
120
+ )
121
+ self.grammar_cache = None
122
+ return
123
+
124
+ tokenizer_info = TokenizerInfo.from_huggingface(tokenizer)
125
+ self.grammar_cache = CachedGrammarCompiler(tokenizer_info=tokenizer_info)
95
126
  self.vocab_size = vocab_size
96
127
 
97
128
  def init_value_impl(self, key: Tuple[str, str]) -> XGrammarGrammar:
129
+ if import_error:
130
+ raise import_error
131
+
98
132
  key_type, key_string = key
99
133
  if key_type == "json":
100
- ctx = self.grammar_cache.get_compiled_grammar_for_json_schema(key_string)
134
+ try:
135
+ ctx = self.grammar_cache.compile_json_schema_grammar(schema=key_string)
136
+ except RuntimeError as e:
137
+ logging.warning(
138
+ f"Skip invalid json_schema: json_schema={key_string}, {e=}"
139
+ )
140
+ return None
101
141
  elif key_type == "regex":
102
- raise ValueError("regex hasn't been supported by xgrammar yet")
142
+ logger.warning(
143
+ "regex hasn't been supported by xgrammar yet. This is skipped."
144
+ )
145
+ return None
103
146
  else:
104
147
  raise ValueError(f"Invalid key_type: {key_type}")
105
148
 
106
149
  matcher = GrammarMatcher(
107
150
  ctx,
108
151
  max_rollback_tokens=MAX_ROLLBACK_TOKENS,
109
- mask_vocab_size=self.vocab_size,
152
+ vocab_size=self.vocab_size,
110
153
  )
111
154
  return XGrammarGrammar(matcher, self.vocab_size, ctx)
112
155
 
113
156
  def reset(self):
114
- self.grammar_cache.clear()
157
+ if self.grammar_cache:
158
+ self.grammar_cache.clear()
@@ -32,12 +32,14 @@ from vllm.distributed import (
32
32
  )
33
33
  from vllm.model_executor.custom_op import CustomOp
34
34
 
35
+ from sglang.srt.layers.custom_op_util import register_custom_op
35
36
  from sglang.srt.layers.quantization.base_config import QuantizationConfig
36
37
  from sglang.srt.utils import set_weight_attrs
37
38
 
38
39
  logger = logging.getLogger(__name__)
39
40
 
40
41
 
42
+ @register_custom_op("sglang_silu_and_mul")
41
43
  class SiluAndMul(CustomOp):
42
44
  def forward_native(self, x: torch.Tensor) -> torch.Tensor:
43
45
  d = x.shape[-1] // 2
@@ -51,6 +53,7 @@ class SiluAndMul(CustomOp):
51
53
  return out
52
54
 
53
55
 
56
+ @register_custom_op("sglang_gelu_and_mul")
54
57
  class GeluAndMul(CustomOp):
55
58
  def __init__(self, approximate="tanh"):
56
59
  super().__init__()