sglang 0.1.12__py3-none-any.whl → 0.1.14__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (40) hide show
  1. sglang/__init__.py +1 -1
  2. sglang/api.py +14 -0
  3. sglang/backend/anthropic.py +18 -12
  4. sglang/backend/base_backend.py +6 -0
  5. sglang/backend/openai.py +41 -12
  6. sglang/backend/runtime_endpoint.py +57 -6
  7. sglang/lang/chat_template.py +47 -26
  8. sglang/lang/interpreter.py +15 -2
  9. sglang/lang/ir.py +1 -1
  10. sglang/srt/constrained/__init__.py +23 -1
  11. sglang/srt/constrained/fsm_cache.py +14 -3
  12. sglang/srt/layers/context_flashattention_nopad.py +1 -1
  13. sglang/srt/layers/extend_attention.py +7 -6
  14. sglang/srt/layers/radix_attention.py +2 -10
  15. sglang/srt/layers/token_attention.py +12 -4
  16. sglang/srt/managers/io_struct.py +3 -1
  17. sglang/srt/managers/router/infer_batch.py +6 -2
  18. sglang/srt/managers/router/model_rpc.py +45 -32
  19. sglang/srt/managers/router/model_runner.py +40 -25
  20. sglang/srt/managers/tokenizer_manager.py +2 -0
  21. sglang/srt/model_config.py +12 -5
  22. sglang/srt/models/gemma.py +340 -0
  23. sglang/srt/models/llama2.py +5 -5
  24. sglang/srt/models/llava.py +2 -4
  25. sglang/srt/models/mixtral.py +5 -5
  26. sglang/srt/models/qwen.py +4 -4
  27. sglang/srt/models/qwen2.py +5 -5
  28. sglang/srt/models/stablelm.py +293 -0
  29. sglang/srt/server.py +111 -47
  30. sglang/srt/server_args.py +44 -9
  31. sglang/srt/utils.py +1 -0
  32. sglang/test/test_utils.py +1 -1
  33. sglang/utils.py +15 -12
  34. {sglang-0.1.12.dist-info → sglang-0.1.14.dist-info}/METADATA +16 -6
  35. sglang-0.1.14.dist-info/RECORD +64 -0
  36. {sglang-0.1.12.dist-info → sglang-0.1.14.dist-info}/WHEEL +1 -1
  37. sglang/srt/models/gpt_neox.py +0 -274
  38. sglang-0.1.12.dist-info/RECORD +0 -63
  39. {sglang-0.1.12.dist-info → sglang-0.1.14.dist-info}/LICENSE +0 -0
  40. {sglang-0.1.12.dist-info → sglang-0.1.14.dist-info}/top_level.txt +0 -0
sglang/__init__.py CHANGED
@@ -1,4 +1,4 @@
1
- __version__ = "0.1.12"
1
+ __version__ = "0.1.14"
2
2
 
3
3
  from sglang.api import *
4
4
  from sglang.global_config import global_config
sglang/api.py CHANGED
@@ -44,6 +44,20 @@ def set_default_backend(backend: BaseBackend):
44
44
  global_config.default_backend = backend
45
45
 
46
46
 
47
+ def flush_cache(backend: BaseBackend = None):
48
+ backend = backend or global_config.default_backend
49
+ if backend is None:
50
+ return False
51
+ return backend.flush_cache()
52
+
53
+
54
+ def get_server_args(backend: BaseBackend = None):
55
+ backend = backend or global_config.default_backend
56
+ if backend is None:
57
+ return None
58
+ return backend.get_server_args()
59
+
60
+
47
61
  def gen(
48
62
  name: Optional[str] = None,
49
63
  max_tokens: Optional[int] = None,
@@ -30,13 +30,17 @@ class Anthropic(BaseBackend):
30
30
  s: StreamExecutor,
31
31
  sampling_params: SglSamplingParams,
32
32
  ):
33
- prompt = s.text_
34
- ret = anthropic.Anthropic().completions.create(
33
+ if s.messages_:
34
+ messages = s.messages_
35
+ else:
36
+ messages = [{"role": "user", "content": s.text_}]
37
+
38
+ ret = anthropic.Anthropic().messages.create(
35
39
  model=self.model_name,
36
- prompt=prompt,
40
+ messages=messages,
37
41
  **sampling_params.to_anthropic_kwargs(),
38
42
  )
39
- comp = ret.completion
43
+ comp = ret.content[0].text
40
44
 
41
45
  return comp, {}
42
46
 
@@ -45,13 +49,15 @@ class Anthropic(BaseBackend):
45
49
  s: StreamExecutor,
46
50
  sampling_params: SglSamplingParams,
47
51
  ):
48
- prompt = s.text_
49
- generator = anthropic.Anthropic().completions.create(
52
+ if s.messages_:
53
+ messages = s.messages_
54
+ else:
55
+ messages = [{"role": "user", "content": s.text_}]
56
+
57
+ with anthropic.Anthropic().messages.stream(
50
58
  model=self.model_name,
51
- prompt=prompt,
52
- stream=True,
59
+ messages=messages,
53
60
  **sampling_params.to_anthropic_kwargs(),
54
- )
55
-
56
- for ret in generator:
57
- yield ret.completion, {}
61
+ ) as stream:
62
+ for text in stream.text_stream:
63
+ yield text, {}
@@ -72,3 +72,9 @@ class BaseBackend:
72
72
 
73
73
  def shutdown(self):
74
74
  pass
75
+
76
+ def flush_cache(self):
77
+ pass
78
+
79
+ def get_server_args(self):
80
+ pass
sglang/backend/openai.py CHANGED
@@ -4,7 +4,7 @@ from typing import Callable, List, Optional, Union
4
4
 
5
5
  import numpy as np
6
6
  from sglang.backend.base_backend import BaseBackend
7
- from sglang.lang.chat_template import get_chat_template
7
+ from sglang.lang.chat_template import ChatTemplate, get_chat_template_by_model_path
8
8
  from sglang.lang.interpreter import StreamExecutor
9
9
  from sglang.lang.ir import SglSamplingParams
10
10
 
@@ -41,23 +41,45 @@ INSTRUCT_MODEL_NAMES = [
41
41
 
42
42
 
43
43
  class OpenAI(BaseBackend):
44
- def __init__(self, model_name, *args, **kwargs):
44
+ def __init__(
45
+ self,
46
+ model_name: str,
47
+ is_chat_model: Optional[bool] = None,
48
+ chat_template: Optional[ChatTemplate] = None,
49
+ is_azure: bool = False,
50
+ *args,
51
+ **kwargs,
52
+ ):
45
53
  super().__init__()
46
54
 
47
55
  if isinstance(openai, Exception):
48
56
  raise openai
49
57
 
50
- self.client = openai.OpenAI(*args, **kwargs)
58
+ if is_azure:
59
+ self.client = openai.AzureOpenAI(*args, **kwargs)
60
+ else:
61
+ self.client = openai.OpenAI(*args, **kwargs)
62
+
51
63
  self.model_name = model_name
52
- self.tokenizer = tiktoken.encoding_for_model(model_name)
64
+ try:
65
+ self.tokenizer = tiktoken.encoding_for_model(model_name)
66
+ except KeyError:
67
+ self.tokenizer = tiktoken.get_encoding("cl100k_base")
53
68
  self.logit_bias_int = create_logit_bias_int(self.tokenizer)
54
69
 
55
- if model_name in INSTRUCT_MODEL_NAMES:
56
- self.is_chat_model = False
70
+ self.chat_template = chat_template or get_chat_template_by_model_path(
71
+ model_name
72
+ )
73
+
74
+ if is_chat_model is not None:
75
+ self.is_chat_model = is_chat_model
57
76
  else:
58
- self.is_chat_model = True
77
+ if model_name in INSTRUCT_MODEL_NAMES:
78
+ self.is_chat_model = False
79
+ else:
80
+ self.is_chat_model = True
59
81
 
60
- self.chat_template = get_chat_template("default")
82
+ self.chat_begin_str = self.chat_template.role_prefix_and_suffix["assistant"][0]
61
83
 
62
84
  def get_chat_template(self):
63
85
  return self.chat_template
@@ -69,7 +91,7 @@ class OpenAI(BaseBackend):
69
91
  ):
70
92
  if sampling_params.dtype is None:
71
93
  if self.is_chat_model:
72
- if not s.text_.endswith("ASSISTANT:"):
94
+ if not s.text_.endswith(self.chat_begin_str):
73
95
  raise RuntimeError(
74
96
  "This use case is not supported. "
75
97
  "For OpenAI chat models, sgl.gen must be right after sgl.assistant"
@@ -122,7 +144,11 @@ class OpenAI(BaseBackend):
122
144
  ):
123
145
  if sampling_params.dtype is None:
124
146
  if self.is_chat_model:
125
- assert s.text_.endswith("ASSISTANT:")
147
+ if not s.text_.endswith(self.chat_begin_str):
148
+ raise RuntimeError(
149
+ "This use case is not supported. "
150
+ "For OpenAI chat models, sgl.gen must be right after sgl.assistant"
151
+ )
126
152
  prompt = s.messages_
127
153
  else:
128
154
  prompt = s.text_
@@ -137,7 +163,7 @@ class OpenAI(BaseBackend):
137
163
  )
138
164
  return generator
139
165
  else:
140
- raise ValueError(f"Unknown dtype: {dtype}")
166
+ raise ValueError(f"Unknown dtype: {sampling_params.dtype}")
141
167
 
142
168
  def select(
143
169
  self,
@@ -241,7 +267,10 @@ def openai_completion_stream(client, retries=3, is_chat=None, prompt=None, **kwa
241
267
  messages=prompt, stream=True, **kwargs
242
268
  )
243
269
  for ret in generator:
244
- content = ret.choices[0].delta.content
270
+ try:
271
+ content = ret.choices[0].delta.content
272
+ except IndexError:
273
+ content = None
245
274
  yield content or "", {}
246
275
  else:
247
276
  generator = client.completions.create(
@@ -12,15 +12,26 @@ from sglang.utils import encode_image_base64, find_printable_text, http_request
12
12
 
13
13
 
14
14
  class RuntimeEndpoint(BaseBackend):
15
- def __init__(self, base_url, auth_token=None):
15
+ def __init__(
16
+ self,
17
+ base_url: str,
18
+ auth_token: Optional[str] = None,
19
+ api_key: Optional[str] = None,
20
+ verify: Optional[str] = None,
21
+ ):
16
22
  super().__init__()
17
23
  self.support_concate_and_append = True
18
24
 
19
25
  self.base_url = base_url
20
26
  self.auth_token = auth_token
27
+ self.api_key = api_key
28
+ self.verify = verify
21
29
 
22
30
  res = http_request(
23
- self.base_url + "/get_model_info", auth_token=self.auth_token
31
+ self.base_url + "/get_model_info",
32
+ auth_token=self.auth_token,
33
+ api_key=self.api_key,
34
+ verify=self.verify,
24
35
  )
25
36
  assert res.status_code == 200
26
37
  self.model_info = res.json()
@@ -32,6 +43,22 @@ class RuntimeEndpoint(BaseBackend):
32
43
  def get_model_name(self):
33
44
  return self.model_info["model_path"]
34
45
 
46
+ def flush_cache(self):
47
+ res = http_request(
48
+ self.base_url + "/flush_cache",
49
+ auth_token=self.auth_token,
50
+ verify=self.verify,
51
+ )
52
+ return res.status_code == 200
53
+
54
+ def get_server_args(self):
55
+ res = http_request(
56
+ self.base_url + "/get_server_args",
57
+ auth_token=self.auth_token,
58
+ verify=self.verify,
59
+ )
60
+ return res.json()
61
+
35
62
  def get_chat_template(self):
36
63
  return self.chat_template
37
64
 
@@ -40,6 +67,8 @@ class RuntimeEndpoint(BaseBackend):
40
67
  self.base_url + "/generate",
41
68
  json={"text": prefix_str, "sampling_params": {"max_new_tokens": 0}},
42
69
  auth_token=self.auth_token,
70
+ api_key=self.api_key,
71
+ verify=self.verify,
43
72
  )
44
73
  assert res.status_code == 200
45
74
 
@@ -48,6 +77,8 @@ class RuntimeEndpoint(BaseBackend):
48
77
  self.base_url + "/generate",
49
78
  json={"text": s.text_, "sampling_params": {"max_new_tokens": 0}},
50
79
  auth_token=self.auth_token,
80
+ api_key=self.api_key,
81
+ verify=self.verify,
51
82
  )
52
83
  assert res.status_code == 200
53
84
 
@@ -55,7 +86,11 @@ class RuntimeEndpoint(BaseBackend):
55
86
  data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
56
87
  self._add_images(s, data)
57
88
  res = http_request(
58
- self.base_url + "/generate", json=data, auth_token=self.auth_token
89
+ self.base_url + "/generate",
90
+ json=data,
91
+ auth_token=self.auth_token,
92
+ api_key=self.api_key,
93
+ verify=self.verify,
59
94
  )
60
95
  assert res.status_code == 200
61
96
 
@@ -87,7 +122,11 @@ class RuntimeEndpoint(BaseBackend):
87
122
  self._add_images(s, data)
88
123
 
89
124
  res = http_request(
90
- self.base_url + "/generate", json=data, auth_token=self.auth_token
125
+ self.base_url + "/generate",
126
+ json=data,
127
+ auth_token=self.auth_token,
128
+ api_key=self.api_key,
129
+ verify=self.verify,
91
130
  )
92
131
  obj = res.json()
93
132
  comp = obj["text"]
@@ -126,6 +165,8 @@ class RuntimeEndpoint(BaseBackend):
126
165
  json=data,
127
166
  stream=True,
128
167
  auth_token=self.auth_token,
168
+ api_key=self.api_key,
169
+ verify=self.verify,
129
170
  )
130
171
  pos = 0
131
172
 
@@ -157,7 +198,11 @@ class RuntimeEndpoint(BaseBackend):
157
198
  data = {"text": s.text_, "sampling_params": {"max_new_tokens": 0}}
158
199
  self._add_images(s, data)
159
200
  res = http_request(
160
- self.base_url + "/generate", json=data, auth_token=self.auth_token
201
+ self.base_url + "/generate",
202
+ json=data,
203
+ auth_token=self.auth_token,
204
+ api_key=self.api_key,
205
+ verify=self.verify,
161
206
  )
162
207
  assert res.status_code == 200
163
208
  prompt_len = res.json()["meta_info"]["prompt_tokens"]
@@ -171,7 +216,11 @@ class RuntimeEndpoint(BaseBackend):
171
216
  }
172
217
  self._add_images(s, data)
173
218
  res = http_request(
174
- self.base_url + "/generate", json=data, auth_token=self.auth_token
219
+ self.base_url + "/generate",
220
+ json=data,
221
+ auth_token=self.auth_token,
222
+ api_key=self.api_key,
223
+ verify=self.verify,
175
224
  )
176
225
  assert res.status_code == 200
177
226
  obj = res.json()
@@ -188,6 +237,8 @@ class RuntimeEndpoint(BaseBackend):
188
237
  self.base_url + "/concate_and_append_request",
189
238
  json={"src_rids": src_rids, "dst_rid": dst_rid},
190
239
  auth_token=self.auth_token,
240
+ api_key=self.api_key,
241
+ verify=self.verify,
191
242
  )
192
243
  assert res.status_code == 200
193
244
 
@@ -12,42 +12,43 @@ class ChatTemplateStyle(Enum):
12
12
  class ChatTemplate:
13
13
  name: str
14
14
  default_system_prompt: str
15
- role_prefix_and_suffix: Dict[str, Tuple[str]]
15
+ role_prefix_and_suffix: Dict[str, Tuple[str, str]]
16
16
  stop_str: List[str] = ()
17
17
  image_token: str = "<image>"
18
18
  style: ChatTemplateStyle = ChatTemplateStyle.PLAIN
19
19
 
20
- def get_prefix_and_suffix(self, role, hist_messages):
21
- if self.style == ChatTemplateStyle.PLAIN:
22
- return self.role_prefix_and_suffix[role]
23
- elif self.style == ChatTemplateStyle.LLAMA2:
24
- if len(hist_messages) == 0 and role == "system":
25
- return (
26
- self.role_prefix_and_suffix["user"][0]
27
- + self.role_prefix_and_suffix["system"][0],
28
- self.role_prefix_and_suffix["system"][1],
20
+ def get_prefix_and_suffix(
21
+ self, role: str, hist_messages: List[Dict]
22
+ ) -> Tuple[str, str]:
23
+ prefix, suffix = self.role_prefix_and_suffix.get(role, ("", ""))
24
+
25
+ if self.style == ChatTemplateStyle.LLAMA2:
26
+ if role == "system" and not hist_messages:
27
+ user_prefix, _ = self.role_prefix_and_suffix.get("user", ("", ""))
28
+ system_prefix, system_suffix = self.role_prefix_and_suffix.get(
29
+ "system", ("", "")
29
30
  )
31
+ return (user_prefix + system_prefix, system_suffix)
30
32
  elif (
31
- len(hist_messages) == 1
32
- and role == "user"
33
+ role == "user"
34
+ and len(hist_messages) == 1
33
35
  and hist_messages[0]["content"] is not None
34
36
  ):
35
- return ("", self.role_prefix_and_suffix["user"][1])
36
- return self.role_prefix_and_suffix[role]
37
- else:
38
- raise ValueError(f"Invalid style: {self.style}")
37
+ return ("", suffix)
38
+
39
+ return prefix, suffix
39
40
 
40
- def get_prompt(self, messages):
41
+ def get_prompt(self, messages: List[Dict]) -> str:
41
42
  prompt = ""
42
- for i in range(len(messages)):
43
- role, content = messages[i]["role"], messages[i]["content"]
43
+ for i, message in enumerate(messages):
44
+ role, content = message["role"], message["content"]
44
45
  if role == "system" and content is None:
45
46
  content = self.default_system_prompt
46
47
  if content is None:
47
48
  continue
48
49
 
49
50
  prefix, suffix = self.get_prefix_and_suffix(role, messages[:i])
50
- prompt += prefix + content + suffix
51
+ prompt += f"{prefix}{content}{suffix}"
51
52
  return prompt
52
53
 
53
54
 
@@ -106,9 +107,9 @@ register_chat_template(
106
107
  name="chatml",
107
108
  default_system_prompt=None,
108
109
  role_prefix_and_suffix={
109
- "system": ("<|im_start|>system\n", "\n<|im_end|>\n"),
110
- "user": ("<|im_start|>user\n", "\n<|im_end|>\n"),
111
- "assistant": ("<|im_start|>assistant\n", "\n<|im_end|>\n"),
110
+ "system": ("<|im_start|>system\n", "<|im_end|>\n"),
111
+ "user": ("<|im_start|>user\n", "<|im_end|>\n"),
112
+ "assistant": ("<|im_start|>assistant\n", "<|im_end|>\n"),
112
113
  },
113
114
  style=ChatTemplateStyle.PLAIN,
114
115
  stop_str=("<|im_end|>",),
@@ -121,9 +122,9 @@ register_chat_template(
121
122
  name="chatml-llava",
122
123
  default_system_prompt="Answer the questions.",
123
124
  role_prefix_and_suffix={
124
- "system": ("<|im_start|>system\n", "\n<|im_end|>\n"),
125
- "user": ("<|im_start|>user\n", "\n<|im_end|>\n"),
126
- "assistant": ("<|im_start|>assistant\n", "\n<|im_end|>\n"),
125
+ "system": ("<|im_start|>system\n", "<|im_end|>\n"),
126
+ "user": ("<|im_start|>user\n", "<|im_end|>\n"),
127
+ "assistant": ("<|im_start|>assistant\n", "<|im_end|>\n"),
127
128
  },
128
129
  style=ChatTemplateStyle.PLAIN,
129
130
  stop_str=("<|im_end|>",),
@@ -178,6 +179,19 @@ register_chat_template(
178
179
  )
179
180
  )
180
181
 
182
+ register_chat_template(
183
+ ChatTemplate(
184
+ name="gemma-it",
185
+ default_system_prompt=None,
186
+ role_prefix_and_suffix={
187
+ "system": ("", ""),
188
+ "user": ("<start_of_turn>user\n", "<end_of_turn>\n"),
189
+ "assistant": ("<start_of_turn>model\n", "<end_of_turn>\n"),
190
+ },
191
+ style=ChatTemplateStyle.PLAIN,
192
+ )
193
+ )
194
+
181
195
 
182
196
  @register_chat_template_matching_function
183
197
  def match_vicuna(model_path: str):
@@ -218,6 +232,13 @@ def match_chat_yi(model_path: str):
218
232
  return get_chat_template("yi")
219
233
 
220
234
 
235
+ @register_chat_template_matching_function
236
+ def match_gemma_it(model_path: str):
237
+ model_path = model_path.lower()
238
+ if "gemma" in model_path and "it" in model_path:
239
+ return get_chat_template("gemma-it")
240
+
241
+
221
242
  if __name__ == "__main__":
222
243
  messages = [
223
244
  {"role": "system", "content": None}, # None means default
@@ -245,6 +245,9 @@ class StreamExecutor:
245
245
  self.variable_event[name].wait()
246
246
  return self.variables[name]
247
247
 
248
+ def set_var(self, name, value):
249
+ self.variables[name] = value
250
+
248
251
  def get_meta_info(self, name):
249
252
  if name in self.variable_event:
250
253
  self.variable_event[name].wait()
@@ -583,6 +586,10 @@ class StreamExecutor:
583
586
  if self.chat_template.stop_str:
584
587
  if not clone:
585
588
  clone = self.default_sampling_para.clone()
589
+ if clone.stop == ():
590
+ clone.stop = []
591
+ elif isinstance(clone.stop, str):
592
+ clone.stop = [clone.stop]
586
593
  clone.stop += self.chat_template.stop_str
587
594
 
588
595
  return clone or self.default_sampling_para
@@ -679,7 +686,7 @@ class ProgramState:
679
686
  if var_name is None:
680
687
  yield self.text()
681
688
  else:
682
- yield self.get_var(name)
689
+ yield self.get_var(var_name)
683
690
 
684
691
  async def text_async_iter(
685
692
  self, var_name: Optional[str] = None, return_meta_data: bool = False
@@ -717,11 +724,14 @@ class ProgramState:
717
724
  if var_name is None:
718
725
  yield self.text()
719
726
  else:
720
- yield self.get_var(name)
727
+ yield self.get_var(var_name)
721
728
 
722
729
  def get_var(self, name):
723
730
  return self.stream_executor.get_var(name)
724
731
 
732
+ def set_var(self, name, value):
733
+ return self.stream_executor.set_var(name, value)
734
+
725
735
  def get_meta_info(self, name):
726
736
  return self.stream_executor.get_meta_info(name)
727
737
 
@@ -732,6 +742,9 @@ class ProgramState:
732
742
  def __getitem__(self, name):
733
743
  return self.get_var(name)
734
744
 
745
+ def __setitem__(self, name, value):
746
+ self.set_var(name, value)
747
+
735
748
  def __del__(self):
736
749
  self.stream_executor.end()
737
750
 
sglang/lang/ir.py CHANGED
@@ -73,7 +73,7 @@ class SglSamplingParams:
73
73
  "Regular expression is not supported in the Anthropic backend."
74
74
  )
75
75
  return {
76
- "max_tokens_to_sample": self.max_new_tokens,
76
+ "max_tokens": self.max_new_tokens,
77
77
  "stop_sequences": (
78
78
  self.stop if isinstance(self.stop, (list, tuple)) else [self.stop]
79
79
  ),
@@ -1,9 +1,31 @@
1
+ import json
2
+ from typing import Dict, Optional, Union
3
+
1
4
  from outlines.caching import cache as disk_cache
2
5
  from outlines.caching import disable_cache
3
6
  from outlines.fsm.fsm import RegexFSM
4
- from outlines.fsm.json_schema import build_regex_from_object
5
7
  from outlines.fsm.regex import FSMInfo, make_deterministic_fsm
6
8
  from outlines.models.transformers import TransformerTokenizer
9
+ from pydantic import BaseModel
10
+
11
+ try:
12
+ from outlines.fsm.json_schema import build_regex_from_object
13
+ except ImportError:
14
+ # Since outlines 0.0.32, build_regex_from_object is replaced by build_regex_from_schema,
15
+ # which only accepts string schema as input.
16
+ from outlines.fsm.json_schema import build_regex_from_schema
17
+
18
+ def build_regex_from_object(
19
+ object: Union[str, BaseModel, Dict], whitespace_pattern: Optional[str] = None
20
+ ):
21
+ if isinstance(object, type(BaseModel)):
22
+ schema = json.dumps(object.model_json_schema())
23
+ elif isinstance(object, Dict):
24
+ schema = json.dumps(object)
25
+ else:
26
+ schema = object
27
+ return build_regex_from_schema(schema, whitespace_pattern)
28
+
7
29
 
8
30
  __all__ = [
9
31
  "RegexFSM",
@@ -5,9 +5,20 @@ from sglang.srt.constrained.base_cache import BaseCache
5
5
  class FSMCache(BaseCache):
6
6
  def __init__(self, tokenizer_path, tokenizer_args_dict, enable=True):
7
7
  super().__init__(enable=enable)
8
- self.outlines_tokenizer = TransformerTokenizer(
9
- tokenizer_path, **tokenizer_args_dict
10
- )
8
+
9
+ from importlib.metadata import version
10
+ if version("outlines") >= "0.0.35":
11
+ from transformers import AutoTokenizer
12
+
13
+ tokenizer_args_dict.setdefault("padding_side", "left")
14
+ tokenizer = AutoTokenizer.from_pretrained(
15
+ tokenizer_path, **tokenizer_args_dict
16
+ )
17
+ self.outlines_tokenizer = TransformerTokenizer(tokenizer)
18
+ else:
19
+ self.outlines_tokenizer = TransformerTokenizer(
20
+ tokenizer_path, **tokenizer_args_dict
21
+ )
11
22
 
12
23
  def init_value(self, regex):
13
24
  return RegexFSM(regex, self.outlines_tokenizer)
@@ -129,7 +129,7 @@ def context_attention_fwd(q, k, v, o, b_start_loc, b_seq_len, max_input_len):
129
129
 
130
130
  Lq, Lk, Lv = q.shape[-1], k.shape[-1], v.shape[-1]
131
131
  assert Lq == Lk and Lk == Lv
132
- assert Lk in {16, 32, 64, 128}
132
+ assert Lk in {16, 32, 64, 128, 256}
133
133
 
134
134
  sm_scale = 1.0 / (Lq**0.5)
135
135
  batch, head = b_seq_len.shape[0], q.shape[1]
@@ -181,19 +181,20 @@ def extend_attention_fwd(
181
181
 
182
182
  k_buffer, v_buffer: (prefix + extend) tensors in mem_manager
183
183
  """
184
- if CUDA_CAPABILITY[0] >= 8:
185
- BLOCK_M, BLOCK_N = 128, 128
186
- else:
187
- BLOCK_M, BLOCK_N = 64, 64
188
-
189
184
  Lq, Lk, Lv, Lo = (
190
185
  q_extend.shape[-1],
191
186
  k_extend.shape[-1],
192
187
  v_extend.shape[-1],
193
188
  o_extend.shape[-1],
194
189
  )
190
+
195
191
  assert Lq == Lk and Lk == Lv and Lv == Lo
196
- assert Lq in {16, 32, 64, 128}
192
+ assert Lq in {16, 32, 64, 128, 256}
193
+
194
+ if CUDA_CAPABILITY[0] >= 8:
195
+ BLOCK_M, BLOCK_N = (128, 128) if Lq <= 128 else (64, 64)
196
+ else:
197
+ BLOCK_M, BLOCK_N = (64, 64) if Lq <= 128 else (32, 32)
197
198
 
198
199
  sm_scale = 1.0 / (Lq**0.5)
199
200
  batch_size, head_num = b_seq_len.shape[0], q_extend.shape[1]
@@ -1,15 +1,9 @@
1
- from typing import List
2
-
3
1
  import torch
4
2
  from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd
5
3
  from sglang.srt.layers.extend_attention import extend_attention_fwd
6
4
  from sglang.srt.layers.token_attention import token_attention_fwd
7
5
  from sglang.srt.managers.router.model_runner import ForwardMode, InputMetadata
8
6
  from torch import nn
9
- from vllm.model_executor.parallel_utils.parallel_state import (
10
- get_tensor_model_parallel_rank,
11
- get_tensor_model_parallel_world_size,
12
- )
13
7
 
14
8
 
15
9
  class RadixAttention(nn.Module):
@@ -21,11 +15,9 @@ class RadixAttention(nn.Module):
21
15
  self.head_dim = head_dim
22
16
  self.layer_id = layer_id
23
17
 
24
- from sglang.srt.managers.router.model_runner import global_model_mode
25
-
26
- self.use_flashinfer = "flashinfer" in global_model_mode
18
+ from sglang.srt.managers.router.model_runner import global_server_args_dict
27
19
 
28
- if self.use_flashinfer:
20
+ if global_server_args_dict.get("enable_flashinfer", False):
29
21
  self.prefill_forward = self.prefill_forward_flashinfer
30
22
  self.extend_forward = self.prefill_forward_flashinfer
31
23
  self.decode_forward = self.decode_forward_flashinfer