sglang 0.3.5__py3-none-any.whl → 0.3.5.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (50) hide show
  1. sglang/bench_serving.py +113 -3
  2. sglang/srt/configs/model_config.py +5 -2
  3. sglang/srt/constrained/__init__.py +2 -66
  4. sglang/srt/constrained/base_grammar_backend.py +72 -0
  5. sglang/srt/constrained/outlines_backend.py +165 -0
  6. sglang/srt/constrained/outlines_jump_forward.py +182 -0
  7. sglang/srt/constrained/xgrammar_backend.py +114 -0
  8. sglang/srt/layers/attention/triton_ops/decode_attention.py +7 -0
  9. sglang/srt/layers/attention/triton_ops/extend_attention.py +6 -0
  10. sglang/srt/layers/fused_moe/fused_moe.py +23 -7
  11. sglang/srt/layers/quantization/base_config.py +4 -6
  12. sglang/srt/layers/vocab_parallel_embedding.py +216 -150
  13. sglang/srt/managers/io_struct.py +5 -3
  14. sglang/srt/managers/schedule_batch.py +14 -20
  15. sglang/srt/managers/scheduler.py +153 -94
  16. sglang/srt/managers/tokenizer_manager.py +81 -17
  17. sglang/srt/metrics/collector.py +211 -0
  18. sglang/srt/metrics/func_timer.py +108 -0
  19. sglang/srt/mm_utils.py +1 -1
  20. sglang/srt/model_executor/cuda_graph_runner.py +2 -2
  21. sglang/srt/model_executor/forward_batch_info.py +7 -3
  22. sglang/srt/model_executor/model_runner.py +2 -1
  23. sglang/srt/models/gemma2_reward.py +69 -0
  24. sglang/srt/models/gpt2.py +31 -37
  25. sglang/srt/models/internlm2_reward.py +62 -0
  26. sglang/srt/models/llama.py +11 -6
  27. sglang/srt/models/llama_reward.py +5 -26
  28. sglang/srt/models/qwen2_vl.py +5 -7
  29. sglang/srt/openai_api/adapter.py +6 -2
  30. sglang/srt/sampling/sampling_batch_info.py +2 -3
  31. sglang/srt/sampling/sampling_params.py +0 -14
  32. sglang/srt/server.py +58 -16
  33. sglang/srt/server_args.py +42 -22
  34. sglang/srt/utils.py +87 -0
  35. sglang/test/simple_eval_common.py +1 -1
  36. sglang/test/simple_eval_humaneval.py +2 -2
  37. sglang/test/simple_eval_mgsm.py +2 -2
  38. sglang/test/test_utils.py +18 -4
  39. sglang/utils.py +1 -0
  40. sglang/version.py +1 -1
  41. {sglang-0.3.5.dist-info → sglang-0.3.5.post1.dist-info}/METADATA +11 -7
  42. {sglang-0.3.5.dist-info → sglang-0.3.5.post1.dist-info}/RECORD +45 -42
  43. {sglang-0.3.5.dist-info → sglang-0.3.5.post1.dist-info}/WHEEL +1 -1
  44. sglang/srt/constrained/base_tool_cache.py +0 -65
  45. sglang/srt/constrained/bnf_cache.py +0 -61
  46. sglang/srt/constrained/fsm_cache.py +0 -95
  47. sglang/srt/constrained/grammar.py +0 -190
  48. sglang/srt/constrained/jump_forward.py +0 -203
  49. {sglang-0.3.5.dist-info → sglang-0.3.5.post1.dist-info}/LICENSE +0 -0
  50. {sglang-0.3.5.dist-info → sglang-0.3.5.post1.dist-info}/top_level.txt +0 -0
sglang/bench_serving.py CHANGED
@@ -596,12 +596,20 @@ def sample_random_requests(
596
596
 
597
597
  # Filter out sequences that are too long or too short
598
598
  input_requests: List[Tuple[str, int, int]] = []
599
- for i in range(num_prompts):
599
+ for data in dataset:
600
+ i = len(input_requests)
601
+ if i == num_prompts:
602
+ break
603
+
600
604
  # Tokenize the prompts and completions.
601
- prompt = dataset[i][0]
605
+ prompt = data[0]
602
606
  prompt_token_ids = tokenizer.encode(prompt)
603
607
  prompt_len = len(prompt_token_ids)
604
608
 
609
+ # Skip empty prompt
610
+ if prompt_len == 0:
611
+ continue
612
+
605
613
  if prompt_len > input_lens[i]:
606
614
  input_ids = prompt_token_ids[: input_lens[i]]
607
615
  else:
@@ -627,6 +635,66 @@ def sample_random_requests(
627
635
  return input_requests
628
636
 
629
637
 
638
+ def gen_prompt(tokenizer, token_num):
639
+ """Generate a random prompt of specified token length using tokenizer vocabulary."""
640
+ all_available_tokens = list(tokenizer.get_vocab().values())
641
+ selected_tokens = random.choices(all_available_tokens, k=token_num)
642
+ return tokenizer.decode(selected_tokens)
643
+
644
+
645
+ def sample_generated_shared_prefix_requests(
646
+ num_groups: int,
647
+ prompts_per_group: int,
648
+ system_prompt_len: int,
649
+ question_len: int,
650
+ output_len: int,
651
+ tokenizer: PreTrainedTokenizerBase,
652
+ ) -> List[Tuple[str, int, int]]:
653
+ """Generate benchmark requests with shared system prompts using random tokens."""
654
+ # Generate system prompts for each group
655
+ system_prompts = []
656
+ for _ in range(num_groups):
657
+ system_prompt = gen_prompt(tokenizer, system_prompt_len)
658
+ system_prompts.append(system_prompt)
659
+
660
+ # Generate questions
661
+ questions = []
662
+ for _ in range(num_groups * prompts_per_group):
663
+ question = gen_prompt(tokenizer, question_len)
664
+ questions.append(question)
665
+
666
+ # Combine system prompts with questions
667
+ input_requests = []
668
+ total_input_tokens = 0
669
+ total_output_tokens = 0
670
+
671
+ for group_idx in range(num_groups):
672
+ system_prompt = system_prompts[group_idx]
673
+ for prompt_idx in range(prompts_per_group):
674
+ question = questions[group_idx * prompts_per_group + prompt_idx]
675
+ full_prompt = f"{system_prompt}\n\n{question}"
676
+ prompt_len = len(tokenizer.encode(full_prompt))
677
+
678
+ input_requests.append((full_prompt, prompt_len, output_len))
679
+ total_input_tokens += prompt_len
680
+ total_output_tokens += output_len
681
+
682
+ print(f"\nGenerated shared prefix dataset statistics:")
683
+ print(f"Number of groups: {num_groups}")
684
+ print(f"Prompts per group: {prompts_per_group}")
685
+ print(f"Total prompts: {len(input_requests)}")
686
+ print(f"Total input tokens: {total_input_tokens}")
687
+ print(f"Total output tokens: {total_output_tokens}")
688
+ print(
689
+ f"Average system prompt length: {sum(len(tokenizer.encode(sp)) for sp in system_prompts) / len(system_prompts):.1f} tokens"
690
+ )
691
+ print(
692
+ f"Average question length: {sum(len(tokenizer.encode(q)) for q in questions) / len(questions):.1f} tokens\n"
693
+ )
694
+
695
+ return input_requests
696
+
697
+
630
698
  async def get_request(
631
699
  input_requests: List[Tuple[str, int, int]],
632
700
  request_rate: float,
@@ -1048,6 +1116,15 @@ def run_benchmark(args_: argparse.Namespace):
1048
1116
  tokenizer=tokenizer,
1049
1117
  dataset_path=args.dataset_path,
1050
1118
  )
1119
+ elif args.dataset_name == "generated-shared-prefix":
1120
+ input_requests = sample_generated_shared_prefix_requests(
1121
+ num_groups=args.gen_num_groups,
1122
+ prompts_per_group=args.gen_prompts_per_group,
1123
+ system_prompt_len=args.gen_system_prompt_len,
1124
+ question_len=args.gen_question_len,
1125
+ output_len=args.gen_output_len,
1126
+ tokenizer=tokenizer,
1127
+ )
1051
1128
  else:
1052
1129
  raise ValueError(f"Unknown dataset: {args.dataset_name}")
1053
1130
 
@@ -1121,7 +1198,7 @@ if __name__ == "__main__":
1121
1198
  "--dataset-name",
1122
1199
  type=str,
1123
1200
  default="sharegpt",
1124
- choices=["sharegpt", "random"],
1201
+ choices=["sharegpt", "random", "generated-shared-prefix"],
1125
1202
  help="Name of the dataset to benchmark on.",
1126
1203
  )
1127
1204
  parser.add_argument(
@@ -1208,5 +1285,38 @@ if __name__ == "__main__":
1208
1285
  help="Append given JSON object to the request payload. You can use this to specify"
1209
1286
  "additional generate params like sampling params.",
1210
1287
  )
1288
+
1289
+ group = parser.add_argument_group("generated-shared-prefix dataset arguments")
1290
+ group.add_argument(
1291
+ "--gen-num-groups",
1292
+ type=int,
1293
+ default=64,
1294
+ help="Number of system prompt groups for generated-shared-prefix dataset",
1295
+ )
1296
+ group.add_argument(
1297
+ "--gen-prompts-per-group",
1298
+ type=int,
1299
+ default=16,
1300
+ help="Number of prompts per system prompt group for generated-shared-prefix dataset",
1301
+ )
1302
+ group.add_argument(
1303
+ "--gen-system-prompt-len",
1304
+ type=int,
1305
+ default=2048,
1306
+ help="Target length in tokens for system prompts in generated-shared-prefix dataset",
1307
+ )
1308
+ group.add_argument(
1309
+ "--gen-question-len",
1310
+ type=int,
1311
+ default=128,
1312
+ help="Target length in tokens for questions in generated-shared-prefix dataset",
1313
+ )
1314
+ group.add_argument(
1315
+ "--gen-output-len",
1316
+ type=int,
1317
+ default=256,
1318
+ help="Target length in tokens for outputs in generated-shared-prefix dataset",
1319
+ )
1320
+
1211
1321
  args = parser.parse_args()
1212
1322
  run_benchmark(args)
@@ -39,7 +39,7 @@ class ModelConfig:
39
39
  revision: Optional[str] = None,
40
40
  context_length: Optional[int] = None,
41
41
  model_override_args: Optional[dict] = None,
42
- is_embedding: Optional[bool] = None
42
+ is_embedding: Optional[bool] = None,
43
43
  ) -> None:
44
44
  # Parse args
45
45
  self.model_override_args = json.loads(model_override_args)
@@ -52,7 +52,9 @@ class ModelConfig:
52
52
  self.hf_text_config = get_hf_text_config(self.hf_config)
53
53
 
54
54
  # Check model type
55
- self.is_generation = is_generation_model(self.hf_config.architectures, is_embedding)
55
+ self.is_generation = is_generation_model(
56
+ self.hf_config.architectures, is_embedding
57
+ )
56
58
  self.is_multimodal = is_multimodal_model(self.hf_config.architectures)
57
59
  self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures)
58
60
 
@@ -208,6 +210,7 @@ def is_generation_model(model_architectures: List[str], is_embedding: bool = Fal
208
210
  or "MistralModel" in model_architectures
209
211
  or "LlamaForSequenceClassification" in model_architectures
210
212
  or "LlamaForSequenceClassificationWithNormal_Weights" in model_architectures
213
+ or "InternLM2ForRewardModel" in model_architectures
211
214
  ):
212
215
  return False
213
216
  else:
@@ -13,69 +13,5 @@ See the License for the specific language governing permissions and
13
13
  limitations under the License.
14
14
  """
15
15
 
16
- """For constrained decoding."""
17
-
18
- import json
19
- from typing import Dict, Optional, Union
20
-
21
- from pydantic import BaseModel
22
-
23
- try:
24
- from outlines.caching import cache as disk_cache
25
- from outlines.caching import disable_cache
26
- from outlines.fsm.guide import RegexGuide
27
- from outlines.fsm.regex import FSMInfo, make_byte_level_fsm, make_deterministic_fsm
28
- from outlines.models.transformers import TransformerTokenizer
29
- except ImportError as e:
30
- print(
31
- f'\nError: {e}. Please install a new version of outlines by `pip install "outlines>=0.0.44"`\n'
32
- )
33
- raise
34
-
35
- try:
36
- from outlines.fsm.json_schema import build_regex_from_object
37
- except ImportError:
38
- # Since outlines 0.0.32, build_regex_from_object is replaced by build_regex_from_schema,
39
- # which only accepts string schema as input.
40
- from outlines.fsm.json_schema import build_regex_from_schema
41
-
42
- def build_regex_from_object(
43
- object: Union[str, BaseModel, Dict], whitespace_pattern: Optional[str] = None
44
- ):
45
- if isinstance(object, type(BaseModel)):
46
- schema = json.dumps(object.model_json_schema())
47
- elif isinstance(object, Dict):
48
- schema = json.dumps(object)
49
- else:
50
- schema = object
51
- return build_regex_from_schema(schema, whitespace_pattern)
52
-
53
-
54
- try:
55
- from xgrammar import (
56
- GrammarMatcher,
57
- GrammarMatcherInitContext,
58
- GrammarMatcherInitContextCache,
59
- )
60
- except ImportError as e:
61
-
62
- class Dummy:
63
- pass
64
-
65
- GrammarMatcher = Dummy
66
- GrammarMatcherInitContext = Dummy
67
- GrammarMatcherInitContextCache = Dummy
68
-
69
- __all__ = [
70
- "RegexGuide",
71
- "FSMInfo",
72
- "make_deterministic_fsm",
73
- "build_regex_from_object",
74
- "TransformerTokenizer",
75
- "disk_cache",
76
- "disable_cache",
77
- "make_byte_level_fsm",
78
- "GrammarMatcher",
79
- "GrammarMatcherInitContext",
80
- "GrammarMatcherInitContextCache",
81
- ]
16
+ # TODO(lmzheng): make this an optional dependency
17
+ from sglang.srt.constrained.outlines_backend import build_regex_from_object
@@ -0,0 +1,72 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
16
+ """The baseclass of backends for grammar-guided constrained decoding."""
17
+
18
+ from concurrent.futures import Future, ThreadPoolExecutor
19
+ from dataclasses import dataclass
20
+ from threading import Event, Lock
21
+ from typing import Any, Optional, Tuple
22
+
23
+
24
+ @dataclass
25
+ class CacheEntry:
26
+ value: Any
27
+ event: Event
28
+
29
+
30
+ class BaseGrammarObject:
31
+ pass
32
+
33
+
34
+ class BaseGrammarBackend:
35
+ def __init__(self):
36
+ self.executor = ThreadPoolExecutor()
37
+ self.cache = {}
38
+ self.cache_lock = Lock()
39
+
40
+ def init_value(self, key: Tuple[str, str]) -> BaseGrammarObject:
41
+ with self.cache_lock:
42
+ if key in self.cache:
43
+ cache_hit = True
44
+ entry = self.cache[key]
45
+ else:
46
+ cache_hit = False
47
+ entry = CacheEntry(None, Event())
48
+ self.cache[key] = entry
49
+
50
+ if cache_hit:
51
+ entry.event.wait()
52
+ else:
53
+ entry.value = self.init_value_impl(key)
54
+ entry.event.set()
55
+ return entry.value.copy()
56
+
57
+ def init_value_impl(self, key: Tuple[str, str]) -> BaseGrammarObject:
58
+ raise NotImplementedError()
59
+
60
+ def get_cached_value(self, key: Tuple[str, str]) -> Optional[BaseGrammarObject]:
61
+ with self.cache_lock:
62
+ entry = self.cache.get(key)
63
+ if not entry or not entry.event.is_set():
64
+ return None
65
+ return self.cache[key].value.copy()
66
+
67
+ def get_future_value(self, key: Tuple[str, str]) -> Future:
68
+ return self.executor.submit(self.init_value, key)
69
+
70
+ def reset(self):
71
+ with self.cache_lock:
72
+ self.cache.clear()
@@ -0,0 +1,165 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
16
+ """Constrained decoding with outlines backend."""
17
+
18
+ import json
19
+ import logging
20
+ from typing import Dict, List, Optional, Tuple, Union
21
+
22
+ import torch
23
+ from outlines.fsm.guide import RegexGuide
24
+ from outlines.models.transformers import TransformerTokenizer
25
+
26
+ from sglang.srt.constrained.base_grammar_backend import (
27
+ BaseGrammarBackend,
28
+ BaseGrammarObject,
29
+ )
30
+ from sglang.srt.constrained.outlines_jump_forward import OutlinesJumpForwardMap
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ try:
36
+ from outlines.fsm.json_schema import build_regex_from_object
37
+ except ImportError:
38
+ # Since outlines 0.0.32, build_regex_from_object is replaced by build_regex_from_schema,
39
+ # which only accepts string schema as input.
40
+ from outlines.fsm.json_schema import build_regex_from_schema
41
+ from pydantic import BaseModel
42
+
43
+ def build_regex_from_object(
44
+ object: Union[str, BaseModel, Dict], whitespace_pattern: Optional[str] = None
45
+ ):
46
+ if isinstance(object, type(BaseModel)):
47
+ schema = json.dumps(object.model_json_schema())
48
+ elif isinstance(object, Dict):
49
+ schema = json.dumps(object)
50
+ else:
51
+ schema = object
52
+ return build_regex_from_schema(schema, whitespace_pattern)
53
+
54
+
55
+ class OutlinesGrammar(BaseGrammarObject):
56
+ def __init__(
57
+ self,
58
+ guide: RegexGuide,
59
+ jump_forward_map: Union[OutlinesJumpForwardMap, None],
60
+ ) -> None:
61
+ self.guide = guide
62
+ self.jump_forward_map = jump_forward_map
63
+ self.state = 0
64
+
65
+ def accept_token(self, token: int):
66
+ self.state = self.guide.get_next_state(self.state, token)
67
+
68
+ def try_jump_forward(self, tokenizer) -> Optional[Tuple]:
69
+ if not self.jump_forward_map:
70
+ return None
71
+
72
+ jump_forward_bytes = self.jump_forward_map.jump_forward_byte(self.state)
73
+ if jump_forward_bytes is None or len(jump_forward_bytes) <= 1:
74
+ return None
75
+
76
+ # preprocess the jump forward string
77
+ suffix_bytes = []
78
+ continuation_range = range(0x80, 0xC0)
79
+ cur_state = self.state
80
+ while (
81
+ len(jump_forward_bytes) and jump_forward_bytes[0][0] in continuation_range
82
+ ):
83
+ # continuation bytes
84
+ byte_edge = jump_forward_bytes.pop(0)
85
+ suffix_bytes.append(byte_edge[0])
86
+ cur_state = byte_edge[1]
87
+
88
+ suffix_tokens = [f"<0x{hex(b)[2:].upper()}>" for b in suffix_bytes]
89
+ suffix_ids = tokenizer.convert_tokens_to_ids(suffix_tokens)
90
+ return suffix_ids, cur_state
91
+
92
+ def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:
93
+ _, cur_state = helper
94
+ return self.jump_forward_map.jump_forward_symbol(cur_state)
95
+
96
+ def jump_and_retokenize(
97
+ self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
98
+ ):
99
+ self.state = next_state
100
+
101
+ def fill_vocab_mask(self, vocab_mask: torch.Tensor):
102
+ vocab_mask.fill_(1)
103
+ vocab_mask[self.guide.get_next_instruction(self.state).tokens] = 0
104
+
105
+ def copy(self):
106
+ return OutlinesGrammar(self.guide, self.jump_forward_map)
107
+
108
+
109
+ class OutlinesGrammarBackend(BaseGrammarBackend):
110
+ def __init__(
111
+ self,
112
+ tokenizer,
113
+ whitespace_pattern: bool,
114
+ allow_jump_forward: bool,
115
+ ):
116
+ super().__init__()
117
+
118
+ try:
119
+ self.outlines_tokenizer = TransformerTokenizer(tokenizer)
120
+ except AttributeError:
121
+ # FIXME: tmp fix for chatglm2 & chatglm3 (pad_token_id=0)
122
+ origin_pad_token_id = tokenizer.pad_token_id
123
+
124
+ def fset(self, value):
125
+ self._value = value
126
+
127
+ type(tokenizer).pad_token_id = property(
128
+ fget=type(tokenizer).pad_token_id.fget, fset=fset
129
+ )
130
+ self.outlines_tokenizer = TransformerTokenizer(tokenizer)
131
+ self.outlines_tokenizer.tokenizer.pad_token_id = origin_pad_token_id
132
+ self.outlines_tokenizer.pad_token_id = origin_pad_token_id
133
+ self.outlines_tokenizer.pad_token = (
134
+ self.outlines_tokenizer.tokenizer.pad_token
135
+ )
136
+ self.outlines_tokenizer.vocabulary = (
137
+ self.outlines_tokenizer.tokenizer.get_vocab()
138
+ )
139
+ self.allow_jump_forward = allow_jump_forward
140
+ self.whitespace_pattern = whitespace_pattern
141
+
142
+ def init_value_impl(self, key: Tuple[str, str]) -> OutlinesGrammar:
143
+ key_type, key_string = key
144
+ if key_type == "json":
145
+ try:
146
+ regex = build_regex_from_object(
147
+ key_string,
148
+ whitespace_pattern=self.whitespace_pattern,
149
+ )
150
+ except NotImplementedError as e:
151
+ logger.warning(
152
+ f"skip invalid json schema: json_schema={key_string}, {e=}"
153
+ )
154
+ return None, key_string
155
+ elif key_type == "regex":
156
+ regex = key_string
157
+ else:
158
+ raise ValueError(f"Invalid key_type: {key_type}")
159
+
160
+ guide = RegexGuide(regex, self.outlines_tokenizer)
161
+ if self.allow_jump_forward:
162
+ jump_forward_map = OutlinesJumpForwardMap(regex)
163
+ else:
164
+ jump_forward_map = None
165
+ return OutlinesGrammar(guide, jump_forward_map)
@@ -0,0 +1,182 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
16
+ """
17
+ Faster constrained decoding with jump forward decoding / compressed finite state machine.
18
+ Reference: https://lmsys.org/blog/2024-02-05-compressed-fsm/
19
+ """
20
+
21
+ import dataclasses
22
+ import logging
23
+ from collections import defaultdict
24
+
25
+ import interegular
26
+ from interegular import InvalidSyntax
27
+ from outlines.caching import cache as disk_cache
28
+ from outlines.fsm.regex import FSMInfo, make_byte_level_fsm, make_deterministic_fsm
29
+
30
+ IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ @dataclasses.dataclass
36
+ class JumpEdge:
37
+ symbol: str = None
38
+ symbol_next_state: int = None
39
+ byte: int = None
40
+ byte_next_state: int = None
41
+
42
+
43
+ @disk_cache()
44
+ def init_state_to_jump_forward(regex_string):
45
+ try:
46
+ regex_pattern = interegular.parse_pattern(regex_string)
47
+ except InvalidSyntax as e:
48
+ logger.warning(f"skip invalid regex: {regex_string}, {e=}")
49
+ return
50
+
51
+ byte_fsm = make_byte_level_fsm(regex_pattern.to_fsm().reduce(), keep_utf8=True)
52
+ regex_fsm, _ = make_deterministic_fsm(byte_fsm)
53
+
54
+ fsm_info: FSMInfo = regex_fsm.fsm_info
55
+
56
+ symbol_to_id = fsm_info.alphabet_symbol_mapping
57
+ id_to_symbol = {}
58
+ for symbol, id_ in symbol_to_id.items():
59
+ id_to_symbol.setdefault(id_, []).append(symbol)
60
+
61
+ transitions = fsm_info.transitions
62
+
63
+ outgoings_ct = defaultdict(int)
64
+ # NOTE(lsyin): Final states can lead to terminate, so they have one outgoing edge naturally
65
+ for s in fsm_info.finals:
66
+ outgoings_ct[s] = 1
67
+
68
+ state_to_jump_forward = {}
69
+ for (state, id_), next_state in transitions.items():
70
+ if id_ == fsm_info.alphabet_anything_value:
71
+ # Arbitrarily symbol cannot be recognized as jump forward
72
+ continue
73
+
74
+ symbols = id_to_symbol[id_]
75
+ for c in symbols:
76
+ if len(c) > 1:
77
+ # Skip byte level transitions like c = "5E"
78
+ continue
79
+
80
+ outgoings_ct[state] += 1
81
+ if outgoings_ct[state] > 1:
82
+ if state in state_to_jump_forward:
83
+ del state_to_jump_forward[state]
84
+ break
85
+
86
+ state_to_jump_forward[state] = JumpEdge(
87
+ symbol=c,
88
+ symbol_next_state=next_state,
89
+ )
90
+
91
+ # Process the byte level jump forward
92
+ outgoings_ct = defaultdict(int)
93
+ for s in fsm_info.finals:
94
+ outgoings_ct[s] = 1
95
+
96
+ for (state, id_), next_state in transitions.items():
97
+ if id_ == fsm_info.alphabet_anything_value:
98
+ continue
99
+ symbols = id_to_symbol[id_]
100
+ for c in symbols:
101
+ byte_ = None
102
+ if len(c) == 1 and ord(c) < 0x80:
103
+ # ASCII character
104
+ byte_ = ord(c)
105
+ elif len(c) > 1:
106
+ # FIXME: This logic is due to the leading \x00
107
+ # https://github.com/outlines-dev/outlines/pull/930
108
+ byte_ = int(symbols[0][1:], 16)
109
+
110
+ if byte_ is not None:
111
+ outgoings_ct[state] += 1
112
+ if outgoings_ct[state] > 1:
113
+ if state in state_to_jump_forward:
114
+ del state_to_jump_forward[state]
115
+ break
116
+ e = state_to_jump_forward.get(state, JumpEdge())
117
+ e.byte = byte_
118
+ e.byte_next_state = next_state
119
+ state_to_jump_forward[state] = e
120
+
121
+ return state_to_jump_forward
122
+
123
+
124
+ class OutlinesJumpForwardMap:
125
+ def __init__(self, regex_string):
126
+ self.state_to_jump_forward = init_state_to_jump_forward(regex_string)
127
+
128
+ def jump_forward_symbol(self, state):
129
+ jump_forward_str = ""
130
+ next_state = state
131
+ while state in self.state_to_jump_forward:
132
+ e = self.state_to_jump_forward[state]
133
+ if e.symbol is None:
134
+ break
135
+ jump_forward_str += e.symbol
136
+ next_state = e.symbol_next_state
137
+ state = next_state
138
+
139
+ return jump_forward_str, next_state
140
+
141
+ def jump_forward_byte(self, state):
142
+ if state not in self.state_to_jump_forward:
143
+ return None
144
+
145
+ jump_forward_bytes = []
146
+ next_state = None
147
+ while state in self.state_to_jump_forward:
148
+ e = self.state_to_jump_forward[state]
149
+ assert e.byte is not None and e.byte_next_state is not None
150
+ jump_forward_bytes.append((e.byte, e.byte_next_state))
151
+ next_state = e.byte_next_state
152
+ state = next_state
153
+
154
+ return jump_forward_bytes
155
+
156
+ def is_jump_forward_symbol_state(self, state):
157
+ return (
158
+ state in self.state_to_jump_forward
159
+ and self.state_to_jump_forward[state].symbol is not None
160
+ )
161
+
162
+
163
+ def test_main(regex_string):
164
+ jump_forward_map = OutlinesJumpForwardMap(regex_string)
165
+ for state, e in jump_forward_map.state_to_jump_forward.items():
166
+ if e.symbol is not None:
167
+ jump_forward_str, next_state = jump_forward_map.jump_forward_symbol(state)
168
+ print(f"{state} -> {next_state}", jump_forward_str)
169
+ bytes_ = jump_forward_map.jump_forward_byte(state)
170
+ print(f"{state} -> {bytes_[-1][1]}", [hex(b) for b, _ in bytes_])
171
+
172
+
173
+ if __name__ == "__main__":
174
+ import outlines
175
+
176
+ outlines.caching.clear_cache()
177
+ test_main(r"The google's DNS sever address is " + IP_REGEX)
178
+ test_main(r"霍格沃茨特快列车|霍比特人比尔博")
179
+ # 霍格: \xe9\x9c\x8d \xe6\xa0\xbc ...
180
+ # 霍比: \xe9\x9c\x8d \xe6\xaf\x94 ...
181
+
182
+ test_main(r"[-+]?[0-9]+[ ]*")