sglang 0.1.14__py3-none-any.whl → 0.1.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. sglang/__init__.py +59 -2
  2. sglang/api.py +40 -11
  3. sglang/backend/anthropic.py +17 -3
  4. sglang/backend/litellm.py +90 -0
  5. sglang/backend/openai.py +160 -12
  6. sglang/backend/runtime_endpoint.py +62 -27
  7. sglang/backend/vertexai.py +1 -0
  8. sglang/bench_latency.py +320 -0
  9. sglang/global_config.py +24 -3
  10. sglang/lang/chat_template.py +122 -6
  11. sglang/lang/compiler.py +2 -2
  12. sglang/lang/interpreter.py +206 -98
  13. sglang/lang/ir.py +98 -34
  14. sglang/lang/tracer.py +6 -4
  15. sglang/launch_server.py +4 -1
  16. sglang/launch_server_llavavid.py +32 -0
  17. sglang/srt/constrained/__init__.py +14 -6
  18. sglang/srt/constrained/fsm_cache.py +9 -2
  19. sglang/srt/constrained/jump_forward.py +113 -24
  20. sglang/srt/conversation.py +4 -2
  21. sglang/srt/flush_cache.py +18 -0
  22. sglang/srt/hf_transformers_utils.py +144 -3
  23. sglang/srt/layers/context_flashattention_nopad.py +1 -0
  24. sglang/srt/layers/extend_attention.py +20 -1
  25. sglang/srt/layers/fused_moe.py +596 -0
  26. sglang/srt/layers/logits_processor.py +190 -61
  27. sglang/srt/layers/radix_attention.py +62 -53
  28. sglang/srt/layers/token_attention.py +21 -9
  29. sglang/srt/managers/controller/cuda_graph_runner.py +196 -0
  30. sglang/srt/managers/controller/dp_worker.py +113 -0
  31. sglang/srt/managers/controller/infer_batch.py +908 -0
  32. sglang/srt/managers/controller/manager_multi.py +195 -0
  33. sglang/srt/managers/controller/manager_single.py +177 -0
  34. sglang/srt/managers/controller/model_runner.py +359 -0
  35. sglang/srt/managers/{router → controller}/radix_cache.py +102 -53
  36. sglang/srt/managers/controller/schedule_heuristic.py +65 -0
  37. sglang/srt/managers/controller/tp_worker.py +813 -0
  38. sglang/srt/managers/detokenizer_manager.py +42 -40
  39. sglang/srt/managers/io_struct.py +44 -10
  40. sglang/srt/managers/tokenizer_manager.py +224 -82
  41. sglang/srt/memory_pool.py +52 -59
  42. sglang/srt/model_config.py +97 -2
  43. sglang/srt/models/chatglm.py +399 -0
  44. sglang/srt/models/commandr.py +369 -0
  45. sglang/srt/models/dbrx.py +406 -0
  46. sglang/srt/models/gemma.py +34 -38
  47. sglang/srt/models/gemma2.py +436 -0
  48. sglang/srt/models/grok.py +738 -0
  49. sglang/srt/models/llama2.py +47 -37
  50. sglang/srt/models/llama_classification.py +107 -0
  51. sglang/srt/models/llava.py +92 -27
  52. sglang/srt/models/llavavid.py +298 -0
  53. sglang/srt/models/minicpm.py +366 -0
  54. sglang/srt/models/mixtral.py +302 -127
  55. sglang/srt/models/mixtral_quant.py +372 -0
  56. sglang/srt/models/qwen.py +40 -35
  57. sglang/srt/models/qwen2.py +33 -36
  58. sglang/srt/models/qwen2_moe.py +473 -0
  59. sglang/srt/models/stablelm.py +33 -39
  60. sglang/srt/models/yivl.py +19 -26
  61. sglang/srt/openai_api_adapter.py +411 -0
  62. sglang/srt/{managers/openai_protocol.py → openai_protocol.py} +44 -19
  63. sglang/srt/sampling_params.py +2 -0
  64. sglang/srt/server.py +197 -481
  65. sglang/srt/server_args.py +190 -74
  66. sglang/srt/utils.py +460 -95
  67. sglang/test/test_programs.py +73 -10
  68. sglang/test/test_utils.py +226 -7
  69. sglang/utils.py +97 -27
  70. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/METADATA +74 -45
  71. sglang-0.1.21.dist-info/RECORD +82 -0
  72. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/WHEEL +1 -1
  73. sglang/srt/backend_config.py +0 -13
  74. sglang/srt/managers/router/infer_batch.py +0 -503
  75. sglang/srt/managers/router/manager.py +0 -79
  76. sglang/srt/managers/router/model_rpc.py +0 -686
  77. sglang/srt/managers/router/model_runner.py +0 -514
  78. sglang/srt/managers/router/scheduler.py +0 -70
  79. sglang-0.1.14.dist-info/RECORD +0 -64
  80. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/LICENSE +0 -0
  81. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/top_level.txt +0 -0
sglang/lang/ir.py CHANGED
@@ -23,6 +23,10 @@ class SglSamplingParams:
23
23
  frequency_penalty: float = 0.0
24
24
  presence_penalty: float = 0.0
25
25
  ignore_eos: bool = False
26
+ return_logprob: Optional[bool] = None
27
+ logprob_start_len: Optional[int] = (None,)
28
+ top_logprobs_num: Optional[int] = (None,)
29
+ return_text_in_logprobs: Optional[bool] = (None,)
26
30
 
27
31
  # for constrained generation, not included in to_xxx_kwargs
28
32
  dtype: Optional[str] = None
@@ -37,6 +41,11 @@ class SglSamplingParams:
37
41
  self.top_k,
38
42
  self.frequency_penalty,
39
43
  self.presence_penalty,
44
+ self.ignore_eos,
45
+ self.return_logprob,
46
+ self.logprob_start_len,
47
+ self.top_logprobs_num,
48
+ self.return_text_in_logprobs,
40
49
  )
41
50
 
42
51
  def to_openai_kwargs(self):
@@ -82,6 +91,19 @@ class SglSamplingParams:
82
91
  "top_k": self.top_k,
83
92
  }
84
93
 
94
+ def to_litellm_kwargs(self):
95
+ if self.regex is not None:
96
+ warnings.warn("Regular expression is not supported in the LiteLLM backend.")
97
+ return {
98
+ "max_tokens": self.max_new_tokens,
99
+ "stop": self.stop or None,
100
+ "temperature": self.temperature,
101
+ "top_p": self.top_p,
102
+ "top_k": self.top_k,
103
+ "frequency_penalty": self.frequency_penalty,
104
+ "presence_penalty": self.presence_penalty,
105
+ }
106
+
85
107
  def to_srt_kwargs(self):
86
108
  return {
87
109
  "max_new_tokens": self.max_new_tokens,
@@ -97,9 +119,9 @@ class SglSamplingParams:
97
119
 
98
120
 
99
121
  class SglFunction:
100
- def __init__(self, func, api_num_spec_tokens=None, bind_arguments=None):
122
+ def __init__(self, func, num_api_spec_tokens=None, bind_arguments=None):
101
123
  self.func = func
102
- self.api_num_spec_tokens = api_num_spec_tokens
124
+ self.num_api_spec_tokens = num_api_spec_tokens
103
125
  self.bind_arguments = bind_arguments or {}
104
126
  self.pin_prefix_rid = None
105
127
 
@@ -107,6 +129,7 @@ class SglFunction:
107
129
  argspec = inspect.getfullargspec(func)
108
130
  assert argspec.args[0] == "s", 'The first argument must be "s"'
109
131
  self.arg_names = argspec.args[1:]
132
+ self.arg_defaults = argspec.defaults if argspec.defaults is not None else []
110
133
 
111
134
  def bind(self, **kwargs):
112
135
  assert all(key in self.arg_names for key in kwargs)
@@ -125,6 +148,10 @@ class SglFunction:
125
148
  frequency_penalty: float = 0.0,
126
149
  presence_penalty: float = 0.0,
127
150
  ignore_eos: bool = False,
151
+ return_logprob: Optional[bool] = None,
152
+ logprob_start_len: Optional[int] = None,
153
+ top_logprobs_num: Optional[int] = None,
154
+ return_text_in_logprobs: Optional[bool] = None,
128
155
  stream: bool = False,
129
156
  backend=None,
130
157
  **kwargs,
@@ -140,6 +167,10 @@ class SglFunction:
140
167
  frequency_penalty=frequency_penalty,
141
168
  presence_penalty=presence_penalty,
142
169
  ignore_eos=ignore_eos,
170
+ return_logprob=return_logprob,
171
+ logprob_start_len=logprob_start_len,
172
+ top_logprobs_num=top_logprobs_num,
173
+ return_text_in_logprobs=return_text_in_logprobs,
143
174
  )
144
175
  backend = backend or global_config.default_backend
145
176
  return run_program(self, backend, args, kwargs, default_sampling_para, stream)
@@ -156,6 +187,10 @@ class SglFunction:
156
187
  frequency_penalty: float = 0.0,
157
188
  presence_penalty: float = 0.0,
158
189
  ignore_eos: bool = False,
190
+ return_logprob: Optional[bool] = None,
191
+ logprob_start_len: Optional[int] = None,
192
+ top_logprobs_num: Optional[int] = None,
193
+ return_text_in_logprobs: Optional[bool] = None,
159
194
  backend=None,
160
195
  num_threads: Union[str, int] = "auto",
161
196
  progress_bar: bool = False,
@@ -165,7 +200,20 @@ class SglFunction:
165
200
  assert isinstance(batch_kwargs, (list, tuple))
166
201
  if len(batch_kwargs) == 0:
167
202
  return []
168
- assert isinstance(batch_kwargs[0], dict)
203
+ if not isinstance(batch_kwargs[0], dict):
204
+ num_programs = len(batch_kwargs)
205
+ # change the list of argument values to dict of arg_name -> arg_value
206
+ batch_kwargs = [
207
+ {self.arg_names[i]: v for i, v in enumerate(arg_values)}
208
+ for arg_values in batch_kwargs
209
+ if isinstance(arg_values, (list, tuple))
210
+ and len(self.arg_names) - len(self.arg_defaults)
211
+ <= len(arg_values)
212
+ <= len(self.arg_names)
213
+ ]
214
+ # Ensure to raise an exception if the number of arguments mismatch
215
+ if len(batch_kwargs) != num_programs:
216
+ raise Exception("Given arguments mismatch the SGL function signature")
169
217
 
170
218
  default_sampling_para = SglSamplingParams(
171
219
  max_new_tokens=max_new_tokens,
@@ -176,6 +224,10 @@ class SglFunction:
176
224
  frequency_penalty=frequency_penalty,
177
225
  presence_penalty=presence_penalty,
178
226
  ignore_eos=ignore_eos,
227
+ return_logprob=return_logprob,
228
+ logprob_start_len=logprob_start_len,
229
+ top_logprobs_num=top_logprobs_num,
230
+ return_text_in_logprobs=return_text_in_logprobs,
179
231
  )
180
232
  backend = backend or global_config.default_backend
181
233
  return run_program_batch(
@@ -193,17 +245,11 @@ class SglFunction:
193
245
  backend = backend or global_config.default_backend
194
246
  return trace_program(self, kwargs, backend)
195
247
 
196
- def pin(self, backend=None):
197
- from sglang.lang.interpreter import pin_program
248
+ def cache(self, backend=None):
249
+ from sglang.lang.interpreter import cache_program
198
250
 
199
251
  backend = backend or global_config.default_backend
200
- return pin_program(self, backend)
201
-
202
- def unpin(self, backend=None):
203
- from sglang.lang.interpreter import unpin_program
204
-
205
- backend = backend or global_config.default_backend
206
- return unpin_program(self, backend)
252
+ return cache_program(self, backend)
207
253
 
208
254
  def compile(self, *, backend=None):
209
255
  from sglang.lang.compiler import compile_func
@@ -329,28 +375,42 @@ class SglArgument(SglExpr):
329
375
 
330
376
 
331
377
  class SglImage(SglExpr):
332
- def __init__(self, path):
378
+ def __init__(self, path: str):
333
379
  self.path = path
334
380
 
335
381
  def __repr__(self) -> str:
336
382
  return f"SglImage({self.path})"
337
383
 
338
384
 
385
+ class SglVideo(SglExpr):
386
+ def __init__(self, path: str, num_frames: int):
387
+ self.path = path
388
+ self.num_frames = num_frames
389
+
390
+ def __repr__(self) -> str:
391
+ return f"SglVideo({self.path}, {self.num_frames})"
392
+
393
+
339
394
  class SglGen(SglExpr):
340
395
  def __init__(
341
396
  self,
342
- name,
343
- max_new_tokens,
344
- stop,
345
- temperature,
346
- top_p,
347
- top_k,
348
- frequency_penalty,
349
- presence_penalty,
350
- ignore_eos,
351
- dtype,
352
- regex,
397
+ name: Optional[str] = None,
398
+ max_new_tokens: Optional[int] = None,
399
+ stop: Optional[Union[str, List[str]]] = None,
400
+ temperature: Optional[float] = None,
401
+ top_p: Optional[float] = None,
402
+ top_k: Optional[int] = None,
403
+ frequency_penalty: Optional[float] = None,
404
+ presence_penalty: Optional[float] = None,
405
+ ignore_eos: Optional[bool] = None,
406
+ return_logprob: Optional[bool] = None,
407
+ logprob_start_len: Optional[int] = None,
408
+ top_logprobs_num: Optional[int] = None,
409
+ return_text_in_logprobs: Optional[bool] = None,
410
+ dtype: Optional[type] = None,
411
+ regex: Optional[str] = None,
353
412
  ):
413
+ """Call the model to generate. See the meaning of the arguments in docs/sampling_params.md"""
354
414
  super().__init__()
355
415
  self.name = name
356
416
  self.sampling_params = SglSamplingParams(
@@ -362,6 +422,10 @@ class SglGen(SglExpr):
362
422
  frequency_penalty=frequency_penalty,
363
423
  presence_penalty=presence_penalty,
364
424
  ignore_eos=ignore_eos,
425
+ return_logprob=return_logprob,
426
+ logprob_start_len=logprob_start_len,
427
+ top_logprobs_num=top_logprobs_num,
428
+ return_text_in_logprobs=return_text_in_logprobs,
365
429
  dtype=dtype,
366
430
  regex=regex,
367
431
  )
@@ -371,7 +435,7 @@ class SglGen(SglExpr):
371
435
 
372
436
 
373
437
  class SglConstantText(SglExpr):
374
- def __init__(self, value):
438
+ def __init__(self, value: str):
375
439
  super().__init__()
376
440
  self.value = value
377
441
 
@@ -380,7 +444,7 @@ class SglConstantText(SglExpr):
380
444
 
381
445
 
382
446
  class SglRoleBegin(SglExpr):
383
- def __init__(self, role):
447
+ def __init__(self, role: str):
384
448
  super().__init__()
385
449
  self.role = role
386
450
 
@@ -389,7 +453,7 @@ class SglRoleBegin(SglExpr):
389
453
 
390
454
 
391
455
  class SglRoleEnd(SglExpr):
392
- def __init__(self, role):
456
+ def __init__(self, role: str):
393
457
  super().__init__()
394
458
  self.role = role
395
459
 
@@ -398,7 +462,7 @@ class SglRoleEnd(SglExpr):
398
462
 
399
463
 
400
464
  class SglSelect(SglExpr):
401
- def __init__(self, name, choices, temperature):
465
+ def __init__(self, name: str, choices: List[str], temperature: float):
402
466
  super().__init__()
403
467
  self.name = name
404
468
  self.choices = choices
@@ -409,7 +473,7 @@ class SglSelect(SglExpr):
409
473
 
410
474
 
411
475
  class SglFork(SglExpr):
412
- def __init__(self, number, position_ids_offset=None):
476
+ def __init__(self, number: int, position_ids_offset=None):
413
477
  super().__init__()
414
478
  self.number = number
415
479
  self.position_ids_offset = position_ids_offset
@@ -422,7 +486,7 @@ class SglFork(SglExpr):
422
486
 
423
487
 
424
488
  class SglGetForkItem(SglExpr):
425
- def __init__(self, index):
489
+ def __init__(self, index: int):
426
490
  super().__init__()
427
491
  self.index = index
428
492
 
@@ -431,7 +495,7 @@ class SglGetForkItem(SglExpr):
431
495
 
432
496
 
433
497
  class SglVariable(SglExpr):
434
- def __init__(self, name, source):
498
+ def __init__(self, name: str, source):
435
499
  super().__init__()
436
500
  self.name = name
437
501
  self.source = source
@@ -441,7 +505,7 @@ class SglVariable(SglExpr):
441
505
 
442
506
 
443
507
  class SglVarScopeBegin(SglExpr):
444
- def __init__(self, name):
508
+ def __init__(self, name: str):
445
509
  super().__init__()
446
510
  self.name = name
447
511
 
@@ -450,7 +514,7 @@ class SglVarScopeBegin(SglExpr):
450
514
 
451
515
 
452
516
  class SglVarScopeEnd(SglExpr):
453
- def __init__(self, name):
517
+ def __init__(self, name: str):
454
518
  super().__init__()
455
519
  self.name = name
456
520
 
@@ -472,4 +536,4 @@ class SglCommitLazy(SglExpr):
472
536
  super().__init__()
473
537
 
474
538
  def __repr__(self):
475
- return f"CommitLazy()"
539
+ return "CommitLazy()"
sglang/lang/tracer.py CHANGED
@@ -109,19 +109,21 @@ class TracerProgramState(ProgramState):
109
109
  ########### Public API ###########
110
110
  ##################################
111
111
 
112
- def fork(self, number: int, position_ids_offset: Optional[List[int]] = None):
112
+ def fork(self, size: int = 1, position_ids_offset: Optional[List[int]] = None):
113
+ assert size >= 1
114
+
113
115
  if self.only_trace_prefix:
114
116
  raise StopTracing()
115
117
 
116
- fork_node = SglFork(number)
118
+ fork_node = SglFork(size)
117
119
  fork_node.prev_node = self.last_node
118
120
 
119
121
  states = [
120
122
  TracerProgramState(self.backend, self.arguments, self.only_trace_prefix)
121
- for _ in range(number)
123
+ for _ in range(size)
122
124
  ]
123
125
 
124
- for i in range(number):
126
+ for i in range(size):
125
127
  node = SglGetForkItem(i)
126
128
  node.prev_node = fork_node
127
129
  states[i].last_node = node
sglang/launch_server.py CHANGED
@@ -1,6 +1,9 @@
1
+ """Launch the inference server."""
2
+
1
3
  import argparse
2
4
 
3
- from sglang.srt.server import ServerArgs, launch_server
5
+ from sglang.srt.server import launch_server
6
+ from sglang.srt.server_args import ServerArgs
4
7
 
5
8
  if __name__ == "__main__":
6
9
  parser = argparse.ArgumentParser()
@@ -0,0 +1,32 @@
1
+ """Launch the inference server for Llava-video model."""
2
+
3
+ import argparse
4
+ import multiprocessing as mp
5
+
6
+ from sglang.srt.server import ServerArgs, launch_server
7
+
8
+ if __name__ == "__main__":
9
+ model_overide_args = {}
10
+
11
+ model_overide_args["mm_spatial_pool_stride"] = 2
12
+ model_overide_args["architectures"] = ["LlavaVidForCausalLM"]
13
+ model_overide_args["num_frames"] = 16
14
+ model_overide_args["model_type"] = "llavavid"
15
+ if model_overide_args["num_frames"] == 32:
16
+ model_overide_args["rope_scaling"] = {"factor": 2.0, "type": "linear"}
17
+ model_overide_args["max_sequence_length"] = 4096 * 2
18
+ model_overide_args["tokenizer_model_max_length"] = 4096 * 2
19
+ model_overide_args["model_max_length"] = 4096 * 2
20
+
21
+ parser = argparse.ArgumentParser()
22
+ ServerArgs.add_cli_args(parser)
23
+ args = parser.parse_args()
24
+
25
+ if "34b" in args.model_path.lower():
26
+ model_overide_args["image_token_index"] = 64002
27
+
28
+ server_args = ServerArgs.from_cli_args(args)
29
+
30
+ pipe_reader, pipe_writer = mp.Pipe(duplex=False)
31
+
32
+ launch_server(server_args, pipe_writer, model_overide_args)
@@ -1,13 +1,20 @@
1
1
  import json
2
2
  from typing import Dict, Optional, Union
3
3
 
4
- from outlines.caching import cache as disk_cache
5
- from outlines.caching import disable_cache
6
- from outlines.fsm.fsm import RegexFSM
7
- from outlines.fsm.regex import FSMInfo, make_deterministic_fsm
8
- from outlines.models.transformers import TransformerTokenizer
9
4
  from pydantic import BaseModel
10
5
 
6
+ try:
7
+ from outlines.caching import cache as disk_cache
8
+ from outlines.caching import disable_cache
9
+ from outlines.fsm.guide import RegexGuide
10
+ from outlines.fsm.regex import FSMInfo, make_byte_level_fsm, make_deterministic_fsm
11
+ from outlines.models.transformers import TransformerTokenizer
12
+ except ImportError as e:
13
+ print(
14
+ f'\nError: {e}. Please install a new version of outlines by `pip install "outlines>=0.0.44"`\n'
15
+ )
16
+ raise
17
+
11
18
  try:
12
19
  from outlines.fsm.json_schema import build_regex_from_object
13
20
  except ImportError:
@@ -28,11 +35,12 @@ except ImportError:
28
35
 
29
36
 
30
37
  __all__ = [
31
- "RegexFSM",
38
+ "RegexGuide",
32
39
  "FSMInfo",
33
40
  "make_deterministic_fsm",
34
41
  "build_regex_from_object",
35
42
  "TransformerTokenizer",
36
43
  "disk_cache",
37
44
  "disable_cache",
45
+ "make_byte_level_fsm",
38
46
  ]
@@ -1,4 +1,6 @@
1
- from sglang.srt.constrained import RegexFSM, TransformerTokenizer
1
+ """Cache for the compressed finite state machine."""
2
+
3
+ from sglang.srt.constrained import RegexGuide, TransformerTokenizer
2
4
  from sglang.srt.constrained.base_cache import BaseCache
3
5
 
4
6
 
@@ -6,7 +8,12 @@ class FSMCache(BaseCache):
6
8
  def __init__(self, tokenizer_path, tokenizer_args_dict, enable=True):
7
9
  super().__init__(enable=enable)
8
10
 
11
+ if tokenizer_path.endswith(".json") or tokenizer_path.endswith(".model"):
12
+ # Do not support TiktokenTokenizer or SentencePieceTokenizer
13
+ return
14
+
9
15
  from importlib.metadata import version
16
+
10
17
  if version("outlines") >= "0.0.35":
11
18
  from transformers import AutoTokenizer
12
19
 
@@ -21,4 +28,4 @@ class FSMCache(BaseCache):
21
28
  )
22
29
 
23
30
  def init_value(self, regex):
24
- return RegexFSM(regex, self.outlines_tokenizer)
31
+ return RegexGuide(regex, self.outlines_tokenizer)
@@ -1,16 +1,43 @@
1
+ """
2
+ Faster constrained decoding.
3
+ Reference: https://lmsys.org/blog/2024-02-05-compressed-fsm/
4
+ """
5
+
6
+ import dataclasses
7
+ from collections import defaultdict
8
+
1
9
  import interegular
2
- from sglang.srt.constrained import FSMInfo, disk_cache, make_deterministic_fsm
10
+ import outlines.caching
11
+
12
+ from sglang.srt.constrained import (
13
+ FSMInfo,
14
+ disk_cache,
15
+ make_byte_level_fsm,
16
+ make_deterministic_fsm,
17
+ )
3
18
  from sglang.srt.constrained.base_cache import BaseCache
4
19
 
5
20
  IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
6
21
 
7
22
 
23
+ @dataclasses.dataclass
24
+ class JumpEdge:
25
+ symbol: str = None
26
+ symbol_next_state: int = None
27
+ byte: int = None
28
+ byte_next_state: int = None
29
+
30
+
8
31
  class JumpForwardMap:
9
32
  def __init__(self, regex_string):
10
33
  @disk_cache()
11
34
  def _init_state_to_jump_forward(regex_string):
12
35
  regex_pattern = interegular.parse_pattern(regex_string)
13
- regex_fsm, _ = make_deterministic_fsm(regex_pattern.to_fsm().reduce())
36
+
37
+ byte_fsm = make_byte_level_fsm(
38
+ regex_pattern.to_fsm().reduce(), keep_utf8=True
39
+ )
40
+ regex_fsm, _ = make_deterministic_fsm(byte_fsm)
14
41
 
15
42
  fsm_info: FSMInfo = regex_fsm.fsm_info
16
43
 
@@ -20,40 +47,93 @@ class JumpForwardMap:
20
47
  id_to_symbol.setdefault(id_, []).append(symbol)
21
48
 
22
49
  transitions = fsm_info.transitions
23
- dirty_states = set()
50
+ outgoings_ct = defaultdict(int)
24
51
  state_to_jump_forward = {}
25
52
 
26
53
  for (state, id_), next_state in transitions.items():
27
- if state in dirty_states:
54
+ if id_ == fsm_info.alphabet_anything_value:
28
55
  continue
29
- if state in state_to_jump_forward:
30
- dirty_states.add(state)
31
- del state_to_jump_forward[state]
32
- continue
33
- if len(id_to_symbol[id_]) > 1:
34
- dirty_states.add(state)
56
+ symbols = id_to_symbol[id_]
57
+ for c in symbols:
58
+ if len(c) > 1:
59
+ # Skip byte level transitions
60
+ continue
61
+
62
+ outgoings_ct[state] += 1
63
+ if outgoings_ct[state] > 1:
64
+ if state in state_to_jump_forward:
65
+ del state_to_jump_forward[state]
66
+ break
67
+
68
+ state_to_jump_forward[state] = JumpEdge(
69
+ symbol=c,
70
+ symbol_next_state=next_state,
71
+ )
72
+
73
+ # Process the byte level jump forward
74
+ outgoings_ct = defaultdict(int)
75
+ for (state, id_), next_state in transitions.items():
76
+ if id_ == fsm_info.alphabet_anything_value:
35
77
  continue
36
-
37
- state_to_jump_forward[state] = (id_to_symbol[id_][0], next_state)
78
+ symbols = id_to_symbol[id_]
79
+ for c in symbols:
80
+ byte_ = None
81
+ if len(c) == 1 and ord(c) < 0x80:
82
+ # ASCII character
83
+ byte_ = ord(c)
84
+ elif len(c) > 1:
85
+ # FIXME: This logic is due to the leading \x00
86
+ # https://github.com/outlines-dev/outlines/pull/930
87
+ byte_ = int(symbols[0][1:], 16)
88
+
89
+ if byte_ is not None:
90
+ outgoings_ct[state] += 1
91
+ if outgoings_ct[state] > 1:
92
+ if state in state_to_jump_forward:
93
+ del state_to_jump_forward[state]
94
+ break
95
+ e = state_to_jump_forward.get(state, JumpEdge())
96
+ e.byte = byte_
97
+ e.byte_next_state = next_state
98
+ state_to_jump_forward[state] = e
38
99
 
39
100
  return state_to_jump_forward
40
101
 
41
102
  self.state_to_jump_forward = _init_state_to_jump_forward(regex_string)
42
103
 
43
- def valid_states(self):
44
- return self.state_to_jump_forward.keys()
104
+ def jump_forward_symbol(self, state):
105
+ jump_forward_str = ""
106
+ next_state = state
107
+ while state in self.state_to_jump_forward:
108
+ e = self.state_to_jump_forward[state]
109
+ if e.symbol is None:
110
+ break
111
+ jump_forward_str += e.symbol
112
+ next_state = e.symbol_next_state
113
+ state = next_state
114
+
115
+ return jump_forward_str, next_state
45
116
 
46
- def jump_forward(self, state):
117
+ def jump_forward_byte(self, state):
47
118
  if state not in self.state_to_jump_forward:
48
119
  return None
49
120
 
50
- jump_forward_str = ""
121
+ jump_forward_bytes = []
51
122
  next_state = None
52
123
  while state in self.state_to_jump_forward:
53
- symbol, next_state = self.state_to_jump_forward[state]
54
- jump_forward_str += symbol
124
+ e = self.state_to_jump_forward[state]
125
+ assert e.byte is not None and e.byte_next_state is not None
126
+ jump_forward_bytes.append((e.byte, e.byte_next_state))
127
+ next_state = e.byte_next_state
55
128
  state = next_state
56
- return jump_forward_str, next_state
129
+
130
+ return jump_forward_bytes
131
+
132
+ def is_jump_forward_symbol_state(self, state):
133
+ return (
134
+ state in self.state_to_jump_forward
135
+ and self.state_to_jump_forward[state].symbol is not None
136
+ )
57
137
 
58
138
 
59
139
  class JumpForwardCache(BaseCache):
@@ -64,12 +144,21 @@ class JumpForwardCache(BaseCache):
64
144
  return JumpForwardMap(regex)
65
145
 
66
146
 
67
- def test_main():
68
- regex_string = r"The google's DNS sever address is " + IP_REGEX
147
+ def test_main(regex_string):
69
148
  jump_forward_map = JumpForwardMap(regex_string)
70
- for state in jump_forward_map.valid_states():
71
- print(state, f'"{jump_forward_map.jump_forward(state)}"')
149
+ for state, e in jump_forward_map.state_to_jump_forward.items():
150
+ if e.symbol is not None:
151
+ jump_forward_str, next_state = jump_forward_map.jump_forward_symbol(state)
152
+ print(f"{state} -> {next_state}", jump_forward_str)
153
+ bytes_ = jump_forward_map.jump_forward_byte(state)
154
+ print(f"{state} -> {bytes_[-1][1]}", [hex(b) for b, _ in bytes_])
72
155
 
73
156
 
74
157
  if __name__ == "__main__":
75
- test_main()
158
+ import outlines
159
+
160
+ outlines.caching.clear_cache()
161
+ test_main(r"The google's DNS sever address is " + IP_REGEX)
162
+ test_main(r"霍格沃茨特快列车|霍比特人比尔博")
163
+ # 霍格: \xe9\x9c\x8d \xe6\xa0\xbc ...
164
+ # 霍比: \xe9\x9c\x8d \xe6\xaf\x94 ...
@@ -1,10 +1,12 @@
1
+ """Conversation templates."""
2
+
1
3
  # Adapted from
2
4
  # https://github.com/lm-sys/FastChat/blob/main/fastchat/conversation.py
3
5
  import dataclasses
4
6
  from enum import IntEnum, auto
5
7
  from typing import Dict, List, Optional, Tuple, Union
6
8
 
7
- from sglang.srt.managers.openai_protocol import ChatCompletionRequest
9
+ from sglang.srt.openai_protocol import ChatCompletionRequest
8
10
 
9
11
 
10
12
  class SeparatorStyle(IntEnum):
@@ -400,7 +402,7 @@ register_conv_template(
400
402
  Conversation(
401
403
  name="chatml",
402
404
  system_template="<|im_start|>system\n{system_message}",
403
- system_message="You are an AI assistant.",
405
+ system_message="You are a helpful assistant.",
404
406
  roles=("<|im_start|>user", "<|im_start|>assistant"),
405
407
  sep_style=SeparatorStyle.CHATML,
406
408
  sep="<|im_end|>",
@@ -0,0 +1,18 @@
1
+ """
2
+ Flush the KV cache.
3
+
4
+ Usage:
5
+ python3 -m sglang.srt.flush_cache --url http://localhost:30000
6
+ """
7
+
8
+ import argparse
9
+
10
+ import requests
11
+
12
+ if __name__ == "__main__":
13
+ parser = argparse.ArgumentParser()
14
+ parser.add_argument("--url", type=str, default="http://localhost:30000")
15
+ args = parser.parse_args()
16
+
17
+ response = requests.get(args.url + "/flush_cache")
18
+ assert response.status_code == 200