sglang 0.1.17__py3-none-any.whl → 0.1.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. sglang/__init__.py +2 -2
  2. sglang/api.py +30 -4
  3. sglang/backend/litellm.py +2 -2
  4. sglang/backend/openai.py +26 -15
  5. sglang/backend/runtime_endpoint.py +18 -14
  6. sglang/bench_latency.py +317 -0
  7. sglang/global_config.py +5 -1
  8. sglang/lang/chat_template.py +41 -6
  9. sglang/lang/compiler.py +2 -2
  10. sglang/lang/interpreter.py +6 -2
  11. sglang/lang/ir.py +74 -28
  12. sglang/launch_server.py +4 -1
  13. sglang/launch_server_llavavid.py +2 -1
  14. sglang/srt/constrained/__init__.py +14 -6
  15. sglang/srt/constrained/fsm_cache.py +6 -3
  16. sglang/srt/constrained/jump_forward.py +113 -25
  17. sglang/srt/conversation.py +2 -0
  18. sglang/srt/flush_cache.py +2 -0
  19. sglang/srt/hf_transformers_utils.py +68 -9
  20. sglang/srt/layers/extend_attention.py +2 -1
  21. sglang/srt/layers/fused_moe.py +280 -169
  22. sglang/srt/layers/logits_processor.py +106 -42
  23. sglang/srt/layers/radix_attention.py +53 -29
  24. sglang/srt/layers/token_attention.py +4 -1
  25. sglang/srt/managers/controller/dp_worker.py +6 -3
  26. sglang/srt/managers/controller/infer_batch.py +144 -69
  27. sglang/srt/managers/controller/manager_multi.py +5 -5
  28. sglang/srt/managers/controller/manager_single.py +9 -4
  29. sglang/srt/managers/controller/model_runner.py +167 -55
  30. sglang/srt/managers/controller/radix_cache.py +4 -0
  31. sglang/srt/managers/controller/schedule_heuristic.py +2 -0
  32. sglang/srt/managers/controller/tp_worker.py +156 -134
  33. sglang/srt/managers/detokenizer_manager.py +19 -21
  34. sglang/srt/managers/io_struct.py +11 -5
  35. sglang/srt/managers/tokenizer_manager.py +16 -14
  36. sglang/srt/model_config.py +89 -4
  37. sglang/srt/models/chatglm.py +399 -0
  38. sglang/srt/models/commandr.py +2 -2
  39. sglang/srt/models/dbrx.py +1 -1
  40. sglang/srt/models/gemma.py +5 -1
  41. sglang/srt/models/gemma2.py +436 -0
  42. sglang/srt/models/grok.py +204 -137
  43. sglang/srt/models/llama2.py +12 -5
  44. sglang/srt/models/llama_classification.py +107 -0
  45. sglang/srt/models/llava.py +11 -8
  46. sglang/srt/models/llavavid.py +1 -1
  47. sglang/srt/models/minicpm.py +373 -0
  48. sglang/srt/models/mixtral.py +164 -115
  49. sglang/srt/models/mixtral_quant.py +0 -1
  50. sglang/srt/models/qwen.py +1 -1
  51. sglang/srt/models/qwen2.py +1 -1
  52. sglang/srt/models/qwen2_moe.py +454 -0
  53. sglang/srt/models/stablelm.py +1 -1
  54. sglang/srt/models/yivl.py +2 -2
  55. sglang/srt/openai_api_adapter.py +35 -25
  56. sglang/srt/openai_protocol.py +2 -2
  57. sglang/srt/server.py +69 -19
  58. sglang/srt/server_args.py +76 -43
  59. sglang/srt/utils.py +177 -35
  60. sglang/test/test_programs.py +28 -10
  61. sglang/utils.py +4 -3
  62. {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/METADATA +44 -31
  63. sglang-0.1.19.dist-info/RECORD +81 -0
  64. {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/WHEEL +1 -1
  65. sglang/srt/managers/router/infer_batch.py +0 -596
  66. sglang/srt/managers/router/manager.py +0 -82
  67. sglang/srt/managers/router/model_rpc.py +0 -818
  68. sglang/srt/managers/router/model_runner.py +0 -445
  69. sglang/srt/managers/router/radix_cache.py +0 -267
  70. sglang/srt/managers/router/scheduler.py +0 -59
  71. sglang-0.1.17.dist-info/RECORD +0 -81
  72. {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/LICENSE +0 -0
  73. {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/top_level.txt +0 -0
@@ -84,7 +84,7 @@ register_chat_template(
84
84
  "system": ("SYSTEM:", "\n"),
85
85
  "user": ("USER:", "\n"),
86
86
  "assistant": ("ASSISTANT:", "\n"),
87
- },
87
+ }
88
88
  )
89
89
  )
90
90
 
@@ -116,6 +116,23 @@ register_chat_template(
116
116
  )
117
117
  )
118
118
 
119
+ # There is default system prompt for qwen
120
+ # reference: https://modelscope.cn/models/qwen/Qwen2-72B-Instruct/file/view/master?fileName=tokenizer_config.json&status=1
121
+ # The chat template is: "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
122
+ register_chat_template(
123
+ ChatTemplate(
124
+ name="qwen",
125
+ default_system_prompt="You are a helpful assistant.",
126
+ role_prefix_and_suffix={
127
+ "system": ("<|im_start|>system\n", "<|im_end|>\n"),
128
+ "user": ("<|im_start|>user\n", "<|im_end|>\n"),
129
+ "assistant": ("<|im_start|>assistant\n", "<|im_end|>\n"),
130
+ },
131
+ style=ChatTemplateStyle.PLAIN,
132
+ stop_str=("<|im_end|>",),
133
+ )
134
+ )
135
+
119
136
 
120
137
  register_chat_template(
121
138
  ChatTemplate(
@@ -132,6 +149,7 @@ register_chat_template(
132
149
  )
133
150
  )
134
151
 
152
+ # Reference: https://github.com/lm-sys/FastChat/blob/main/docs/vicuna_weights_version.md#prompt-template
135
153
  register_chat_template(
136
154
  ChatTemplate(
137
155
  name="vicuna_v1.1",
@@ -148,6 +166,20 @@ register_chat_template(
148
166
  )
149
167
  )
150
168
 
169
+ # Reference: https://modelscope.cn/models/01ai/Yi-1.5-34B-Chat/file/view/master?fileName=tokenizer_config.json&status=1
170
+ register_chat_template(
171
+ ChatTemplate(
172
+ name="yi-1.5",
173
+ default_system_prompt=None,
174
+ role_prefix_and_suffix={
175
+ "system": ("", ""),
176
+ "user": ("<|im_start|>user\n", "<|im_end|>\n<|im_start|>assistant\n"),
177
+ "assistant": ("", "<|im_end|>\n"),
178
+ },
179
+ style=ChatTemplateStyle.PLAIN,
180
+ stop_str=("<|im_end|>",)
181
+ )
182
+ )
151
183
 
152
184
  register_chat_template(
153
185
  ChatTemplate(
@@ -187,7 +219,7 @@ register_chat_template(
187
219
  # Reference: https://github.com/01-ai/Yi/tree/main/VL#major-difference-with-llava
188
220
  register_chat_template(
189
221
  ChatTemplate(
190
- name="yi",
222
+ name="yi-vl",
191
223
  default_system_prompt=(
192
224
  "This is a chat between an inquisitive human and an AI assistant. Assume the role of the AI assistant. Read all the images carefully, and respond to the human's questions with informative, helpful, detailed and polite answers."
193
225
  "这是一个好奇的人类和一个人工智能助手之间的对话。假设你扮演这个AI助手的角色。仔细阅读所有的图像,并对人类的问题做出信息丰富、有帮助、详细的和礼貌的回答。"
@@ -289,8 +321,9 @@ def match_chat_ml(model_path: str):
289
321
  model_path = model_path.lower()
290
322
  if "tinyllama" in model_path:
291
323
  return get_chat_template("chatml")
292
- if "qwen" in model_path and "chat" in model_path:
293
- return get_chat_template("chatml")
324
+ # Now the suffix for qwen2 chat model is "instruct"
325
+ if "qwen" in model_path and ("chat" in model_path or "instruct" in model_path):
326
+ return get_chat_template("qwen")
294
327
  if (
295
328
  "llava-v1.6-34b" in model_path
296
329
  or "llava-v1.6-yi-34b" in model_path
@@ -302,8 +335,10 @@ def match_chat_ml(model_path: str):
302
335
  @register_chat_template_matching_function
303
336
  def match_chat_yi(model_path: str):
304
337
  model_path = model_path.lower()
305
- if "yi" in model_path and "llava" not in model_path:
306
- return get_chat_template("yi")
338
+ if "yi-vl" in model_path and "llava" not in model_path:
339
+ return get_chat_template("yi-vl")
340
+ elif "yi-1.5" in model_path and "chat" in model_path:
341
+ return get_chat_template("yi-1.5")
307
342
 
308
343
 
309
344
  @register_chat_template_matching_function
sglang/lang/compiler.py CHANGED
@@ -4,7 +4,7 @@ from queue import Queue
4
4
  from typing import List, Union
5
5
 
6
6
  from sglang.global_config import global_config
7
- from sglang.lang.interpreter import ProgramState, StreamExecutor, pin_program
7
+ from sglang.lang.interpreter import ProgramState, StreamExecutor, cache_program
8
8
  from sglang.lang.ir import (
9
9
  SglArgument,
10
10
  SglConstantText,
@@ -184,7 +184,7 @@ class CompiledFunction:
184
184
 
185
185
  # Extract prefix by tracing and cache it
186
186
  if len(batch_kwargs) > 1:
187
- pin_program(self.function, backend)
187
+ cache_program(self.function, backend)
188
188
 
189
189
  # Run all programs
190
190
  if num_threads == "auto":
@@ -507,7 +507,7 @@ class StreamExecutor:
507
507
  )
508
508
  return
509
509
 
510
- else: # Speculative execution on models with completion interface
510
+ else: # Speculative execution on models with completion interface
511
511
  comp, meta_info = self._spec_gen(sampling_params)
512
512
 
513
513
  self.text_ += comp
@@ -523,9 +523,9 @@ class StreamExecutor:
523
523
  self, sampling_params=sampling_params
524
524
  )
525
525
 
526
+ self.variables[name] = ""
526
527
  self.stream_var_event[name].set()
527
528
 
528
- self.variables[name] = ""
529
529
  for comp, meta_info in generator:
530
530
  self.text_ += comp
531
531
  self.variables[name] += comp
@@ -668,6 +668,10 @@ class StreamExecutor:
668
668
  "frequency_penalty",
669
669
  "presence_penalty",
670
670
  "ignore_eos",
671
+ "return_logprob",
672
+ "logprob_start_len",
673
+ "top_logprobs_num",
674
+ "return_text_in_logprobs",
671
675
  "dtype",
672
676
  "regex",
673
677
  ]:
sglang/lang/ir.py CHANGED
@@ -23,6 +23,10 @@ class SglSamplingParams:
23
23
  frequency_penalty: float = 0.0
24
24
  presence_penalty: float = 0.0
25
25
  ignore_eos: bool = False
26
+ return_logprob: Optional[bool] = None
27
+ logprob_start_len: Optional[int] = None,
28
+ top_logprobs_num: Optional[int] = None,
29
+ return_text_in_logprobs: Optional[bool] = None,
26
30
 
27
31
  # for constrained generation, not included in to_xxx_kwargs
28
32
  dtype: Optional[str] = None
@@ -37,6 +41,11 @@ class SglSamplingParams:
37
41
  self.top_k,
38
42
  self.frequency_penalty,
39
43
  self.presence_penalty,
44
+ self.ignore_eos,
45
+ self.return_logprob,
46
+ self.logprob_start_len,
47
+ self.top_logprobs_num,
48
+ self.return_text_in_logprobs,
40
49
  )
41
50
 
42
51
  def to_openai_kwargs(self):
@@ -81,12 +90,10 @@ class SglSamplingParams:
81
90
  "top_p": self.top_p,
82
91
  "top_k": self.top_k,
83
92
  }
84
-
93
+
85
94
  def to_litellm_kwargs(self):
86
95
  if self.regex is not None:
87
- warnings.warn(
88
- "Regular expression is not supported in the LiteLLM backend."
89
- )
96
+ warnings.warn("Regular expression is not supported in the LiteLLM backend.")
90
97
  return {
91
98
  "max_tokens": self.max_new_tokens,
92
99
  "stop": self.stop or None,
@@ -122,6 +129,7 @@ class SglFunction:
122
129
  argspec = inspect.getfullargspec(func)
123
130
  assert argspec.args[0] == "s", 'The first argument must be "s"'
124
131
  self.arg_names = argspec.args[1:]
132
+ self.arg_defaults = argspec.defaults if argspec.defaults is not None else []
125
133
 
126
134
  def bind(self, **kwargs):
127
135
  assert all(key in self.arg_names for key in kwargs)
@@ -140,6 +148,10 @@ class SglFunction:
140
148
  frequency_penalty: float = 0.0,
141
149
  presence_penalty: float = 0.0,
142
150
  ignore_eos: bool = False,
151
+ return_logprob: Optional[bool] = None,
152
+ logprob_start_len: Optional[int] = None,
153
+ top_logprobs_num: Optional[int] = None,
154
+ return_text_in_logprobs: Optional[bool] = None,
143
155
  stream: bool = False,
144
156
  backend=None,
145
157
  **kwargs,
@@ -155,6 +167,10 @@ class SglFunction:
155
167
  frequency_penalty=frequency_penalty,
156
168
  presence_penalty=presence_penalty,
157
169
  ignore_eos=ignore_eos,
170
+ return_logprob=return_logprob,
171
+ logprob_start_len=logprob_start_len,
172
+ top_logprobs_num=top_logprobs_num,
173
+ return_text_in_logprobs=return_text_in_logprobs,
158
174
  )
159
175
  backend = backend or global_config.default_backend
160
176
  return run_program(self, backend, args, kwargs, default_sampling_para, stream)
@@ -171,6 +187,10 @@ class SglFunction:
171
187
  frequency_penalty: float = 0.0,
172
188
  presence_penalty: float = 0.0,
173
189
  ignore_eos: bool = False,
190
+ return_logprob: Optional[bool] = None,
191
+ logprob_start_len: Optional[int] = None,
192
+ top_logprobs_num: Optional[int] = None,
193
+ return_text_in_logprobs: Optional[bool] = None,
174
194
  backend=None,
175
195
  num_threads: Union[str, int] = "auto",
176
196
  progress_bar: bool = False,
@@ -180,7 +200,20 @@ class SglFunction:
180
200
  assert isinstance(batch_kwargs, (list, tuple))
181
201
  if len(batch_kwargs) == 0:
182
202
  return []
183
- assert isinstance(batch_kwargs[0], dict)
203
+ if not isinstance(batch_kwargs[0], dict):
204
+ num_programs = len(batch_kwargs)
205
+ # change the list of argument values to dict of arg_name -> arg_value
206
+ batch_kwargs = [
207
+ {self.arg_names[i]: v for i, v in enumerate(arg_values)}
208
+ for arg_values in batch_kwargs
209
+ if isinstance(arg_values, (list, tuple))
210
+ and len(self.arg_names) - len(self.arg_defaults)
211
+ <= len(arg_values)
212
+ <= len(self.arg_names)
213
+ ]
214
+ # Ensure to raise an exception if the number of arguments mismatch
215
+ if len(batch_kwargs) != num_programs:
216
+ raise Exception("Given arguments mismatch the SGL function signature")
184
217
 
185
218
  default_sampling_para = SglSamplingParams(
186
219
  max_new_tokens=max_new_tokens,
@@ -191,6 +224,10 @@ class SglFunction:
191
224
  frequency_penalty=frequency_penalty,
192
225
  presence_penalty=presence_penalty,
193
226
  ignore_eos=ignore_eos,
227
+ return_logprob=return_logprob,
228
+ logprob_start_len=logprob_start_len,
229
+ top_logprobs_num=top_logprobs_num,
230
+ return_text_in_logprobs=return_text_in_logprobs,
194
231
  )
195
232
  backend = backend or global_config.default_backend
196
233
  return run_program_batch(
@@ -338,7 +375,7 @@ class SglArgument(SglExpr):
338
375
 
339
376
 
340
377
  class SglImage(SglExpr):
341
- def __init__(self, path):
378
+ def __init__(self, path: str):
342
379
  self.path = path
343
380
 
344
381
  def __repr__(self) -> str:
@@ -346,7 +383,7 @@ class SglImage(SglExpr):
346
383
 
347
384
 
348
385
  class SglVideo(SglExpr):
349
- def __init__(self, path, num_frames):
386
+ def __init__(self, path: str, num_frames: int):
350
387
  self.path = path
351
388
  self.num_frames = num_frames
352
389
 
@@ -357,18 +394,23 @@ class SglVideo(SglExpr):
357
394
  class SglGen(SglExpr):
358
395
  def __init__(
359
396
  self,
360
- name,
361
- max_new_tokens,
362
- stop,
363
- temperature,
364
- top_p,
365
- top_k,
366
- frequency_penalty,
367
- presence_penalty,
368
- ignore_eos,
369
- dtype,
370
- regex,
397
+ name: Optional[str] = None,
398
+ max_new_tokens: Optional[int] = None,
399
+ stop: Optional[Union[str, List[str]]] = None,
400
+ temperature: Optional[float] = None,
401
+ top_p: Optional[float] = None,
402
+ top_k: Optional[int] = None,
403
+ frequency_penalty: Optional[float] = None,
404
+ presence_penalty: Optional[float] = None,
405
+ ignore_eos: Optional[bool] = None,
406
+ return_logprob: Optional[bool] = None,
407
+ logprob_start_len: Optional[int] = None,
408
+ top_logprobs_num: Optional[int] = None,
409
+ return_text_in_logprobs: Optional[bool] = None,
410
+ dtype: Optional[type] = None,
411
+ regex: Optional[str] = None,
371
412
  ):
413
+ """Call the model to generate. See the meaning of the arguments in docs/sampling_params.md"""
372
414
  super().__init__()
373
415
  self.name = name
374
416
  self.sampling_params = SglSamplingParams(
@@ -380,6 +422,10 @@ class SglGen(SglExpr):
380
422
  frequency_penalty=frequency_penalty,
381
423
  presence_penalty=presence_penalty,
382
424
  ignore_eos=ignore_eos,
425
+ return_logprob=return_logprob,
426
+ logprob_start_len=logprob_start_len,
427
+ top_logprobs_num=top_logprobs_num,
428
+ return_text_in_logprobs=return_text_in_logprobs,
383
429
  dtype=dtype,
384
430
  regex=regex,
385
431
  )
@@ -389,7 +435,7 @@ class SglGen(SglExpr):
389
435
 
390
436
 
391
437
  class SglConstantText(SglExpr):
392
- def __init__(self, value):
438
+ def __init__(self, value: str):
393
439
  super().__init__()
394
440
  self.value = value
395
441
 
@@ -398,7 +444,7 @@ class SglConstantText(SglExpr):
398
444
 
399
445
 
400
446
  class SglRoleBegin(SglExpr):
401
- def __init__(self, role):
447
+ def __init__(self, role: str):
402
448
  super().__init__()
403
449
  self.role = role
404
450
 
@@ -407,7 +453,7 @@ class SglRoleBegin(SglExpr):
407
453
 
408
454
 
409
455
  class SglRoleEnd(SglExpr):
410
- def __init__(self, role):
456
+ def __init__(self, role: str):
411
457
  super().__init__()
412
458
  self.role = role
413
459
 
@@ -416,7 +462,7 @@ class SglRoleEnd(SglExpr):
416
462
 
417
463
 
418
464
  class SglSelect(SglExpr):
419
- def __init__(self, name, choices, temperature):
465
+ def __init__(self, name: str, choices: List[str], temperature: float):
420
466
  super().__init__()
421
467
  self.name = name
422
468
  self.choices = choices
@@ -427,7 +473,7 @@ class SglSelect(SglExpr):
427
473
 
428
474
 
429
475
  class SglFork(SglExpr):
430
- def __init__(self, number, position_ids_offset=None):
476
+ def __init__(self, number: int, position_ids_offset=None):
431
477
  super().__init__()
432
478
  self.number = number
433
479
  self.position_ids_offset = position_ids_offset
@@ -440,7 +486,7 @@ class SglFork(SglExpr):
440
486
 
441
487
 
442
488
  class SglGetForkItem(SglExpr):
443
- def __init__(self, index):
489
+ def __init__(self, index: int):
444
490
  super().__init__()
445
491
  self.index = index
446
492
 
@@ -449,7 +495,7 @@ class SglGetForkItem(SglExpr):
449
495
 
450
496
 
451
497
  class SglVariable(SglExpr):
452
- def __init__(self, name, source):
498
+ def __init__(self, name: str, source):
453
499
  super().__init__()
454
500
  self.name = name
455
501
  self.source = source
@@ -459,7 +505,7 @@ class SglVariable(SglExpr):
459
505
 
460
506
 
461
507
  class SglVarScopeBegin(SglExpr):
462
- def __init__(self, name):
508
+ def __init__(self, name: str):
463
509
  super().__init__()
464
510
  self.name = name
465
511
 
@@ -468,7 +514,7 @@ class SglVarScopeBegin(SglExpr):
468
514
 
469
515
 
470
516
  class SglVarScopeEnd(SglExpr):
471
- def __init__(self, name):
517
+ def __init__(self, name: str):
472
518
  super().__init__()
473
519
  self.name = name
474
520
 
@@ -490,4 +536,4 @@ class SglCommitLazy(SglExpr):
490
536
  super().__init__()
491
537
 
492
538
  def __repr__(self):
493
- return f"CommitLazy()"
539
+ return "CommitLazy()"
sglang/launch_server.py CHANGED
@@ -1,6 +1,9 @@
1
+ """Launch the inference server."""
2
+
1
3
  import argparse
2
4
 
3
- from sglang.srt.server import ServerArgs, launch_server
5
+ from sglang.srt.server import launch_server
6
+ from sglang.srt.server_args import ServerArgs
4
7
 
5
8
  if __name__ == "__main__":
6
9
  parser = argparse.ArgumentParser()
@@ -1,10 +1,11 @@
1
+ """Launch the inference server for Llava-video model."""
2
+
1
3
  import argparse
2
4
  import multiprocessing as mp
3
5
 
4
6
  from sglang.srt.server import ServerArgs, launch_server
5
7
 
6
8
  if __name__ == "__main__":
7
-
8
9
  model_overide_args = {}
9
10
 
10
11
  model_overide_args["mm_spatial_pool_stride"] = 2
@@ -1,13 +1,20 @@
1
1
  import json
2
2
  from typing import Dict, Optional, Union
3
3
 
4
- from outlines.caching import cache as disk_cache
5
- from outlines.caching import disable_cache
6
- from outlines.fsm.fsm import RegexFSM
7
- from outlines.fsm.regex import FSMInfo, make_deterministic_fsm
8
- from outlines.models.transformers import TransformerTokenizer
9
4
  from pydantic import BaseModel
10
5
 
6
+ try:
7
+ from outlines.caching import cache as disk_cache
8
+ from outlines.caching import disable_cache
9
+ from outlines.fsm.guide import RegexGuide
10
+ from outlines.fsm.regex import FSMInfo, make_byte_level_fsm, make_deterministic_fsm
11
+ from outlines.models.transformers import TransformerTokenizer
12
+ except ImportError as e:
13
+ print(
14
+ f'\nError: {e}. Please install a new version of outlines by `pip install "outlines>=0.0.44"`\n'
15
+ )
16
+ raise
17
+
11
18
  try:
12
19
  from outlines.fsm.json_schema import build_regex_from_object
13
20
  except ImportError:
@@ -28,11 +35,12 @@ except ImportError:
28
35
 
29
36
 
30
37
  __all__ = [
31
- "RegexFSM",
38
+ "RegexGuide",
32
39
  "FSMInfo",
33
40
  "make_deterministic_fsm",
34
41
  "build_regex_from_object",
35
42
  "TransformerTokenizer",
36
43
  "disk_cache",
37
44
  "disable_cache",
45
+ "make_byte_level_fsm",
38
46
  ]
@@ -1,4 +1,6 @@
1
- from sglang.srt.constrained import RegexFSM, TransformerTokenizer
1
+ """Cache for the compressed finite state machine."""
2
+
3
+ from sglang.srt.constrained import RegexGuide, TransformerTokenizer
2
4
  from sglang.srt.constrained.base_cache import BaseCache
3
5
 
4
6
 
@@ -6,7 +8,8 @@ class FSMCache(BaseCache):
6
8
  def __init__(self, tokenizer_path, tokenizer_args_dict, enable=True):
7
9
  super().__init__(enable=enable)
8
10
 
9
- if tokenizer_path.endswith(".json"):
11
+ if tokenizer_path.endswith(".json") or tokenizer_path.endswith(".model"):
12
+ # Do not support TiktokenTokenizer or SentencePieceTokenizer
10
13
  return
11
14
 
12
15
  from importlib.metadata import version
@@ -25,4 +28,4 @@ class FSMCache(BaseCache):
25
28
  )
26
29
 
27
30
  def init_value(self, regex):
28
- return RegexFSM(regex, self.outlines_tokenizer)
31
+ return RegexGuide(regex, self.outlines_tokenizer)
@@ -1,17 +1,43 @@
1
- import interegular
1
+ """
2
+ Faster constrained decoding.
3
+ Reference: https://lmsys.org/blog/2024-02-05-compressed-fsm/
4
+ """
5
+
6
+ import dataclasses
7
+ from collections import defaultdict
2
8
 
3
- from sglang.srt.constrained import FSMInfo, disk_cache, make_deterministic_fsm
9
+ import interegular
10
+ import outlines.caching
11
+
12
+ from sglang.srt.constrained import (
13
+ FSMInfo,
14
+ disk_cache,
15
+ make_byte_level_fsm,
16
+ make_deterministic_fsm,
17
+ )
4
18
  from sglang.srt.constrained.base_cache import BaseCache
5
19
 
6
20
  IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
7
21
 
8
22
 
23
+ @dataclasses.dataclass
24
+ class JumpEdge:
25
+ symbol: str = None
26
+ symbol_next_state: int = None
27
+ byte: int = None
28
+ byte_next_state: int = None
29
+
30
+
9
31
  class JumpForwardMap:
10
32
  def __init__(self, regex_string):
11
33
  @disk_cache()
12
34
  def _init_state_to_jump_forward(regex_string):
13
35
  regex_pattern = interegular.parse_pattern(regex_string)
14
- regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce())
36
+
37
+ byte_fsm = make_byte_level_fsm(
38
+ regex_pattern.to_fsm().reduce(), keep_utf8=True
39
+ )
40
+ regex_fsm, _ = make_deterministic_fsm(byte_fsm)
15
41
 
16
42
  fsm_info: FSMInfo = regex_fsm.fsm_info
17
43
 
@@ -21,40 +47,93 @@ class JumpForwardMap:
21
47
  id_to_symbol.setdefault(id_, []).append(symbol)
22
48
 
23
49
  transitions = fsm_info.transitions
24
- dirty_states = set()
50
+ outgoings_ct = defaultdict(int)
25
51
  state_to_jump_forward = {}
26
52
 
27
53
  for (state, id_), next_state in transitions.items():
28
- if state in dirty_states:
29
- continue
30
- if state in state_to_jump_forward:
31
- dirty_states.add(state)
32
- del state_to_jump_forward[state]
54
+ if id_ == fsm_info.alphabet_anything_value:
33
55
  continue
34
- if len(id_to_symbol[id_]) > 1:
35
- dirty_states.add(state)
56
+ symbols = id_to_symbol[id_]
57
+ for c in symbols:
58
+ if len(c) > 1:
59
+ # Skip byte level transitions
60
+ continue
61
+
62
+ outgoings_ct[state] += 1
63
+ if outgoings_ct[state] > 1:
64
+ if state in state_to_jump_forward:
65
+ del state_to_jump_forward[state]
66
+ break
67
+
68
+ state_to_jump_forward[state] = JumpEdge(
69
+ symbol=c,
70
+ symbol_next_state=next_state,
71
+ )
72
+
73
+ # Process the byte level jump forward
74
+ outgoings_ct = defaultdict(int)
75
+ for (state, id_), next_state in transitions.items():
76
+ if id_ == fsm_info.alphabet_anything_value:
36
77
  continue
37
-
38
- state_to_jump_forward[state] = (id_to_symbol[id_][0], next_state)
78
+ symbols = id_to_symbol[id_]
79
+ for c in symbols:
80
+ byte_ = None
81
+ if len(c) == 1 and ord(c) < 0x80:
82
+ # ASCII character
83
+ byte_ = ord(c)
84
+ elif len(c) > 1:
85
+ # FIXME: This logic is due to the leading \x00
86
+ # https://github.com/outlines-dev/outlines/pull/930
87
+ byte_ = int(symbols[0][1:], 16)
88
+
89
+ if byte_ is not None:
90
+ outgoings_ct[state] += 1
91
+ if outgoings_ct[state] > 1:
92
+ if state in state_to_jump_forward:
93
+ del state_to_jump_forward[state]
94
+ break
95
+ e = state_to_jump_forward.get(state, JumpEdge())
96
+ e.byte = byte_
97
+ e.byte_next_state = next_state
98
+ state_to_jump_forward[state] = e
39
99
 
40
100
  return state_to_jump_forward
41
101
 
42
102
  self.state_to_jump_forward = _init_state_to_jump_forward(regex_string)
43
103
 
44
- def valid_states(self):
45
- return self.state_to_jump_forward.keys()
104
+ def jump_forward_symbol(self, state):
105
+ jump_forward_str = ""
106
+ next_state = state
107
+ while state in self.state_to_jump_forward:
108
+ e = self.state_to_jump_forward[state]
109
+ if e.symbol is None:
110
+ break
111
+ jump_forward_str += e.symbol
112
+ next_state = e.symbol_next_state
113
+ state = next_state
46
114
 
47
- def jump_forward(self, state):
115
+ return jump_forward_str, next_state
116
+
117
+ def jump_forward_byte(self, state):
48
118
  if state not in self.state_to_jump_forward:
49
119
  return None
50
120
 
51
- jump_forward_str = ""
121
+ jump_forward_bytes = []
52
122
  next_state = None
53
123
  while state in self.state_to_jump_forward:
54
- symbol, next_state = self.state_to_jump_forward[state]
55
- jump_forward_str += symbol
124
+ e = self.state_to_jump_forward[state]
125
+ assert e.byte is not None and e.byte_next_state is not None
126
+ jump_forward_bytes.append((e.byte, e.byte_next_state))
127
+ next_state = e.byte_next_state
56
128
  state = next_state
57
- return jump_forward_str, next_state
129
+
130
+ return jump_forward_bytes
131
+
132
+ def is_jump_forward_symbol_state(self, state):
133
+ return (
134
+ state in self.state_to_jump_forward
135
+ and self.state_to_jump_forward[state].symbol is not None
136
+ )
58
137
 
59
138
 
60
139
  class JumpForwardCache(BaseCache):
@@ -65,12 +144,21 @@ class JumpForwardCache(BaseCache):
65
144
  return JumpForwardMap(regex)
66
145
 
67
146
 
68
- def test_main():
69
- regex_string = r"The google's DNS sever address is " + IP_REGEX
147
+ def test_main(regex_string):
70
148
  jump_forward_map = JumpForwardMap(regex_string)
71
- for state in jump_forward_map.valid_states():
72
- print(state, f'"{jump_forward_map.jump_forward(state)}"')
149
+ for state, e in jump_forward_map.state_to_jump_forward.items():
150
+ if e.symbol is not None:
151
+ jump_forward_str, next_state = jump_forward_map.jump_forward_symbol(state)
152
+ print(f"{state} -> {next_state}", jump_forward_str)
153
+ bytes_ = jump_forward_map.jump_forward_byte(state)
154
+ print(f"{state} -> {bytes_[-1][1]}", [hex(b) for b, _ in bytes_])
73
155
 
74
156
 
75
157
  if __name__ == "__main__":
76
- test_main()
158
+ import outlines
159
+
160
+ outlines.caching.clear_cache()
161
+ test_main(r"The google's DNS sever address is " + IP_REGEX)
162
+ test_main(r"霍格沃茨特快列车|霍比特人比尔博")
163
+ # 霍格: \xe9\x9c\x8d \xe6\xa0\xbc ...
164
+ # 霍比: \xe9\x9c\x8d \xe6\xaf\x94 ...
@@ -1,3 +1,5 @@
1
+ """Conversation templates."""
2
+
1
3
  # Adapted from
2
4
  # https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
3
5
  import dataclasses