sglang 0.3.5.post2__py3-none-any.whl → 0.3.6.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (118) hide show
  1. sglang/__init__.py +2 -2
  2. sglang/api.py +2 -2
  3. sglang/bench_latency.py +1 -553
  4. sglang/bench_offline_throughput.py +48 -20
  5. sglang/bench_one_batch.py +472 -0
  6. sglang/{bench_server_latency.py → bench_one_batch_server.py} +3 -3
  7. sglang/bench_serving.py +125 -6
  8. sglang/check_env.py +3 -6
  9. sglang/lang/backend/base_backend.py +1 -1
  10. sglang/lang/backend/runtime_endpoint.py +2 -2
  11. sglang/srt/configs/model_config.py +13 -14
  12. sglang/srt/constrained/__init__.py +13 -14
  13. sglang/srt/constrained/base_grammar_backend.py +13 -15
  14. sglang/srt/constrained/outlines_backend.py +28 -17
  15. sglang/srt/constrained/outlines_jump_forward.py +13 -15
  16. sglang/srt/constrained/xgrammar_backend.py +47 -58
  17. sglang/srt/conversation.py +13 -15
  18. sglang/srt/hf_transformers_utils.py +13 -15
  19. sglang/srt/layers/activation.py +16 -13
  20. sglang/srt/layers/attention/flashinfer_backend.py +106 -54
  21. sglang/srt/layers/attention/triton_backend.py +9 -7
  22. sglang/srt/layers/attention/triton_ops/decode_attention.py +51 -55
  23. sglang/srt/layers/attention/triton_ops/extend_attention.py +16 -16
  24. sglang/srt/layers/attention/triton_ops/prefill_attention.py +13 -15
  25. sglang/srt/layers/custom_op_util.py +25 -0
  26. sglang/srt/layers/fused_moe_grok/__init__.py +1 -0
  27. sglang/srt/layers/{fused_moe → fused_moe_grok}/fused_moe.py +11 -4
  28. sglang/srt/layers/{fused_moe → fused_moe_grok}/layer.py +4 -9
  29. sglang/srt/layers/{fused_moe/patch.py → fused_moe_patch.py} +5 -0
  30. sglang/srt/layers/fused_moe_triton/__init__.py +44 -0
  31. sglang/srt/layers/fused_moe_triton/fused_moe.py +861 -0
  32. sglang/srt/layers/fused_moe_triton/layer.py +633 -0
  33. sglang/srt/layers/layernorm.py +17 -15
  34. sglang/srt/layers/logits_processor.py +23 -25
  35. sglang/srt/layers/quantization/__init__.py +77 -17
  36. sglang/srt/layers/radix_attention.py +13 -15
  37. sglang/srt/layers/rotary_embedding.py +13 -13
  38. sglang/srt/layers/sampler.py +4 -8
  39. sglang/srt/layers/torchao_utils.py +2 -0
  40. sglang/srt/lora/lora.py +13 -14
  41. sglang/srt/lora/lora_config.py +13 -14
  42. sglang/srt/lora/lora_manager.py +22 -24
  43. sglang/srt/managers/data_parallel_controller.py +98 -27
  44. sglang/srt/managers/detokenizer_manager.py +13 -15
  45. sglang/srt/managers/io_struct.py +63 -21
  46. sglang/srt/managers/schedule_batch.py +154 -59
  47. sglang/srt/managers/schedule_policy.py +18 -16
  48. sglang/srt/managers/scheduler.py +278 -109
  49. sglang/srt/managers/session_controller.py +61 -0
  50. sglang/srt/managers/tokenizer_manager.py +63 -18
  51. sglang/srt/managers/tp_worker.py +25 -16
  52. sglang/srt/managers/tp_worker_overlap_thread.py +62 -67
  53. sglang/srt/metrics/collector.py +13 -15
  54. sglang/srt/metrics/func_timer.py +13 -15
  55. sglang/srt/mm_utils.py +13 -14
  56. sglang/srt/model_executor/cuda_graph_runner.py +63 -25
  57. sglang/srt/model_executor/forward_batch_info.py +128 -32
  58. sglang/srt/model_executor/model_runner.py +132 -64
  59. sglang/srt/model_parallel.py +98 -0
  60. sglang/srt/models/chatglm.py +15 -16
  61. sglang/srt/models/commandr.py +15 -16
  62. sglang/srt/models/dbrx.py +15 -16
  63. sglang/srt/models/deepseek.py +15 -15
  64. sglang/srt/models/deepseek_v2.py +162 -59
  65. sglang/srt/models/exaone.py +14 -15
  66. sglang/srt/models/gemma.py +14 -14
  67. sglang/srt/models/gemma2.py +31 -25
  68. sglang/srt/models/gemma2_reward.py +13 -14
  69. sglang/srt/models/gpt_bigcode.py +14 -14
  70. sglang/srt/models/grok.py +15 -15
  71. sglang/srt/models/internlm2.py +13 -15
  72. sglang/srt/models/internlm2_reward.py +13 -14
  73. sglang/srt/models/llama.py +21 -21
  74. sglang/srt/models/llama_classification.py +13 -14
  75. sglang/srt/models/llama_reward.py +13 -14
  76. sglang/srt/models/llava.py +14 -16
  77. sglang/srt/models/llavavid.py +14 -16
  78. sglang/srt/models/minicpm.py +13 -15
  79. sglang/srt/models/minicpm3.py +13 -15
  80. sglang/srt/models/mistral.py +13 -15
  81. sglang/srt/models/mixtral.py +15 -15
  82. sglang/srt/models/mixtral_quant.py +14 -14
  83. sglang/srt/models/olmo.py +22 -20
  84. sglang/srt/models/olmoe.py +23 -20
  85. sglang/srt/models/phi3_small.py +447 -0
  86. sglang/srt/models/qwen.py +14 -14
  87. sglang/srt/models/qwen2.py +22 -19
  88. sglang/srt/models/qwen2_moe.py +17 -18
  89. sglang/srt/models/qwen2_vl.py +13 -6
  90. sglang/srt/models/stablelm.py +18 -16
  91. sglang/srt/models/torch_native_llama.py +107 -93
  92. sglang/srt/models/xverse.py +13 -14
  93. sglang/srt/models/xverse_moe.py +15 -16
  94. sglang/srt/models/yivl.py +13 -15
  95. sglang/srt/openai_api/adapter.py +19 -17
  96. sglang/srt/openai_api/protocol.py +14 -16
  97. sglang/srt/sampling/penaltylib/orchestrator.py +49 -79
  98. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +3 -8
  99. sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +3 -9
  100. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +3 -8
  101. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +3 -8
  102. sglang/srt/sampling/sampling_batch_info.py +61 -57
  103. sglang/srt/sampling/sampling_params.py +14 -16
  104. sglang/srt/server.py +86 -35
  105. sglang/srt/server_args.py +96 -80
  106. sglang/srt/utils.py +266 -68
  107. sglang/test/few_shot_gsm8k.py +8 -4
  108. sglang/test/runners.py +38 -20
  109. sglang/test/srt/sampling/penaltylib/utils.py +23 -21
  110. sglang/test/test_utils.py +31 -20
  111. sglang/version.py +1 -1
  112. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/LICENSE +1 -1
  113. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/METADATA +66 -57
  114. sglang-0.3.6.post1.dist-info/RECORD +164 -0
  115. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/WHEEL +1 -1
  116. sglang/srt/layers/fused_moe/__init__.py +0 -1
  117. sglang-0.3.5.post2.dist-info/RECORD +0 -156
  118. {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/top_level.txt +0 -0
sglang/bench_serving.py CHANGED
@@ -15,6 +15,7 @@ import argparse
15
15
  import asyncio
16
16
  import json
17
17
  import os
18
+ import pickle
18
19
  import random
19
20
  import resource
20
21
  import sys
@@ -24,6 +25,7 @@ import warnings
24
25
  from argparse import ArgumentParser
25
26
  from dataclasses import dataclass, field
26
27
  from datetime import datetime
28
+ from pathlib import Path
27
29
  from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
28
30
 
29
31
  import aiohttp
@@ -387,8 +389,26 @@ async def async_request_gserver(
387
389
  raise NotImplementedError()
388
390
 
389
391
 
392
+ async def async_request_profile(api_url: str) -> RequestFuncOutput:
393
+ async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
394
+ output = RequestFuncOutput()
395
+ try:
396
+ async with session.post(url=api_url) as response:
397
+ if response.status == 200:
398
+ output.success = True
399
+ else:
400
+ output.error = response.reason or ""
401
+ output.success = False
402
+ except Exception:
403
+ output.success = False
404
+ exc_info = sys.exc_info()
405
+ output.error = "".join(traceback.format_exception(*exc_info))
406
+
407
+ return output
408
+
409
+
390
410
  def get_model(pretrained_model_name_or_path: str) -> str:
391
- if os.getenv("SGLANG_USE_MODELSCOPE", "False").lower() == "true":
411
+ if os.getenv("SGLANG_USE_MODELSCOPE", "false").lower() == "true":
392
412
  import huggingface_hub.constants
393
413
  from modelscope import snapshot_download
394
414
 
@@ -674,6 +694,19 @@ def gen_prompt(tokenizer, token_num):
674
694
  return tokenizer.decode(selected_tokens)
675
695
 
676
696
 
697
+ def get_gen_prefix_cache_path(args, tokenizer):
698
+ """Create cache directory under ~/.cache/sglang/benchmark"""
699
+ cache_dir = Path.home() / ".cache" / "sglang" / "benchmark"
700
+
701
+ # Create a unique cache filename based on the generation parameters
702
+ cache_key = (
703
+ f"gen_prefix_{args.gen_num_groups}_{args.gen_prompts_per_group}_"
704
+ f"{args.gen_system_prompt_len}_{args.gen_question_len}_{args.gen_output_len}_"
705
+ f"{tokenizer.__class__.__name__}.pkl"
706
+ )
707
+ return cache_dir / cache_key
708
+
709
+
677
710
  def sample_generated_shared_prefix_requests(
678
711
  num_groups: int,
679
712
  prompts_per_group: int,
@@ -682,7 +715,17 @@ def sample_generated_shared_prefix_requests(
682
715
  output_len: int,
683
716
  tokenizer: PreTrainedTokenizerBase,
684
717
  ) -> List[Tuple[str, int, int]]:
685
- """Generate benchmark requests with shared system prompts using random tokens."""
718
+ """Generate benchmark requests with shared system prompts using random tokens and caching."""
719
+ cache_path = get_gen_prefix_cache_path(args, tokenizer)
720
+
721
+ # Try to load from cache first
722
+ if cache_path.exists():
723
+ print(f"\nLoading cached generated input data from {cache_path}")
724
+ with open(cache_path, "rb") as f:
725
+ return pickle.load(f)
726
+
727
+ print("\nGenerating new input data...")
728
+
686
729
  # Generate system prompts for each group
687
730
  system_prompts = []
688
731
  for _ in range(num_groups):
@@ -700,9 +743,11 @@ def sample_generated_shared_prefix_requests(
700
743
  total_input_tokens = 0
701
744
  total_output_tokens = 0
702
745
 
703
- for group_idx in range(num_groups):
746
+ for group_idx in tqdm(range(num_groups), desc="Generating system prompt"):
704
747
  system_prompt = system_prompts[group_idx]
705
- for prompt_idx in range(prompts_per_group):
748
+ for prompt_idx in tqdm(
749
+ range(prompts_per_group), desc="Generating questions", leave=False
750
+ ):
706
751
  question = questions[group_idx * prompts_per_group + prompt_idx]
707
752
  full_prompt = f"{system_prompt}\n\n{question}"
708
753
  prompt_len = len(tokenizer.encode(full_prompt))
@@ -711,6 +756,10 @@ def sample_generated_shared_prefix_requests(
711
756
  total_input_tokens += prompt_len
712
757
  total_output_tokens += output_len
713
758
 
759
+ # Shuffle questions
760
+ random.shuffle(input_requests)
761
+
762
+ # Print statistics
714
763
  print(f"\nGenerated shared prefix dataset statistics:")
715
764
  print(f"Number of groups: {num_groups}")
716
765
  print(f"Prompts per group: {prompts_per_group}")
@@ -724,6 +773,12 @@ def sample_generated_shared_prefix_requests(
724
773
  f"Average question length: {sum(len(tokenizer.encode(q)) for q in questions) / len(questions):.1f} tokens\n"
725
774
  )
726
775
 
776
+ # Save to cache
777
+ cache_path.parent.mkdir(parents=True, exist_ok=True)
778
+ print(f"Caching generated input data to {cache_path}")
779
+ with open(cache_path, "wb") as f:
780
+ pickle.dump(input_requests, f)
781
+
727
782
  return input_requests
728
783
 
729
784
 
@@ -822,18 +877,30 @@ def calculate_metrics(
822
877
  async def benchmark(
823
878
  backend: str,
824
879
  api_url: str,
880
+ base_url: str,
825
881
  model_id: str,
826
882
  tokenizer: PreTrainedTokenizerBase,
827
883
  input_requests: List[Tuple[str, int, int]],
828
884
  request_rate: float,
885
+ max_concurrency: Optional[int],
829
886
  disable_tqdm: bool,
830
887
  extra_request_body: Dict[str, Any],
888
+ profile: bool,
831
889
  ):
832
890
  if backend in ASYNC_REQUEST_FUNCS:
833
891
  request_func = ASYNC_REQUEST_FUNCS[backend]
834
892
  else:
835
893
  raise ValueError(f"Unknown backend: {backend}")
836
894
 
895
+ # From https://github.com/vllm-project/vllm/pull/9390
896
+ semaphore = asyncio.Semaphore(max_concurrency) if max_concurrency else None
897
+
898
+ async def limited_request_func(request_func_input, pbar):
899
+ if semaphore is None:
900
+ return await request_func(request_func_input=request_func_input, pbar=pbar)
901
+ async with semaphore:
902
+ return await request_func(request_func_input=request_func_input, pbar=pbar)
903
+
837
904
  print("Starting initial single prompt test run...")
838
905
  test_prompt, test_prompt_len, test_output_len = input_requests[0]
839
906
  test_input = RequestFuncInput(
@@ -855,6 +922,14 @@ async def benchmark(
855
922
 
856
923
  time.sleep(1.5)
857
924
 
925
+ if profile:
926
+ print("Starting profiler...")
927
+ profile_output = await async_request_profile(
928
+ api_url=base_url + "/start_profile"
929
+ )
930
+ if profile_output.success:
931
+ print("Profiler started")
932
+
858
933
  pbar = None if disable_tqdm else tqdm(total=len(input_requests))
859
934
 
860
935
  benchmark_start_time = time.perf_counter()
@@ -871,11 +946,17 @@ async def benchmark(
871
946
  )
872
947
  tasks.append(
873
948
  asyncio.create_task(
874
- request_func(request_func_input=request_func_input, pbar=pbar)
949
+ limited_request_func(request_func_input=request_func_input, pbar=pbar)
875
950
  )
876
951
  )
877
952
  outputs: List[RequestFuncOutput] = await asyncio.gather(*tasks)
878
953
 
954
+ if profile:
955
+ print("Stopping profiler...")
956
+ profile_output = await async_request_profile(api_url=base_url + "/stop_profile")
957
+ if profile_output.success:
958
+ print("Profiler stopped")
959
+
879
960
  if pbar is not None:
880
961
  pbar.close()
881
962
 
@@ -892,6 +973,12 @@ async def benchmark(
892
973
  print("\n{s:{c}^{n}}".format(s=" Serving Benchmark Result ", n=50, c="="))
893
974
  print("{:<40} {:<10}".format("Backend:", backend))
894
975
  print("{:<40} {:<10}".format("Traffic request rate:", request_rate))
976
+ print(
977
+ "{:<40} {:<10}".format(
978
+ "Max reqeuest concurrency:",
979
+ max_concurrency if max_concurrency else "not set",
980
+ )
981
+ )
895
982
  print("{:<40} {:<10}".format("Successful requests:", metrics.completed))
896
983
  print("{:<40} {:<10.2f}".format("Benchmark duration (s):", benchmark_duration))
897
984
  print("{:<40} {:<10}".format("Total input tokens:", metrics.total_input))
@@ -955,6 +1042,7 @@ async def benchmark(
955
1042
  "backend": args.backend,
956
1043
  "dataset_name": args.dataset_name,
957
1044
  "request_rate": request_rate,
1045
+ "max_concurrency": max_concurrency,
958
1046
  "total_input_tokens": metrics.total_input,
959
1047
  "total_output_tokens": metrics.total_output,
960
1048
  "total_output_tokens_retokenized": metrics.total_output_retokenized,
@@ -1042,6 +1130,10 @@ def run_benchmark(args_: argparse.Namespace):
1042
1130
  global args
1043
1131
  args = args_
1044
1132
 
1133
+ # Set default value for max_concurrency if not present
1134
+ if not hasattr(args, "max_concurrency"):
1135
+ args.max_concurrency = None
1136
+
1045
1137
  # Set global environments
1046
1138
  set_ulimit()
1047
1139
  random.seed(args.seed)
@@ -1100,6 +1192,9 @@ def run_benchmark(args_: argparse.Namespace):
1100
1192
  if args.base_url
1101
1193
  else f"http://{args.host}:{args.port}/v1/models/model:predict"
1102
1194
  )
1195
+ base_url = (
1196
+ f"http://{args.host}:{args.port}" if args.base_url is None else args.base_url
1197
+ )
1103
1198
 
1104
1199
  # Get model name
1105
1200
  if args.model is None:
@@ -1145,12 +1240,15 @@ def run_benchmark(args_: argparse.Namespace):
1145
1240
  benchmark(
1146
1241
  backend=backend,
1147
1242
  api_url=api_url,
1243
+ base_url=base_url,
1148
1244
  model_id=model_id,
1149
1245
  tokenizer=tokenizer,
1150
1246
  input_requests=input_requests,
1151
1247
  request_rate=args.request_rate,
1248
+ max_concurrency=args.max_concurrency,
1152
1249
  disable_tqdm=args.disable_tqdm,
1153
1250
  extra_request_body=extra_request_body,
1251
+ profile=args.profile,
1154
1252
  )
1155
1253
  )
1156
1254
  else:
@@ -1162,12 +1260,15 @@ def run_benchmark(args_: argparse.Namespace):
1162
1260
  benchmark(
1163
1261
  backend=backend,
1164
1262
  api_url=api_url,
1263
+ base_url=base_url,
1165
1264
  model_id=model_id,
1166
1265
  tokenizer=tokenizer,
1167
1266
  input_requests=input_requests,
1168
1267
  request_rate=rate,
1268
+ max_concurrency=args.max_concurrency,
1169
1269
  disable_tqdm=args.disable_tqdm,
1170
1270
  extra_request_body=extra_request_body,
1271
+ profile=args.profile,
1171
1272
  )
1172
1273
  )
1173
1274
 
@@ -1264,6 +1365,19 @@ if __name__ == "__main__":
1264
1365
  help="Number of requests per second. If this is inf, then all the requests are sent at time 0. "
1265
1366
  "Otherwise, we use Poisson process to synthesize the request arrival times. Default is inf.",
1266
1367
  )
1368
+ parser.add_argument(
1369
+ "--max-concurrency",
1370
+ type=int,
1371
+ default=None,
1372
+ help="Maximum number of concurrent requests. This can be used "
1373
+ "to help simulate an environment where a higher level component "
1374
+ "is enforcing a maximum number of concurrent requests. While the "
1375
+ "--request-rate argument controls the rate at which requests are "
1376
+ "initiated, this argument will control how many are actually allowed "
1377
+ "to execute at a time. This means that when used in combination, the "
1378
+ "actual request rate may be lower than specified with --request-rate, "
1379
+ "if the server is not processing requests fast enough to keep up.",
1380
+ )
1267
1381
  parser.add_argument("--seed", type=int, default=1, help="The random seed.")
1268
1382
  parser.add_argument(
1269
1383
  "--multi",
@@ -1331,6 +1445,11 @@ if __name__ == "__main__":
1331
1445
  default=256,
1332
1446
  help="Target length in tokens for outputs in generated-shared-prefix dataset",
1333
1447
  )
1334
-
1448
+ parser.add_argument(
1449
+ "--profile",
1450
+ action="store_true",
1451
+ help="Use Torch Profiler. The endpoint must be launched with "
1452
+ "SGLANG_TORCH_PROFILER_DIR to enable profiler.",
1453
+ )
1335
1454
  args = parser.parse_args()
1336
1455
  run_benchmark(args)
sglang/check_env.py CHANGED
@@ -15,24 +15,21 @@ PACKAGE_LIST = [
15
15
  "flashinfer",
16
16
  "triton",
17
17
  "transformers",
18
- "requests",
19
- "tqdm",
18
+ "torchao",
20
19
  "numpy",
21
20
  "aiohttp",
22
21
  "fastapi",
23
22
  "hf_transfer",
24
23
  "huggingface_hub",
25
24
  "interegular",
26
- "packaging",
27
- "PIL",
28
25
  "psutil",
29
26
  "pydantic",
27
+ "multipart",
28
+ "zmq",
30
29
  "uvicorn",
31
30
  "uvloop",
32
- "zmq",
33
31
  "vllm",
34
32
  "outlines",
35
- "multipart",
36
33
  "openai",
37
34
  "tiktoken",
38
35
  "anthropic",
@@ -78,5 +78,5 @@ class BaseBackend:
78
78
  def flush_cache(self):
79
79
  pass
80
80
 
81
- def get_server_args(self):
81
+ def get_server_info(self):
82
82
  pass
@@ -58,9 +58,9 @@ class RuntimeEndpoint(BaseBackend):
58
58
  )
59
59
  self._assert_success(res)
60
60
 
61
- def get_server_args(self):
61
+ def get_server_info(self):
62
62
  res = http_request(
63
- self.base_url + "/get_server_args",
63
+ self.base_url + "/get_server_info",
64
64
  api_key=self.api_key,
65
65
  verify=self.verify,
66
66
  )
@@ -1,17 +1,16 @@
1
- """
2
- Copyright 2023-2024 SGLang Team
3
- Licensed under the Apache License, Version 2.0 (the "License");
4
- you may not use this file except in compliance with the License.
5
- You may obtain a copy of the License at
6
-
7
- http://www.apache.org/licenses/LICENSE-2.0
8
-
9
- Unless required by applicable law or agreed to in writing, software
10
- distributed under the License is distributed on an "AS IS" BASIS,
11
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- See the License for the specific language governing permissions and
13
- limitations under the License.
14
- """
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
15
14
 
16
15
  import json
17
16
  import logging
@@ -1,17 +1,16 @@
1
- """
2
- Copyright 2023-2024 SGLang Team
3
- Licensed under the Apache License, Version 2.0 (the "License");
4
- you may not use this file except in compliance with the License.
5
- You may obtain a copy of the License at
6
-
7
- http://www.apache.org/licenses/LICENSE-2.0
8
-
9
- Unless required by applicable law or agreed to in writing, software
10
- distributed under the License is distributed on an "AS IS" BASIS,
11
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- See the License for the specific language governing permissions and
13
- limitations under the License.
14
- """
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
15
14
 
16
15
  # TODO(lmzheng): make this an optional dependency
17
16
  from sglang.srt.constrained.outlines_backend import build_regex_from_object
@@ -1,18 +1,16 @@
1
- """
2
- Copyright 2023-2024 SGLang Team
3
- Licensed under the Apache License, Version 2.0 (the "License");
4
- you may not use this file except in compliance with the License.
5
- You may obtain a copy of the License at
6
-
7
- http://www.apache.org/licenses/LICENSE-2.0
8
-
9
- Unless required by applicable law or agreed to in writing, software
10
- distributed under the License is distributed on an "AS IS" BASIS,
11
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- See the License for the specific language governing permissions and
13
- limitations under the License.
14
- """
15
-
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
16
14
  """The baseclass of a backend for grammar-guided constrained decoding."""
17
15
 
18
16
  from concurrent.futures import Future, ThreadPoolExecutor
@@ -1,18 +1,16 @@
1
- """
2
- Copyright 2023-2024 SGLang Team
3
- Licensed under the Apache License, Version 2.0 (the "License");
4
- you may not use this file except in compliance with the License.
5
- You may obtain a copy of the License at
6
-
7
- http://www.apache.org/licenses/LICENSE-2.0
8
-
9
- Unless required by applicable law or agreed to in writing, software
10
- distributed under the License is distributed on an "AS IS" BASIS,
11
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- See the License for the specific language governing permissions and
13
- limitations under the License.
14
- """
15
-
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
16
14
  """Constrained decoding with outlines backend."""
17
15
 
18
16
  import json
@@ -81,9 +79,22 @@ class OutlinesGrammar(BaseGrammarObject):
81
79
  ):
82
80
  self.state = next_state
83
81
 
84
- def fill_vocab_mask(self, vocab_mask: torch.Tensor):
82
+ def allocate_vocab_mask(
83
+ self, vocab_size: int, batch_size: int, device
84
+ ) -> torch.Tensor:
85
+ return torch.zeros(batch_size, vocab_size, dtype=torch.bool, device=device)
86
+
87
+ def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
88
+ tokens = torch.tensor(
89
+ self.guide.get_next_instruction(self.state).tokens, dtype=torch.int64
90
+ ).to(vocab_mask.device, non_blocking=True)
91
+ vocab_mask = vocab_mask[idx]
85
92
  vocab_mask.fill_(1)
86
- vocab_mask[self.guide.get_next_instruction(self.state).tokens] = 0
93
+ vocab_mask.scatter_(0, tokens, torch.zeros_like(tokens, dtype=torch.bool))
94
+
95
+ @staticmethod
96
+ def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor):
97
+ logits.masked_fill_(vocab_mask, float("-inf"))
87
98
 
88
99
  def copy(self):
89
100
  return OutlinesGrammar(self.guide, self.jump_forward_map)
@@ -1,18 +1,16 @@
1
- """
2
- Copyright 2023-2024 SGLang Team
3
- Licensed under the Apache License, Version 2.0 (the "License");
4
- you may not use this file except in compliance with the License.
5
- You may obtain a copy of the License at
6
-
7
- http://www.apache.org/licenses/LICENSE-2.0
8
-
9
- Unless required by applicable law or agreed to in writing, software
10
- distributed under the License is distributed on an "AS IS" BASIS,
11
- WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
- See the License for the specific language governing permissions and
13
- limitations under the License.
14
- """
15
-
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
16
14
  """
17
15
  Faster constrained decoding with jump forward decoding / compressed finite state machine.
18
16
  Reference: https://lmsys.org/blog/2024-02-05-compressed-fsm/