sglang 0.2.11__py3-none-any.whl → 0.2.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. sglang/api.py +7 -1
  2. sglang/bench_latency.py +9 -6
  3. sglang/bench_serving.py +46 -22
  4. sglang/global_config.py +1 -1
  5. sglang/lang/backend/runtime_endpoint.py +60 -49
  6. sglang/lang/compiler.py +2 -2
  7. sglang/lang/interpreter.py +4 -2
  8. sglang/lang/ir.py +16 -7
  9. sglang/srt/constrained/base_tool_cache.py +1 -1
  10. sglang/srt/constrained/fsm_cache.py +12 -2
  11. sglang/srt/constrained/jump_forward.py +13 -2
  12. sglang/srt/layers/activation.py +32 -0
  13. sglang/srt/layers/{token_attention.py → decode_attention.py} +9 -5
  14. sglang/srt/layers/extend_attention.py +9 -2
  15. sglang/srt/layers/fused_moe/__init__.py +1 -0
  16. sglang/srt/layers/{fused_moe.py → fused_moe/fused_moe.py} +165 -108
  17. sglang/srt/layers/fused_moe/layer.py +587 -0
  18. sglang/srt/layers/layernorm.py +65 -0
  19. sglang/srt/layers/logits_processor.py +7 -2
  20. sglang/srt/layers/pooler.py +50 -0
  21. sglang/srt/layers/{context_flashattention_nopad.py → prefill_attention.py} +5 -0
  22. sglang/srt/layers/radix_attention.py +40 -16
  23. sglang/srt/managers/detokenizer_manager.py +31 -9
  24. sglang/srt/managers/io_struct.py +63 -0
  25. sglang/srt/managers/policy_scheduler.py +173 -25
  26. sglang/srt/managers/schedule_batch.py +115 -97
  27. sglang/srt/managers/tokenizer_manager.py +194 -112
  28. sglang/srt/managers/tp_worker.py +290 -359
  29. sglang/srt/mem_cache/{base_cache.py → base_prefix_cache.py} +9 -4
  30. sglang/srt/mem_cache/chunk_cache.py +43 -20
  31. sglang/srt/mem_cache/memory_pool.py +2 -2
  32. sglang/srt/mem_cache/radix_cache.py +74 -40
  33. sglang/srt/model_executor/cuda_graph_runner.py +71 -25
  34. sglang/srt/model_executor/forward_batch_info.py +293 -156
  35. sglang/srt/model_executor/model_runner.py +77 -57
  36. sglang/srt/models/chatglm.py +2 -2
  37. sglang/srt/models/commandr.py +1 -1
  38. sglang/srt/models/deepseek.py +2 -2
  39. sglang/srt/models/deepseek_v2.py +7 -6
  40. sglang/srt/models/gemma.py +1 -1
  41. sglang/srt/models/gemma2.py +11 -6
  42. sglang/srt/models/grok.py +50 -396
  43. sglang/srt/models/internlm2.py +2 -7
  44. sglang/srt/models/llama2.py +4 -4
  45. sglang/srt/models/llama_embedding.py +88 -0
  46. sglang/srt/models/minicpm.py +2 -2
  47. sglang/srt/models/mixtral.py +56 -254
  48. sglang/srt/models/mixtral_quant.py +1 -4
  49. sglang/srt/models/qwen.py +2 -2
  50. sglang/srt/models/qwen2.py +2 -2
  51. sglang/srt/models/qwen2_moe.py +2 -13
  52. sglang/srt/models/stablelm.py +1 -1
  53. sglang/srt/openai_api/adapter.py +187 -48
  54. sglang/srt/openai_api/protocol.py +37 -1
  55. sglang/srt/sampling/penaltylib/__init__.py +13 -0
  56. sglang/srt/sampling/penaltylib/orchestrator.py +357 -0
  57. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +80 -0
  58. sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +105 -0
  59. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +79 -0
  60. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +83 -0
  61. sglang/srt/sampling_params.py +31 -8
  62. sglang/srt/server.py +91 -29
  63. sglang/srt/server_args.py +32 -19
  64. sglang/srt/utils.py +32 -15
  65. sglang/test/run_eval.py +10 -1
  66. sglang/test/runners.py +81 -73
  67. sglang/test/simple_eval_humaneval.py +2 -8
  68. sglang/test/simple_eval_mgsm.py +203 -0
  69. sglang/test/srt/sampling/penaltylib/utils.py +337 -0
  70. sglang/test/test_layernorm.py +60 -0
  71. sglang/test/test_programs.py +36 -7
  72. sglang/test/test_utils.py +24 -2
  73. sglang/utils.py +0 -1
  74. sglang/version.py +1 -1
  75. {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/METADATA +33 -16
  76. sglang-0.2.13.dist-info/RECORD +112 -0
  77. {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/WHEEL +1 -1
  78. sglang/srt/layers/linear.py +0 -884
  79. sglang/srt/layers/quantization/__init__.py +0 -64
  80. sglang/srt/layers/quantization/fp8.py +0 -677
  81. sglang/srt/model_loader/model_loader.py +0 -292
  82. sglang/srt/model_loader/utils.py +0 -275
  83. sglang-0.2.11.dist-info/RECORD +0 -102
  84. {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/LICENSE +0 -0
  85. {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/top_level.txt +0 -0
sglang/api.py CHANGED
@@ -62,6 +62,7 @@ def gen(
62
62
  name: Optional[str] = None,
63
63
  max_tokens: Optional[int] = None,
64
64
  stop: Optional[Union[str, List[str]]] = None,
65
+ stop_token_ids: Optional[List[int]] = None,
65
66
  temperature: Optional[float] = None,
66
67
  top_p: Optional[float] = None,
67
68
  top_k: Optional[int] = None,
@@ -72,7 +73,7 @@ def gen(
72
73
  logprob_start_len: Optional[int] = None,
73
74
  top_logprobs_num: Optional[int] = None,
74
75
  return_text_in_logprobs: Optional[bool] = None,
75
- dtype: Optional[type] = None,
76
+ dtype: Optional[Union[type, str]] = None,
76
77
  choices: Optional[List[str]] = None,
77
78
  choices_method: Optional[ChoicesSamplingMethod] = None,
78
79
  regex: Optional[str] = None,
@@ -98,6 +99,7 @@ def gen(
98
99
  name,
99
100
  max_tokens,
100
101
  stop,
102
+ stop_token_ids,
101
103
  temperature,
102
104
  top_p,
103
105
  top_k,
@@ -117,6 +119,7 @@ def gen_int(
117
119
  name: Optional[str] = None,
118
120
  max_tokens: Optional[int] = None,
119
121
  stop: Optional[Union[str, List[str]]] = None,
122
+ stop_token_ids: Optional[List[int]] = None,
120
123
  temperature: Optional[float] = None,
121
124
  top_p: Optional[float] = None,
122
125
  top_k: Optional[int] = None,
@@ -132,6 +135,7 @@ def gen_int(
132
135
  name,
133
136
  max_tokens,
134
137
  stop,
138
+ stop_token_ids,
135
139
  temperature,
136
140
  top_p,
137
141
  top_k,
@@ -151,6 +155,7 @@ def gen_string(
151
155
  name: Optional[str] = None,
152
156
  max_tokens: Optional[int] = None,
153
157
  stop: Optional[Union[str, List[str]]] = None,
158
+ stop_token_ids: Optional[List[int]] = None,
154
159
  temperature: Optional[float] = None,
155
160
  top_p: Optional[float] = None,
156
161
  top_k: Optional[int] = None,
@@ -166,6 +171,7 @@ def gen_string(
166
171
  name,
167
172
  max_tokens,
168
173
  stop,
174
+ stop_token_ids,
169
175
  temperature,
170
176
  top_p,
171
177
  top_k,
sglang/bench_latency.py CHANGED
@@ -64,7 +64,7 @@ class BenchArgs:
64
64
  run_name: str = "before"
65
65
  batch_size: Tuple[int] = (1,)
66
66
  input_len: Tuple[int] = (1024,)
67
- output_len: Tuple[int] = (4,)
67
+ output_len: Tuple[int] = (16,)
68
68
  result_filename: str = ""
69
69
  correctness_test: bool = False
70
70
  # This is only used for correctness test
@@ -152,7 +152,7 @@ def prepare_inputs_for_correctness_test(bench_args, tokenizer):
152
152
  req = Req(rid=i, origin_input_text=prompts[i], origin_input_ids=tmp_input_ids)
153
153
  req.prefix_indices = []
154
154
  req.sampling_params = sampling_params
155
- req.input_ids = req.origin_input_ids
155
+ req.fill_ids = req.origin_input_ids
156
156
  reqs.append(req)
157
157
 
158
158
  return input_ids, reqs
@@ -163,7 +163,7 @@ def prepare_extend_inputs_for_correctness_test(
163
163
  ):
164
164
  for i in range(len(reqs)):
165
165
  req = reqs[i]
166
- req.input_ids += input_ids[i][bench_args.cut_len :]
166
+ req.fill_ids += input_ids[i][bench_args.cut_len :]
167
167
  req.prefix_indices = model_runner.req_to_token_pool.req_to_token[
168
168
  i, : bench_args.cut_len
169
169
  ]
@@ -182,7 +182,7 @@ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
182
182
  req = Req(rid=i, origin_input_text="", origin_input_ids=list(input_ids[i]))
183
183
  req.prefix_indices = []
184
184
  req.sampling_params = sampling_params
185
- req.input_ids = req.origin_input_ids
185
+ req.fill_ids = req.origin_input_ids
186
186
  reqs.append(req)
187
187
 
188
188
  return reqs
@@ -195,7 +195,7 @@ def extend(reqs, model_runner):
195
195
  token_to_kv_pool=model_runner.token_to_kv_pool,
196
196
  tree_cache=None,
197
197
  )
198
- batch.prepare_for_extend(model_runner.model_config.vocab_size, None)
198
+ batch.prepare_for_extend(model_runner.model_config.vocab_size)
199
199
  output = model_runner.forward(batch, ForwardMode.EXTEND)
200
200
  next_token_ids = batch.sample(output.next_token_logits)
201
201
  return next_token_ids, output.next_token_logits, batch
@@ -221,6 +221,7 @@ def correctness_test(
221
221
 
222
222
  # Prepare inputs
223
223
  input_ids, reqs = prepare_inputs_for_correctness_test(bench_args, tokenizer)
224
+ rank_print(f"{input_ids=}")
224
225
 
225
226
  if bench_args.cut_len > 0:
226
227
  # Prefill
@@ -238,7 +239,7 @@ def correctness_test(
238
239
 
239
240
  # Decode
240
241
  output_ids = [input_ids[i] + [next_token_ids[i]] for i in range(len(input_ids))]
241
- for _ in range(bench_args.output_len):
242
+ for _ in range(bench_args.output_len[0]):
242
243
  next_token_ids, _ = decode(next_token_ids, batch, model_runner)
243
244
  for i in range(len(reqs)):
244
245
  output_ids[i].append(next_token_ids[i])
@@ -332,6 +333,7 @@ def latency_test(
332
333
  )
333
334
 
334
335
  # Warm up
336
+ rank_print("Warmup ...")
335
337
  latency_test_run_once(
336
338
  bench_args.run_name,
337
339
  model_runner,
@@ -341,6 +343,7 @@ def latency_test(
341
343
  bench_args.input_len[0],
342
344
  4, # shorter decoding to speed up the warmup
343
345
  )
346
+ rank_print("Benchmark ...")
344
347
 
345
348
  # Run the sweep
346
349
  result_list = []
sglang/bench_serving.py CHANGED
@@ -24,7 +24,7 @@ import warnings
24
24
  from argparse import ArgumentParser
25
25
  from dataclasses import dataclass, field
26
26
  from datetime import datetime
27
- from typing import AsyncGenerator, List, Optional, Tuple, Union
27
+ from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
28
28
 
29
29
  import aiohttp
30
30
  import numpy as np
@@ -39,6 +39,8 @@ from transformers import (
39
39
 
40
40
  AIOHTTP_TIMEOUT = aiohttp.ClientTimeout(total=6 * 60 * 60)
41
41
 
42
+ global args
43
+
42
44
 
43
45
  @dataclass
44
46
  class RequestFuncInput:
@@ -47,6 +49,7 @@ class RequestFuncInput:
47
49
  prompt_len: int
48
50
  output_len: int
49
51
  model: str
52
+ extra_request_body: Dict[str, Any]
50
53
 
51
54
 
52
55
  @dataclass
@@ -84,6 +87,7 @@ async def async_request_trt_llm(
84
87
  "stream": True,
85
88
  "min_length": request_func_input.output_len,
86
89
  "end_id": 1048576,
90
+ **request_func_input.extra_request_body,
87
91
  }
88
92
  if args.disable_ignore_eos:
89
93
  del payload["min_length"]
@@ -154,6 +158,7 @@ async def async_request_openai_completions(
154
158
  "max_tokens": request_func_input.output_len,
155
159
  "stream": not args.disable_stream,
156
160
  "ignore_eos": not args.disable_ignore_eos,
161
+ **request_func_input.extra_request_body,
157
162
  }
158
163
  headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
159
164
 
@@ -192,7 +197,8 @@ async def async_request_openai_completions(
192
197
  output.ttft = ttft
193
198
 
194
199
  # Decoding phase
195
- output.itl.append(timestamp - most_recent_timestamp)
200
+ else:
201
+ output.itl.append(timestamp - most_recent_timestamp)
196
202
 
197
203
  most_recent_timestamp = timestamp
198
204
  generated_text += data["choices"][0]["text"]
@@ -542,6 +548,7 @@ async def benchmark(
542
548
  request_rate: float,
543
549
  disable_tqdm: bool,
544
550
  enable_multi: bool,
551
+ extra_request_body: Dict[str, Any],
545
552
  ):
546
553
  if backend in ASYNC_REQUEST_FUNCS:
547
554
  request_func = ASYNC_REQUEST_FUNCS[backend]
@@ -556,6 +563,7 @@ async def benchmark(
556
563
  api_url=api_url,
557
564
  prompt_len=test_prompt_len,
558
565
  output_len=test_output_len,
566
+ extra_request_body=extra_request_body,
559
567
  )
560
568
  test_output = await request_func(request_func_input=test_input)
561
569
  if not test_output.success:
@@ -578,6 +586,7 @@ async def benchmark(
578
586
  api_url=api_url,
579
587
  prompt_len=prompt_len,
580
588
  output_len=output_len,
589
+ extra_request_body=extra_request_body,
581
590
  )
582
591
  tasks.append(
583
592
  asyncio.create_task(
@@ -660,19 +669,20 @@ async def benchmark(
660
669
  "backend": args.backend,
661
670
  "dataset_name": args.dataset_name,
662
671
  "request_rate": request_rate,
663
- "total_input": metrics.total_input,
664
- "total_output": metrics.total_output,
665
- "total_output_retokenized": metrics.total_output_retokenized,
666
- "mean_e2e_latency": metrics.mean_e2e_latency_ms,
667
- "median_e2e_latency": metrics.median_e2e_latency_ms,
668
- "median_ttft": metrics.median_ttft_ms,
669
- "median_itl": metrics.median_itl_ms,
670
- "output_token_throughput": metrics.output_throughput,
672
+ "total_input_tokens": metrics.total_input,
673
+ "total_output_tokens": metrics.total_output,
674
+ "total_output_tokens_retokenized": metrics.total_output_retokenized,
675
+ "mean_e2e_latency_ms": metrics.mean_e2e_latency_ms,
676
+ "median_e2e_latency_ms": metrics.median_e2e_latency_ms,
677
+ "median_ttft_ms": metrics.median_ttft_ms,
678
+ "median_itl_ms": metrics.median_itl_ms,
679
+ "output_throughput": metrics.output_throughput,
671
680
  "sharegpt_output_len": args.sharegpt_output_len,
672
681
  "random_input_len": args.random_input_len,
673
682
  "random_output_len": args.random_output_len,
674
683
  "random_range_ratio": args.random_range_ratio,
675
- "benchmark_duration": benchmark_duration,
684
+ "duration": benchmark_duration,
685
+ "completed": metrics.completed,
676
686
  }
677
687
  else:
678
688
  print(f"Error running benchmark for request rate: {request_rate}")
@@ -742,10 +752,18 @@ def check_chat_template(model_path):
742
752
  return False
743
753
 
744
754
 
745
- def fire(args: argparse.Namespace):
755
+ def run_benchmark(args_: argparse.Namespace):
756
+ global args
757
+ args = args_
758
+
759
+ set_ulimit()
746
760
  random.seed(args.seed)
747
761
  np.random.seed(args.seed)
748
762
 
763
+ extra_request_body = {}
764
+ if args.extra_request_body:
765
+ extra_request_body = json.loads(args.extra_request_body)
766
+
749
767
  if args.port is None:
750
768
  args.port = {
751
769
  "sglang": 30000,
@@ -838,10 +856,11 @@ def fire(args: argparse.Namespace):
838
856
  request_rate=rate,
839
857
  disable_tqdm=args.disable_tqdm,
840
858
  enable_multi=args.multi,
859
+ extra_request_body=extra_request_body,
841
860
  )
842
861
  )
843
862
  else:
844
- asyncio.run(
863
+ return asyncio.run(
845
864
  benchmark(
846
865
  backend=backend,
847
866
  api_url=api_url,
@@ -851,6 +870,7 @@ def fire(args: argparse.Namespace):
851
870
  request_rate=args.request_rate,
852
871
  disable_tqdm=args.disable_tqdm,
853
872
  enable_multi=args.multi,
873
+ extra_request_body=extra_request_body,
854
874
  )
855
875
  )
856
876
 
@@ -949,11 +969,6 @@ if __name__ == "__main__":
949
969
  "Otherwise, we use Poisson process to synthesize the request arrival times. Default is 128.0.",
950
970
  )
951
971
  parser.add_argument("--seed", type=int, default=0, help="Default is 0.")
952
- parser.add_argument(
953
- "--disable-tqdm",
954
- action="store_true",
955
- help="Specify to disable tqdm progress bar.",
956
- )
957
972
  parser.add_argument(
958
973
  "--multi",
959
974
  action="store_true",
@@ -966,6 +981,11 @@ if __name__ == "__main__":
966
981
  help="Range of request rates in the format start,stop,step. Default is 2,34,2. It also supports a list of request rates, requiring the parameters to not equal three.",
967
982
  )
968
983
  parser.add_argument("--output-file", type=str, help="Output JSONL file name.")
984
+ parser.add_argument(
985
+ "--disable-tqdm",
986
+ action="store_true",
987
+ help="Specify to disable tqdm progress bar.",
988
+ )
969
989
  parser.add_argument(
970
990
  "--disable-stream",
971
991
  action="store_true",
@@ -976,8 +996,12 @@ if __name__ == "__main__":
976
996
  action="store_true",
977
997
  help="Disable ignoring EOS.",
978
998
  )
979
-
980
- set_ulimit()
981
-
999
+ parser.add_argument(
1000
+ "--extra-request-body",
1001
+ metavar='{"key1": "value1", "key2": "value2"}',
1002
+ type=str,
1003
+ help="Append given JSON object to the request payload. You can use this to specify"
1004
+ "additional generate params like sampling params.",
1005
+ )
982
1006
  args = parser.parse_args()
983
- fire(args)
1007
+ run_benchmark(args)
sglang/global_config.py CHANGED
@@ -27,7 +27,7 @@ class GlobalConfig:
27
27
  # Runtime constants: others
28
28
  self.num_continue_decode_steps = 10
29
29
  self.retract_decode_steps = 20
30
- self.flashinfer_workspace_size = 192 * 1024 * 1024
30
+ self.flashinfer_workspace_size = 384 * 1024 * 1024
31
31
 
32
32
  # Output tokenization configs
33
33
  self.skip_special_tokens_in_output = True
@@ -1,21 +1,23 @@
1
1
  import json
2
+ import warnings
2
3
  from typing import List, Optional
3
4
 
4
5
  from sglang.global_config import global_config
5
6
  from sglang.lang.backend.base_backend import BaseBackend
6
7
  from sglang.lang.chat_template import get_chat_template_by_model_path
7
- from sglang.lang.choices import (
8
- ChoicesDecision,
9
- ChoicesSamplingMethod,
10
- token_length_normalized,
11
- )
8
+ from sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod
12
9
  from sglang.lang.interpreter import StreamExecutor
13
- from sglang.lang.ir import SglSamplingParams
10
+ from sglang.lang.ir import (
11
+ REGEX_BOOL,
12
+ REGEX_FLOAT,
13
+ REGEX_INT,
14
+ REGEX_STR,
15
+ SglSamplingParams,
16
+ )
14
17
  from sglang.utils import http_request
15
18
 
16
19
 
17
20
  class RuntimeEndpoint(BaseBackend):
18
-
19
21
  def __init__(
20
22
  self,
21
23
  base_url: str,
@@ -95,32 +97,52 @@ class RuntimeEndpoint(BaseBackend):
95
97
  )
96
98
  self._assert_success(res)
97
99
 
100
+ def _handle_dtype_to_regex(self, sampling_params: SglSamplingParams):
101
+ if sampling_params.dtype is None:
102
+ return
103
+
104
+ if sampling_params.stop == ():
105
+ sampling_params.stop = []
106
+
107
+ dtype_regex = None
108
+ if sampling_params.dtype in ["int", int]:
109
+
110
+ dtype_regex = REGEX_INT
111
+ sampling_params.stop.extend([" ", "\n"])
112
+ elif sampling_params.dtype in ["float", float]:
113
+
114
+ dtype_regex = REGEX_FLOAT
115
+ sampling_params.stop.extend([" ", "\n"])
116
+ elif sampling_params.dtype in ["str", str]:
117
+
118
+ dtype_regex = REGEX_STR
119
+ elif sampling_params.dtype in ["bool", bool]:
120
+
121
+ dtype_regex = REGEX_BOOL
122
+ else:
123
+ raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
124
+
125
+ if dtype_regex is not None and sampling_params.regex is not None:
126
+ warnings.warn(
127
+ f"Both dtype and regex are set. Only dtype will be used. dtype: {sampling_params.dtype}, regex: {sampling_params.regex}"
128
+ )
129
+
130
+ sampling_params.regex = dtype_regex
131
+
98
132
  def generate(
99
133
  self,
100
134
  s: StreamExecutor,
101
135
  sampling_params: SglSamplingParams,
102
136
  ):
103
- if sampling_params.dtype is None:
104
- data = {
105
- "text": s.text_,
106
- "sampling_params": {
107
- "skip_special_tokens": global_config.skip_special_tokens_in_output,
108
- "spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
109
- **sampling_params.to_srt_kwargs(),
110
- },
111
- }
112
- elif sampling_params.dtype in [int, "int"]:
113
- data = {
114
- "text": s.text_,
115
- "sampling_params": {
116
- "skip_special_tokens": global_config.skip_special_tokens_in_output,
117
- "spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
118
- "dtype": "int",
119
- **sampling_params.to_srt_kwargs(),
120
- },
121
- }
122
- else:
123
- raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
137
+ self._handle_dtype_to_regex(sampling_params)
138
+ data = {
139
+ "text": s.text_,
140
+ "sampling_params": {
141
+ "skip_special_tokens": global_config.skip_special_tokens_in_output,
142
+ "spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
143
+ **sampling_params.to_srt_kwargs(),
144
+ },
145
+ }
124
146
 
125
147
  for item in [
126
148
  "return_logprob",
@@ -151,27 +173,16 @@ class RuntimeEndpoint(BaseBackend):
151
173
  s: StreamExecutor,
152
174
  sampling_params: SglSamplingParams,
153
175
  ):
154
- if sampling_params.dtype is None:
155
- data = {
156
- "text": s.text_,
157
- "sampling_params": {
158
- "skip_special_tokens": global_config.skip_special_tokens_in_output,
159
- "spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
160
- **sampling_params.to_srt_kwargs(),
161
- },
162
- }
163
- elif sampling_params.dtype in [int, "int"]:
164
- data = {
165
- "text": s.text_,
166
- "sampling_params": {
167
- "skip_special_tokens": global_config.skip_special_tokens_in_output,
168
- "spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
169
- "dtype": "int",
170
- **sampling_params.to_srt_kwargs(),
171
- },
172
- }
173
- else:
174
- raise RuntimeError(f"Invalid dtype: {sampling_params.dtype}")
176
+ self._handle_dtype_to_regex(sampling_params)
177
+
178
+ data = {
179
+ "text": s.text_,
180
+ "sampling_params": {
181
+ "skip_special_tokens": global_config.skip_special_tokens_in_output,
182
+ "spaces_between_special_tokens": global_config.spaces_between_special_tokens_in_out,
183
+ **sampling_params.to_srt_kwargs(),
184
+ },
185
+ }
175
186
 
176
187
  for item in [
177
188
  "return_logprob",
sglang/lang/compiler.py CHANGED
@@ -125,7 +125,7 @@ class CompiledFunction:
125
125
  def run(
126
126
  self,
127
127
  *,
128
- max_new_tokens: int = 16,
128
+ max_new_tokens: int = 128,
129
129
  stop: Union[str, List[str]] = (),
130
130
  temperature: float = 1.0,
131
131
  top_p: float = 1.0,
@@ -155,7 +155,7 @@ class CompiledFunction:
155
155
  self,
156
156
  batch_kwargs,
157
157
  *,
158
- max_new_tokens: int = 16,
158
+ max_new_tokens: int = 128,
159
159
  stop: Union[str, List[str]] = (),
160
160
  temperature: float = 1.0,
161
161
  top_p: float = 1.0,
@@ -20,7 +20,6 @@ from sglang.lang.ir import (
20
20
  SglConstantText,
21
21
  SglExpr,
22
22
  SglExprList,
23
- SglFunction,
24
23
  SglGen,
25
24
  SglImage,
26
25
  SglRoleBegin,
@@ -181,8 +180,10 @@ class StreamExecutor:
181
180
  num_api_spec_tokens=None,
182
181
  use_thread=True,
183
182
  ):
183
+ from sglang.lang.backend.base_backend import BaseBackend
184
+
184
185
  self.sid = uuid.uuid4().hex
185
- self.backend = backend
186
+ self.backend: BaseBackend = backend
186
187
  self.arguments: Dict[str, Any] = arguments
187
188
  self.default_sampling_para = default_sampling_para
188
189
  self.stream = stream
@@ -658,6 +659,7 @@ class StreamExecutor:
658
659
  for item in [
659
660
  "max_new_tokens",
660
661
  "stop",
662
+ "stop_token_ids",
661
663
  "temperature",
662
664
  "top_p",
663
665
  "top_k",
sglang/lang/ir.py CHANGED
@@ -8,16 +8,17 @@ from typing import List, Optional, Union
8
8
  from sglang.global_config import global_config
9
9
  from sglang.lang.choices import ChoicesSamplingMethod
10
10
 
11
- REGEX_INT = r"[-+]?[0-9]+"
12
- REGEX_FLOAT = r"[-+]?[0-9]*\.?[0-9]+"
11
+ REGEX_INT = r"[-+]?[0-9]+[ \n]*"
12
+ REGEX_FLOAT = r"[-+]?[0-9]*\.?[0-9]+[ \n]*"
13
13
  REGEX_BOOL = r"(True|False)"
14
- REGEX_STRING = r"\"[\w\d\s]*\"" # bugs with regex r"\".*\"" in interegular pkg
14
+ REGEX_STR = r"\"[\w\d\s]*\"" # bugs with regex r"\".*\"" in interegular pkg
15
15
 
16
16
 
17
17
  @dataclasses.dataclass
18
18
  class SglSamplingParams:
19
- max_new_tokens: int = 16
19
+ max_new_tokens: int = 128
20
20
  stop: Union[str, List[str]] = ()
21
+ stop_token_ids: Optional[List[int]] = ()
21
22
  temperature: float = 1.0
22
23
  top_p: float = 1.0
23
24
  top_k: int = -1 # -1 means disable
@@ -37,6 +38,7 @@ class SglSamplingParams:
37
38
  return SglSamplingParams(
38
39
  self.max_new_tokens,
39
40
  self.stop,
41
+ self.stop_token_ids,
40
42
  self.temperature,
41
43
  self.top_p,
42
44
  self.top_k,
@@ -108,6 +110,7 @@ class SglSamplingParams:
108
110
  return {
109
111
  "max_new_tokens": self.max_new_tokens,
110
112
  "stop": self.stop,
113
+ "stop_token_ids": self.stop_token_ids,
111
114
  "temperature": self.temperature,
112
115
  "top_p": self.top_p,
113
116
  "top_k": self.top_k,
@@ -140,8 +143,9 @@ class SglFunction:
140
143
  def run(
141
144
  self,
142
145
  *args,
143
- max_new_tokens: int = 16,
144
- stop: Union[str, List[str]] = (),
146
+ max_new_tokens: int = 128,
147
+ stop: Union[str, List[str]] = [],
148
+ stop_token_ids: Optional[List[int]] = [],
145
149
  temperature: float = 1.0,
146
150
  top_p: float = 1.0,
147
151
  top_k: int = -1,
@@ -161,6 +165,7 @@ class SglFunction:
161
165
  default_sampling_para = SglSamplingParams(
162
166
  max_new_tokens=max_new_tokens,
163
167
  stop=stop,
168
+ stop_token_ids=stop_token_ids,
164
169
  temperature=temperature,
165
170
  top_p=top_p,
166
171
  top_k=top_k,
@@ -179,8 +184,9 @@ class SglFunction:
179
184
  self,
180
185
  batch_kwargs,
181
186
  *,
182
- max_new_tokens: int = 16,
187
+ max_new_tokens: int = 128,
183
188
  stop: Union[str, List[str]] = (),
189
+ stop_token_ids: Optional[List[int]] = [],
184
190
  temperature: float = 1.0,
185
191
  top_p: float = 1.0,
186
192
  top_k: int = -1,
@@ -218,6 +224,7 @@ class SglFunction:
218
224
  default_sampling_para = SglSamplingParams(
219
225
  max_new_tokens=max_new_tokens,
220
226
  stop=stop,
227
+ stop_token_ids=stop_token_ids,
221
228
  temperature=temperature,
222
229
  top_p=top_p,
223
230
  top_k=top_k,
@@ -397,6 +404,7 @@ class SglGen(SglExpr):
397
404
  name: Optional[str] = None,
398
405
  max_new_tokens: Optional[int] = None,
399
406
  stop: Optional[Union[str, List[str]]] = None,
407
+ stop_token_ids: Optional[List[int]] = None,
400
408
  temperature: Optional[float] = None,
401
409
  top_p: Optional[float] = None,
402
410
  top_k: Optional[int] = None,
@@ -416,6 +424,7 @@ class SglGen(SglExpr):
416
424
  self.sampling_params = SglSamplingParams(
417
425
  max_new_tokens=max_new_tokens,
418
426
  stop=stop,
427
+ stop_token_ids=stop_token_ids,
419
428
  temperature=temperature,
420
429
  top_p=top_p,
421
430
  top_k=top_k,
@@ -54,7 +54,7 @@ class BaseToolCache:
54
54
  return val
55
55
 
56
56
  def init_value(self, key):
57
- raise NotImplementedError
57
+ raise NotImplementedError()
58
58
 
59
59
  def get_cache_hit_rate(self):
60
60
  if self.metrics["total"] == 0:
@@ -20,10 +20,20 @@ from sglang.srt.constrained.base_tool_cache import BaseToolCache
20
20
 
21
21
 
22
22
  class FSMCache(BaseToolCache):
23
- def __init__(self, tokenizer_path, tokenizer_args_dict, enable=True):
23
+ def __init__(
24
+ self,
25
+ tokenizer_path,
26
+ tokenizer_args_dict,
27
+ enable=True,
28
+ skip_tokenizer_init=False,
29
+ ):
24
30
  super().__init__(enable=enable)
25
31
 
26
- if tokenizer_path.endswith(".json") or tokenizer_path.endswith(".model"):
32
+ if (
33
+ skip_tokenizer_init
34
+ or tokenizer_path.endswith(".json")
35
+ or tokenizer_path.endswith(".model")
36
+ ):
27
37
  # Do not support TiktokenTokenizer or SentencePieceTokenizer
28
38
  return
29
39
 
@@ -62,16 +62,22 @@ class JumpForwardMap:
62
62
  id_to_symbol.setdefault(id_, []).append(symbol)
63
63
 
64
64
  transitions = fsm_info.transitions
65
+
65
66
  outgoings_ct = defaultdict(int)
66
- state_to_jump_forward = {}
67
+ # NOTE(lsyin): Final states can lead to terminate, so they have one outgoing edge naturally
68
+ for s in fsm_info.finals:
69
+ outgoings_ct[s] = 1
67
70
 
71
+ state_to_jump_forward = {}
68
72
  for (state, id_), next_state in transitions.items():
69
73
  if id_ == fsm_info.alphabet_anything_value:
74
+ # Arbitrarily symbol cannot be recognized as jump forward
70
75
  continue
76
+
71
77
  symbols = id_to_symbol[id_]
72
78
  for c in symbols:
73
79
  if len(c) > 1:
74
- # Skip byte level transitions
80
+ # Skip byte level transitions like c = "5E"
75
81
  continue
76
82
 
77
83
  outgoings_ct[state] += 1
@@ -87,6 +93,9 @@ class JumpForwardMap:
87
93
 
88
94
  # Process the byte level jump forward
89
95
  outgoings_ct = defaultdict(int)
96
+ for s in fsm_info.finals:
97
+ outgoings_ct[s] = 1
98
+
90
99
  for (state, id_), next_state in transitions.items():
91
100
  if id_ == fsm_info.alphabet_anything_value:
92
101
  continue
@@ -177,3 +186,5 @@ if __name__ == "__main__":
177
186
  test_main(r"霍格沃茨特快列车|霍比特人比尔博")
178
187
  # 霍格: \xe9\x9c\x8d \xe6\xa0\xbc ...
179
188
  # 霍比: \xe9\x9c\x8d \xe6\xaf\x94 ...
189
+
190
+ test_main(r"[-+]?[0-9]+[ ]*")