sglang 0.3.1.post3__py3-none-any.whl → 0.3.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (92) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +23 -1
  3. sglang/bench_latency.py +48 -33
  4. sglang/bench_server_latency.py +0 -6
  5. sglang/bench_serving.py +2 -2
  6. sglang/lang/backend/runtime_endpoint.py +14 -1
  7. sglang/lang/interpreter.py +16 -6
  8. sglang/lang/ir.py +20 -4
  9. sglang/srt/configs/model_config.py +11 -9
  10. sglang/srt/constrained/fsm_cache.py +9 -1
  11. sglang/srt/constrained/jump_forward.py +15 -2
  12. sglang/srt/hf_transformers_utils.py +1 -0
  13. sglang/srt/layers/activation.py +4 -4
  14. sglang/srt/layers/attention/__init__.py +49 -0
  15. sglang/srt/layers/attention/flashinfer_backend.py +277 -0
  16. sglang/srt/layers/{flashinfer_utils.py → attention/flashinfer_utils.py} +82 -80
  17. sglang/srt/layers/attention/triton_backend.py +161 -0
  18. sglang/srt/layers/{triton_attention → attention/triton_ops}/extend_attention.py +3 -1
  19. sglang/srt/layers/fused_moe/patch.py +117 -0
  20. sglang/srt/layers/layernorm.py +4 -4
  21. sglang/srt/layers/logits_processor.py +19 -15
  22. sglang/srt/layers/pooler.py +3 -3
  23. sglang/srt/layers/quantization/__init__.py +0 -2
  24. sglang/srt/layers/radix_attention.py +6 -4
  25. sglang/srt/layers/sampler.py +6 -4
  26. sglang/srt/layers/torchao_utils.py +18 -0
  27. sglang/srt/lora/lora.py +20 -21
  28. sglang/srt/lora/lora_manager.py +97 -25
  29. sglang/srt/managers/detokenizer_manager.py +31 -18
  30. sglang/srt/managers/image_processor.py +187 -0
  31. sglang/srt/managers/io_struct.py +99 -75
  32. sglang/srt/managers/schedule_batch.py +187 -68
  33. sglang/srt/managers/{policy_scheduler.py → schedule_policy.py} +31 -21
  34. sglang/srt/managers/scheduler.py +1021 -0
  35. sglang/srt/managers/tokenizer_manager.py +120 -247
  36. sglang/srt/managers/tp_worker.py +28 -925
  37. sglang/srt/mem_cache/memory_pool.py +34 -52
  38. sglang/srt/mem_cache/radix_cache.py +5 -5
  39. sglang/srt/model_executor/cuda_graph_runner.py +25 -25
  40. sglang/srt/model_executor/forward_batch_info.py +94 -97
  41. sglang/srt/model_executor/model_runner.py +76 -78
  42. sglang/srt/models/baichuan.py +10 -10
  43. sglang/srt/models/chatglm.py +12 -12
  44. sglang/srt/models/commandr.py +10 -10
  45. sglang/srt/models/dbrx.py +12 -12
  46. sglang/srt/models/deepseek.py +10 -10
  47. sglang/srt/models/deepseek_v2.py +14 -15
  48. sglang/srt/models/exaone.py +10 -10
  49. sglang/srt/models/gemma.py +10 -10
  50. sglang/srt/models/gemma2.py +11 -11
  51. sglang/srt/models/gpt_bigcode.py +10 -10
  52. sglang/srt/models/grok.py +10 -10
  53. sglang/srt/models/internlm2.py +10 -10
  54. sglang/srt/models/llama.py +22 -10
  55. sglang/srt/models/llama_classification.py +5 -5
  56. sglang/srt/models/llama_embedding.py +4 -4
  57. sglang/srt/models/llama_reward.py +142 -0
  58. sglang/srt/models/llava.py +39 -33
  59. sglang/srt/models/llavavid.py +31 -28
  60. sglang/srt/models/minicpm.py +10 -10
  61. sglang/srt/models/minicpm3.py +14 -15
  62. sglang/srt/models/mixtral.py +10 -10
  63. sglang/srt/models/mixtral_quant.py +10 -10
  64. sglang/srt/models/olmoe.py +10 -10
  65. sglang/srt/models/qwen.py +10 -10
  66. sglang/srt/models/qwen2.py +11 -11
  67. sglang/srt/models/qwen2_moe.py +10 -10
  68. sglang/srt/models/stablelm.py +10 -10
  69. sglang/srt/models/torch_native_llama.py +506 -0
  70. sglang/srt/models/xverse.py +10 -10
  71. sglang/srt/models/xverse_moe.py +10 -10
  72. sglang/srt/openai_api/adapter.py +7 -0
  73. sglang/srt/sampling/sampling_batch_info.py +36 -27
  74. sglang/srt/sampling/sampling_params.py +3 -1
  75. sglang/srt/server.py +170 -119
  76. sglang/srt/server_args.py +54 -27
  77. sglang/srt/utils.py +101 -128
  78. sglang/test/runners.py +76 -33
  79. sglang/test/test_programs.py +38 -5
  80. sglang/test/test_utils.py +53 -9
  81. sglang/version.py +1 -1
  82. {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/METADATA +42 -23
  83. sglang-0.3.3.dist-info/RECORD +139 -0
  84. sglang/srt/layers/attention_backend.py +0 -482
  85. sglang/srt/managers/controller_multi.py +0 -207
  86. sglang/srt/managers/controller_single.py +0 -164
  87. sglang-0.3.1.post3.dist-info/RECORD +0 -134
  88. /sglang/srt/layers/{triton_attention → attention/triton_ops}/decode_attention.py +0 -0
  89. /sglang/srt/layers/{triton_attention → attention/triton_ops}/prefill_attention.py +0 -0
  90. {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/LICENSE +0 -0
  91. {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/WHEEL +0 -0
  92. {sglang-0.3.1.post3.dist-info → sglang-0.3.3.dist-info}/top_level.txt +0 -0
sglang/__init__.py CHANGED
@@ -1,6 +1,7 @@
1
1
  # SGL API Components
2
2
 
3
3
  from sglang.api import (
4
+ Engine,
4
5
  Runtime,
5
6
  assistant,
6
7
  assistant_begin,
@@ -31,6 +32,7 @@ from sglang.lang.choices import (
31
32
  # SGLang DSL APIs
32
33
  __all__ = [
33
34
  "Runtime",
35
+ "Engine",
34
36
  "assistant",
35
37
  "assistant_begin",
36
38
  "assistant_end",
sglang/api.py CHANGED
@@ -33,13 +33,23 @@ def function(
33
33
 
34
34
 
35
35
  def Runtime(*args, **kwargs):
36
- # Avoid importing unnecessary dependency
37
36
  os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
37
+
38
+ # Avoid importing unnecessary dependency
38
39
  from sglang.srt.server import Runtime
39
40
 
40
41
  return Runtime(*args, **kwargs)
41
42
 
42
43
 
44
+ def Engine(*args, **kwargs):
45
+ os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
46
+
47
+ # Avoid importing unnecessary dependency
48
+ from sglang.srt.server import Engine
49
+
50
+ return Engine(*args, **kwargs)
51
+
52
+
43
53
  def set_default_backend(backend: BaseBackend):
44
54
  global_config.default_backend = backend
45
55
 
@@ -48,6 +58,10 @@ def flush_cache(backend: Optional[BaseBackend] = None):
48
58
  backend = backend or global_config.default_backend
49
59
  if backend is None:
50
60
  return False
61
+
62
+ # If backend is Runtime
63
+ if hasattr(backend, "endpoint"):
64
+ backend = backend.endpoint
51
65
  return backend.flush_cache()
52
66
 
53
67
 
@@ -55,12 +69,17 @@ def get_server_args(backend: Optional[BaseBackend] = None):
55
69
  backend = backend or global_config.default_backend
56
70
  if backend is None:
57
71
  return None
72
+
73
+ # If backend is Runtime
74
+ if hasattr(backend, "endpoint"):
75
+ backend = backend.endpoint
58
76
  return backend.get_server_args()
59
77
 
60
78
 
61
79
  def gen(
62
80
  name: Optional[str] = None,
63
81
  max_tokens: Optional[int] = None,
82
+ min_tokens: Optional[int] = None,
64
83
  stop: Optional[Union[str, List[str]]] = None,
65
84
  stop_token_ids: Optional[List[int]] = None,
66
85
  temperature: Optional[float] = None,
@@ -100,6 +119,7 @@ def gen(
100
119
  return SglGen(
101
120
  name,
102
121
  max_tokens,
122
+ min_tokens,
103
123
  stop,
104
124
  stop_token_ids,
105
125
  temperature,
@@ -139,6 +159,7 @@ def gen_int(
139
159
  return SglGen(
140
160
  name,
141
161
  max_tokens,
162
+ None,
142
163
  stop,
143
164
  stop_token_ids,
144
165
  temperature,
@@ -177,6 +198,7 @@ def gen_string(
177
198
  return SglGen(
178
199
  name,
179
200
  max_tokens,
201
+ None,
180
202
  stop,
181
203
  stop_token_ids,
182
204
  temperature,
sglang/bench_latency.py CHANGED
@@ -47,6 +47,7 @@ I'm going to the park
47
47
  import argparse
48
48
  import dataclasses
49
49
  import itertools
50
+ import json
50
51
  import logging
51
52
  import multiprocessing
52
53
  import os
@@ -62,10 +63,11 @@ import torch.distributed as dist
62
63
  from sglang.srt.configs.model_config import ModelConfig
63
64
  from sglang.srt.hf_transformers_utils import get_tokenizer
64
65
  from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
66
+ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
65
67
  from sglang.srt.model_executor.model_runner import ModelRunner
66
68
  from sglang.srt.sampling.sampling_params import SamplingParams
67
69
  from sglang.srt.server import _set_envs_and_config
68
- from sglang.srt.server_args import ServerArgs
70
+ from sglang.srt.server_args import PortArgs, ServerArgs
69
71
  from sglang.srt.utils import (
70
72
  configure_logger,
71
73
  kill_child_process,
@@ -121,7 +123,7 @@ class BenchArgs:
121
123
  )
122
124
 
123
125
 
124
- def load_model(server_args, tp_rank):
126
+ def load_model(server_args, port_args, tp_rank):
125
127
  suppress_other_loggers()
126
128
  rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
127
129
 
@@ -129,6 +131,7 @@ def load_model(server_args, tp_rank):
129
131
  server_args.model_path,
130
132
  server_args.trust_remote_code,
131
133
  context_length=server_args.context_length,
134
+ model_override_args=json.loads(server_args.json_model_override_args),
132
135
  )
133
136
  model_runner = ModelRunner(
134
137
  model_config=model_config,
@@ -136,7 +139,7 @@ def load_model(server_args, tp_rank):
136
139
  gpu_id=tp_rank,
137
140
  tp_rank=tp_rank,
138
141
  tp_size=server_args.tp_size,
139
- nccl_port=28888,
142
+ nccl_port=port_args.nccl_ports[0],
140
143
  server_args=server_args,
141
144
  )
142
145
  rank_print(f"max_total_num_tokens={model_runner.max_total_num_tokens}")
@@ -167,9 +170,13 @@ def prepare_inputs_for_correctness_test(bench_args, tokenizer):
167
170
  assert len(input_ids[i]) > bench_args.cut_len
168
171
 
169
172
  tmp_input_ids = input_ids[i][: bench_args.cut_len]
170
- req = Req(rid=i, origin_input_text=prompts[i], origin_input_ids=tmp_input_ids)
173
+ req = Req(
174
+ rid=i,
175
+ origin_input_text=prompts[i],
176
+ origin_input_ids=tmp_input_ids,
177
+ sampling_params=sampling_params,
178
+ )
171
179
  req.prefix_indices = []
172
- req.sampling_params = sampling_params
173
180
  req.fill_ids = req.origin_input_ids
174
181
  req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
175
182
  reqs.append(req)
@@ -199,9 +206,13 @@ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
199
206
 
200
207
  reqs = []
201
208
  for i in range(len(input_ids)):
202
- req = Req(rid=i, origin_input_text="", origin_input_ids=list(input_ids[i]))
209
+ req = Req(
210
+ rid=i,
211
+ origin_input_text="",
212
+ origin_input_ids=list(input_ids[i]),
213
+ sampling_params=sampling_params,
214
+ )
203
215
  req.prefix_indices = []
204
- req.sampling_params = sampling_params
205
216
  req.fill_ids = req.origin_input_ids
206
217
  req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
207
218
  reqs.append(req)
@@ -217,28 +228,33 @@ def extend(reqs, model_runner):
217
228
  tree_cache=None,
218
229
  )
219
230
  batch.prepare_for_extend(model_runner.model_config.vocab_size)
220
- logits_output = model_runner.forward(batch)
221
- next_token_ids = model_runner.sample(logits_output, batch).tolist()
231
+ model_worker_batch = batch.get_model_worker_batch()
232
+ forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
233
+ logits_output = model_runner.forward(forward_batch)
234
+ next_token_ids = model_runner.sample(logits_output, forward_batch).tolist()
222
235
  return next_token_ids, logits_output.next_token_logits, batch
223
236
 
224
237
 
225
238
  def decode(input_token_ids, batch, model_runner):
226
239
  batch.prepare_for_decode(input_token_ids)
227
- logits_output = model_runner.forward(batch)
228
- next_token_ids = model_runner.sample(logits_output, batch).tolist()
240
+ model_worker_batch = batch.get_model_worker_batch()
241
+ forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
242
+ logits_output = model_runner.forward(forward_batch)
243
+ next_token_ids = model_runner.sample(logits_output, forward_batch).tolist()
229
244
  return next_token_ids, logits_output.next_token_logits
230
245
 
231
246
 
232
247
  @torch.inference_mode()
233
248
  def correctness_test(
234
249
  server_args,
250
+ port_args,
235
251
  bench_args,
236
252
  tp_rank,
237
253
  ):
238
254
  rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
239
255
 
240
256
  # Load the model
241
- model_runner, tokenizer = load_model(server_args, tp_rank)
257
+ model_runner, tokenizer = load_model(server_args, port_args, tp_rank)
242
258
 
243
259
  # Prepare inputs
244
260
  input_ids, reqs = prepare_inputs_for_correctness_test(bench_args, tokenizer)
@@ -260,7 +276,7 @@ def correctness_test(
260
276
 
261
277
  # Decode
262
278
  output_ids = [input_ids[i] + [next_token_ids[i]] for i in range(len(input_ids))]
263
- for _ in range(bench_args.output_len[0]):
279
+ for _ in range(bench_args.output_len[0] - 1):
264
280
  next_token_ids, _ = decode(next_token_ids, batch, model_runner)
265
281
  for i in range(len(reqs)):
266
282
  output_ids[i].append(next_token_ids[i])
@@ -311,7 +327,7 @@ def latency_test_run_once(
311
327
 
312
328
  # Decode
313
329
  decode_latencies = []
314
- for i in range(output_len):
330
+ for i in range(output_len - 1):
315
331
  torch.cuda.synchronize()
316
332
  tic = time.time()
317
333
  next_token_ids, _ = decode(next_token_ids, batch, model_runner)
@@ -324,13 +340,16 @@ def latency_test_run_once(
324
340
  rank_print(
325
341
  f"Decode. latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s"
326
342
  )
327
- med_decode_latency = np.median(decode_latencies)
328
- med_decode_throughput = batch_size / med_decode_latency
329
- rank_print(
330
- f"Decode. median latency: {med_decode_latency:6.5f} s, median throughput: {med_decode_throughput:9.2f} token/s"
331
- )
332
- measurement_results["median_decode_latency"] = med_decode_latency
333
- measurement_results["median_decode_throughput"] = med_decode_throughput
343
+
344
+ # record decode timing from 2nd output
345
+ if output_len > 1:
346
+ med_decode_latency = np.median(decode_latencies)
347
+ med_decode_throughput = batch_size / med_decode_latency
348
+ rank_print(
349
+ f"Decode. median latency: {med_decode_latency:6.5f} s, median throughput: {med_decode_throughput:9.2f} token/s"
350
+ )
351
+ measurement_results["median_decode_latency"] = med_decode_latency
352
+ measurement_results["median_decode_throughput"] = med_decode_throughput
334
353
 
335
354
  throughput = (input_len + output_len) * batch_size / tot_latency
336
355
  rank_print(
@@ -343,15 +362,15 @@ def latency_test_run_once(
343
362
 
344
363
  def latency_test(
345
364
  server_args,
365
+ port_args,
346
366
  bench_args,
347
367
  tp_rank,
348
368
  ):
349
369
  configure_logger(server_args, prefix=f" TP{tp_rank}")
350
- _set_envs_and_config(server_args)
351
370
  rank_print = print if tp_rank == 0 else lambda *args, **kwargs: None
352
371
 
353
372
  # Load the model
354
- model_runner, tokenizer = load_model(server_args, tp_rank)
373
+ model_runner, tokenizer = load_model(server_args, port_args, tp_rank)
355
374
 
356
375
  # Prepare inputs for warm up
357
376
  reqs = prepare_synthetic_inputs_for_latency_test(
@@ -367,7 +386,7 @@ def latency_test(
367
386
  reqs,
368
387
  bench_args.batch_size[0],
369
388
  bench_args.input_len[0],
370
- 4, # shorter decoding to speed up the warmup
389
+ 8, # shorter decoding to speed up the warmup
371
390
  )
372
391
  rank_print("Benchmark ...")
373
392
 
@@ -453,6 +472,7 @@ def plot_latency_test(
453
472
 
454
473
 
455
474
  def main(server_args, bench_args):
475
+ _set_envs_and_config(server_args)
456
476
 
457
477
  if server_args.model_path:
458
478
  if bench_args.correctness_test:
@@ -468,8 +488,10 @@ def main(server_args, bench_args):
468
488
  "provide --result-filename for plotting the results"
469
489
  )
470
490
 
491
+ port_args = PortArgs.init_new(server_args)
492
+
471
493
  if server_args.tp_size == 1:
472
- work_func(server_args, bench_args, 0)
494
+ work_func(server_args, port_args, bench_args, 0)
473
495
  else:
474
496
  workers = []
475
497
  for tp_rank in range(server_args.tp_size):
@@ -477,6 +499,7 @@ def main(server_args, bench_args):
477
499
  target=work_func,
478
500
  args=(
479
501
  server_args,
502
+ port_args,
480
503
  bench_args,
481
504
  tp_rank,
482
505
  ),
@@ -491,18 +514,10 @@ def main(server_args, bench_args):
491
514
 
492
515
 
493
516
  if __name__ == "__main__":
494
- multiprocessing.set_start_method("spawn", force=True)
495
-
496
517
  parser = argparse.ArgumentParser()
497
518
  ServerArgs.add_cli_args(parser)
498
519
  BenchArgs.add_cli_args(parser)
499
- # For this script, model-path is not required
500
- assert (
501
- parser._actions[1].option_strings[0] == "--model-path"
502
- ), "options changed, this code need to be updated"
503
- parser._actions[1].required = False
504
520
  args = parser.parse_args()
505
-
506
521
  server_args = ServerArgs.from_cli_args(args)
507
522
  bench_args = BenchArgs.from_cli_args(args)
508
523
 
@@ -174,13 +174,7 @@ if __name__ == "__main__":
174
174
  parser = argparse.ArgumentParser()
175
175
  ServerArgs.add_cli_args(parser)
176
176
  BenchArgs.add_cli_args(parser)
177
- # For this script, model-path is not required
178
- assert (
179
- parser._actions[1].option_strings[0] == "--model-path"
180
- ), "options changed, this code need to be updated"
181
- parser._actions[1].required = False
182
177
  args = parser.parse_args()
183
-
184
178
  server_args = ServerArgs.from_cli_args(args)
185
179
  bench_args = BenchArgs.from_cli_args(args)
186
180
 
sglang/bench_serving.py CHANGED
@@ -845,6 +845,7 @@ def run_benchmark(args_: argparse.Namespace):
845
845
  tokenizer = get_tokenizer(tokenizer_id)
846
846
 
847
847
  if args.dataset_name == "sharegpt":
848
+ assert args.random_input_len is None and args.random_output_len is None
848
849
  input_requests = sample_sharegpt_requests(
849
850
  dataset_path=args.dataset_path,
850
851
  num_requests=args.num_prompts,
@@ -852,6 +853,7 @@ def run_benchmark(args_: argparse.Namespace):
852
853
  fixed_output_len=args.sharegpt_output_len,
853
854
  )
854
855
  elif args.dataset_name == "random":
856
+ assert args.random_input_len is not None and args.random_output_len is not None
855
857
  input_requests = sample_random_requests(
856
858
  input_len=args.random_input_len,
857
859
  output_len=args.random_output_len,
@@ -964,13 +966,11 @@ if __name__ == "__main__":
964
966
  parser.add_argument(
965
967
  "--random-input-len",
966
968
  type=int,
967
- default=1024,
968
969
  help="Number of input tokens per request, used only for random dataset.",
969
970
  )
970
971
  parser.add_argument(
971
972
  "--random-output-len",
972
973
  type=int,
973
- default=128,
974
974
  help="Number of output tokens per request, used only for random dataset.",
975
975
  )
976
976
  parser.add_argument(
@@ -235,6 +235,7 @@ class RuntimeEndpoint(BaseBackend):
235
235
  data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
236
236
  obj = self._generate_http_request(s, data)
237
237
  prompt_len = obj["meta_info"]["prompt_tokens"]
238
+ logprob_start_len = max(prompt_len - 2, 0) # For token healing
238
239
 
239
240
  # Compute logprob
240
241
  data = {
@@ -244,7 +245,8 @@ class RuntimeEndpoint(BaseBackend):
244
245
  "temperature": 0,
245
246
  },
246
247
  "return_logprob": True,
247
- "logprob_start_len": max(prompt_len - 2, 0), # for token healing
248
+ "return_text_in_logprobs": True,
249
+ "logprob_start_len": logprob_start_len,
248
250
  }
249
251
  obj = self._generate_http_request(s, data)
250
252
 
@@ -254,6 +256,17 @@ class RuntimeEndpoint(BaseBackend):
254
256
  input_token_logprobs = [r["meta_info"]["input_token_logprobs"] for r in obj]
255
257
  output_token_logprobs = [r["meta_info"]["output_token_logprobs"] for r in obj]
256
258
 
259
+ # Remove extra token if no token healing occurred
260
+ for i in range(len(input_token_logprobs)):
261
+ healed_token_str = input_token_logprobs[i][0][-1]
262
+ if s.text_.endswith(healed_token_str):
263
+ healed_token_logprob = input_token_logprobs[i][0][0]
264
+ normalized_prompt_logprobs[i] = (
265
+ normalized_prompt_logprobs[i] * len(input_token_logprobs[i])
266
+ - healed_token_logprob
267
+ ) / (len(input_token_logprobs[i]) - 1)
268
+ input_token_logprobs[i] = input_token_logprobs[i][1:]
269
+
257
270
  # Compute unconditional logprobs if required
258
271
  if choices_method.requires_unconditional_logprobs:
259
272
  input_ids = [[el[1] for el in subl] for subl in input_token_logprobs]
@@ -2,6 +2,7 @@
2
2
 
3
3
  import asyncio
4
4
  import contextvars
5
+ import copy
5
6
  import multiprocessing
6
7
  import queue
7
8
  import threading
@@ -652,9 +653,22 @@ class StreamExecutor:
652
653
  self._init_var_event(e)
653
654
 
654
655
  def _resolve_sampling_params(self, sampling_params):
655
- clone = None
656
+ """
657
+ Construct sampling param based on default + override values
658
+
659
+ The default values of sampling are populated in `default_sampling_para` via sgl.function.run(...sampling_args)
660
+ , and `sampling_params` contains the override values from sgl.gen().
661
+
662
+ Here we use default_sampling_para as the base and override the values if they exist in `sampling_params`.
663
+ It also extends the stop tokens based on the chat template.
664
+ """
665
+
666
+ # deepcopy is required because the dict has lists inside
667
+ clone = copy.deepcopy(self.default_sampling_para)
668
+
656
669
  for item in [
657
670
  "max_new_tokens",
671
+ "min_new_tokens",
658
672
  "stop",
659
673
  "stop_token_ids",
660
674
  "temperature",
@@ -674,20 +688,16 @@ class StreamExecutor:
674
688
  ]:
675
689
  value = getattr(sampling_params, item, None)
676
690
  if value is not None:
677
- if clone is None:
678
- clone = self.default_sampling_para.clone()
679
691
  setattr(clone, item, value)
680
692
 
681
693
  if self.chat_template.stop_str:
682
- if not clone:
683
- clone = self.default_sampling_para.clone()
684
694
  if clone.stop == ():
685
695
  clone.stop = []
686
696
  elif isinstance(clone.stop, str):
687
697
  clone.stop = [clone.stop]
688
698
  clone.stop += self.chat_template.stop_str
689
699
 
690
- return clone or self.default_sampling_para
700
+ return clone
691
701
 
692
702
  def __del__(self):
693
703
  self.end()
sglang/lang/ir.py CHANGED
@@ -17,6 +17,7 @@ REGEX_STR = r"\"[\w\d\s]*\"" # bugs with regex r"\".*\"" in interegular pkg
17
17
  @dataclasses.dataclass
18
18
  class SglSamplingParams:
19
19
  max_new_tokens: int = 128
20
+ min_new_tokens: int = 0
20
21
  stop: Union[str, List[str]] = ()
21
22
  stop_token_ids: Optional[List[int]] = ()
22
23
  temperature: float = 1.0
@@ -39,6 +40,7 @@ class SglSamplingParams:
39
40
  def clone(self):
40
41
  return SglSamplingParams(
41
42
  self.max_new_tokens,
43
+ self.min_new_tokens,
42
44
  self.stop,
43
45
  self.stop_token_ids,
44
46
  self.temperature,
@@ -113,6 +115,7 @@ class SglSamplingParams:
113
115
  def to_srt_kwargs(self):
114
116
  return {
115
117
  "max_new_tokens": self.max_new_tokens,
118
+ "min_new_tokens": self.min_new_tokens,
116
119
  "stop": self.stop,
117
120
  "stop_token_ids": self.stop_token_ids,
118
121
  "temperature": self.temperature,
@@ -150,8 +153,8 @@ class SglFunction:
150
153
  self,
151
154
  *args,
152
155
  max_new_tokens: int = 128,
153
- stop: Union[str, List[str]] = [],
154
- stop_token_ids: Optional[List[int]] = [],
156
+ stop: Optional[Union[str, List[str]]] = None,
157
+ stop_token_ids: Optional[List[int]] = None,
155
158
  temperature: float = 1.0,
156
159
  top_p: float = 1.0,
157
160
  top_k: int = -1,
@@ -169,6 +172,12 @@ class SglFunction:
169
172
  ):
170
173
  from sglang.lang.interpreter import run_program
171
174
 
175
+ # avoid using [] as the default arg: https://nikos7am.com/posts/mutable-default-arguments/
176
+ if stop is None:
177
+ stop = []
178
+ if stop_token_ids is None:
179
+ stop_token_ids = []
180
+
172
181
  default_sampling_para = SglSamplingParams(
173
182
  max_new_tokens=max_new_tokens,
174
183
  stop=stop,
@@ -193,8 +202,8 @@ class SglFunction:
193
202
  batch_kwargs,
194
203
  *,
195
204
  max_new_tokens: int = 128,
196
- stop: Union[str, List[str]] = (),
197
- stop_token_ids: Optional[List[int]] = [],
205
+ stop: Optional[Union[str, List[str]]] = None,
206
+ stop_token_ids: Optional[List[int]] = None,
198
207
  temperature: float = 1.0,
199
208
  top_p: float = 1.0,
200
209
  top_k: int = -1,
@@ -212,6 +221,11 @@ class SglFunction:
212
221
  ):
213
222
  from sglang.lang.interpreter import run_program_batch
214
223
 
224
+ if stop is None:
225
+ stop = []
226
+ if stop_token_ids is None:
227
+ stop_token_ids = []
228
+
215
229
  assert isinstance(batch_kwargs, (list, tuple))
216
230
  if len(batch_kwargs) == 0:
217
231
  return []
@@ -413,6 +427,7 @@ class SglGen(SglExpr):
413
427
  self,
414
428
  name: Optional[str] = None,
415
429
  max_new_tokens: Optional[int] = None,
430
+ min_new_tokens: Optional[int] = None,
416
431
  stop: Optional[Union[str, List[str]]] = None,
417
432
  stop_token_ids: Optional[List[int]] = None,
418
433
  temperature: Optional[float] = None,
@@ -435,6 +450,7 @@ class SglGen(SglExpr):
435
450
  self.name = name
436
451
  self.sampling_params = SglSamplingParams(
437
452
  max_new_tokens=max_new_tokens,
453
+ min_new_tokens=min_new_tokens,
438
454
  stop=stop,
439
455
  stop_token_ids=stop_token_ids,
440
456
  temperature=temperature,
@@ -49,13 +49,13 @@ class ModelConfig:
49
49
  if context_length is not None:
50
50
  self.context_len = context_length
51
51
  else:
52
- self.context_len = get_context_length(self.hf_config)
52
+ self.context_len = get_context_length(self.hf_text_config)
53
53
 
54
- # Unify the config keys for hf_config
54
+ # Unify the config keys for hf_text_config
55
55
  self.head_dim = getattr(
56
- self.hf_config,
56
+ self.hf_text_config,
57
57
  "head_dim",
58
- self.hf_config.hidden_size // self.hf_config.num_attention_heads,
58
+ self.hf_text_config.hidden_size // self.hf_text_config.num_attention_heads,
59
59
  )
60
60
 
61
61
  # FIXME: temporary special judge for deepseek v2 MLA architecture
@@ -72,8 +72,10 @@ class ModelConfig:
72
72
  else:
73
73
  self.attention_arch = AttentionArch.MHA
74
74
 
75
- self.num_attention_heads = self.hf_config.num_attention_heads
76
- self.num_key_value_heads = getattr(self.hf_config, "num_key_value_heads", None)
75
+ self.num_attention_heads = self.hf_text_config.num_attention_heads
76
+ self.num_key_value_heads = getattr(
77
+ self.hf_text_config, "num_key_value_heads", None
78
+ )
77
79
 
78
80
  # for Dbrx and MPT models
79
81
  if self.hf_config.model_type in ["dbrx", "mpt"]:
@@ -83,9 +85,9 @@ class ModelConfig:
83
85
 
84
86
  if self.num_key_value_heads is None:
85
87
  self.num_key_value_heads = self.num_attention_heads
86
- self.hidden_size = self.hf_config.hidden_size
87
- self.num_hidden_layers = self.hf_config.num_hidden_layers
88
- self.vocab_size = self.hf_config.vocab_size
88
+ self.hidden_size = self.hf_text_config.hidden_size
89
+ self.num_hidden_layers = self.hf_text_config.num_hidden_layers
90
+ self.vocab_size = self.hf_text_config.vocab_size
89
91
 
90
92
  # adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
91
93
  def get_total_num_kv_heads(self) -> int:
@@ -14,13 +14,17 @@ limitations under the License.
14
14
  """
15
15
 
16
16
  """Cache for the compressed finite state machine."""
17
+ import logging
17
18
 
19
+ from interegular import InvalidSyntax, parse_pattern
18
20
  from outlines.fsm.json_schema import build_regex_from_schema
19
21
  from transformers import AutoTokenizer
20
22
 
21
23
  from sglang.srt.constrained import RegexGuide, TransformerTokenizer
22
24
  from sglang.srt.constrained.base_tool_cache import BaseToolCache
23
25
 
26
+ logger = logging.getLogger(__name__)
27
+
24
28
 
25
29
  class FSMCache(BaseToolCache):
26
30
  def __init__(
@@ -76,5 +80,9 @@ class FSMCache(BaseToolCache):
76
80
  regex = key_string
77
81
  else:
78
82
  raise ValueError(f"Invalid key_type: {key_type}")
79
-
83
+ try:
84
+ parse_pattern(regex)
85
+ except InvalidSyntax as e:
86
+ logger.warning(f"skip invalid regex guide: {regex=}, {e=}")
87
+ return None, regex
80
88
  return RegexGuide(regex, self.outlines_tokenizer), regex
@@ -19,10 +19,12 @@ Reference: https://lmsys.org/blog/2024-02-05-compressed-fsm/
19
19
  """
20
20
 
21
21
  import dataclasses
22
+ import logging
22
23
  from collections import defaultdict
23
24
 
24
25
  import interegular
25
26
  import outlines.caching
27
+ from interegular import InvalidSyntax
26
28
 
27
29
  from sglang.srt.constrained import (
28
30
  FSMInfo,
@@ -34,6 +36,8 @@ from sglang.srt.constrained.base_tool_cache import BaseToolCache
34
36
 
35
37
  IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
36
38
 
39
+ logger = logging.getLogger(__name__)
40
+
37
41
 
38
42
  @dataclasses.dataclass
39
43
  class JumpEdge:
@@ -47,7 +51,12 @@ class JumpForwardMap:
47
51
  def __init__(self, regex_string):
48
52
  @disk_cache()
49
53
  def _init_state_to_jump_forward(regex_string):
50
- regex_pattern = interegular.parse_pattern(regex_string)
54
+ try:
55
+ regex_pattern = interegular.parse_pattern(regex_string)
56
+ except InvalidSyntax as e:
57
+ logger.warning(f"skip invalid regex: {regex_string}, {e=}")
58
+ self.state_to_jump_forward = None
59
+ return
51
60
 
52
61
  byte_fsm = make_byte_level_fsm(
53
62
  regex_pattern.to_fsm().reduce(), keep_utf8=True
@@ -165,7 +174,11 @@ class JumpForwardCache(BaseToolCache):
165
174
  super().__init__()
166
175
 
167
176
  def init_value(self, regex):
168
- return JumpForwardMap(regex)
177
+ forward_map = JumpForwardMap(regex)
178
+ if forward_map.state_to_jump_forward:
179
+ return forward_map
180
+ else:
181
+ return None
169
182
 
170
183
 
171
184
  def test_main(regex_string):
@@ -129,6 +129,7 @@ def get_tokenizer(
129
129
  *args,
130
130
  trust_remote_code=trust_remote_code,
131
131
  tokenizer_revision=tokenizer_revision,
132
+ clean_up_tokenization_spaces=False,
132
133
  **kwargs,
133
134
  )
134
135
  except TypeError as e: