sglang 0.3.4.post2__py3-none-any.whl → 0.3.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. sglang/api.py +1 -1
  2. sglang/bench_latency.py +3 -3
  3. sglang/bench_server_latency.py +2 -3
  4. sglang/bench_serving.py +92 -0
  5. sglang/global_config.py +9 -3
  6. sglang/lang/chat_template.py +50 -25
  7. sglang/lang/interpreter.py +9 -1
  8. sglang/lang/ir.py +11 -2
  9. sglang/launch_server.py +1 -1
  10. sglang/srt/configs/model_config.py +51 -13
  11. sglang/srt/constrained/__init__.py +18 -0
  12. sglang/srt/constrained/bnf_cache.py +61 -0
  13. sglang/srt/constrained/grammar.py +190 -0
  14. sglang/srt/hf_transformers_utils.py +6 -5
  15. sglang/srt/layers/attention/triton_ops/decode_attention.py +110 -30
  16. sglang/srt/layers/attention/triton_ops/prefill_attention.py +1 -1
  17. sglang/srt/layers/fused_moe/fused_moe.py +4 -3
  18. sglang/srt/layers/fused_moe/layer.py +28 -0
  19. sglang/srt/layers/quantization/base_config.py +16 -1
  20. sglang/srt/layers/vocab_parallel_embedding.py +486 -0
  21. sglang/srt/managers/data_parallel_controller.py +7 -6
  22. sglang/srt/managers/detokenizer_manager.py +9 -11
  23. sglang/srt/managers/image_processor.py +4 -3
  24. sglang/srt/managers/io_struct.py +70 -78
  25. sglang/srt/managers/schedule_batch.py +33 -49
  26. sglang/srt/managers/schedule_policy.py +24 -13
  27. sglang/srt/managers/scheduler.py +137 -80
  28. sglang/srt/managers/tokenizer_manager.py +224 -336
  29. sglang/srt/managers/tp_worker.py +5 -5
  30. sglang/srt/mem_cache/flush_cache.py +1 -1
  31. sglang/srt/model_executor/cuda_graph_runner.py +7 -4
  32. sglang/srt/model_executor/model_runner.py +8 -17
  33. sglang/srt/models/baichuan.py +4 -4
  34. sglang/srt/models/chatglm.py +4 -4
  35. sglang/srt/models/commandr.py +1 -1
  36. sglang/srt/models/dbrx.py +5 -5
  37. sglang/srt/models/deepseek.py +4 -4
  38. sglang/srt/models/deepseek_v2.py +4 -4
  39. sglang/srt/models/exaone.py +4 -4
  40. sglang/srt/models/gemma.py +1 -1
  41. sglang/srt/models/gemma2.py +1 -1
  42. sglang/srt/models/gpt2.py +287 -0
  43. sglang/srt/models/gpt_bigcode.py +1 -1
  44. sglang/srt/models/grok.py +4 -4
  45. sglang/srt/models/internlm2.py +4 -4
  46. sglang/srt/models/llama.py +15 -7
  47. sglang/srt/models/llama_embedding.py +2 -10
  48. sglang/srt/models/llama_reward.py +5 -0
  49. sglang/srt/models/minicpm.py +4 -4
  50. sglang/srt/models/minicpm3.py +4 -4
  51. sglang/srt/models/mixtral.py +7 -5
  52. sglang/srt/models/mixtral_quant.py +4 -4
  53. sglang/srt/models/mllama.py +5 -5
  54. sglang/srt/models/olmo.py +4 -4
  55. sglang/srt/models/olmoe.py +4 -4
  56. sglang/srt/models/qwen.py +4 -4
  57. sglang/srt/models/qwen2.py +4 -4
  58. sglang/srt/models/qwen2_moe.py +4 -4
  59. sglang/srt/models/qwen2_vl.py +4 -8
  60. sglang/srt/models/stablelm.py +4 -4
  61. sglang/srt/models/torch_native_llama.py +4 -4
  62. sglang/srt/models/xverse.py +4 -4
  63. sglang/srt/models/xverse_moe.py +4 -4
  64. sglang/srt/openai_api/adapter.py +52 -66
  65. sglang/srt/sampling/sampling_batch_info.py +7 -13
  66. sglang/srt/server.py +31 -35
  67. sglang/srt/server_args.py +34 -5
  68. sglang/srt/utils.py +40 -56
  69. sglang/test/runners.py +2 -1
  70. sglang/test/test_utils.py +73 -25
  71. sglang/utils.py +62 -1
  72. sglang/version.py +1 -1
  73. sglang-0.3.5.dist-info/METADATA +344 -0
  74. {sglang-0.3.4.post2.dist-info → sglang-0.3.5.dist-info}/RECORD +77 -73
  75. {sglang-0.3.4.post2.dist-info → sglang-0.3.5.dist-info}/WHEEL +1 -1
  76. sglang-0.3.4.post2.dist-info/METADATA +0 -899
  77. {sglang-0.3.4.post2.dist-info → sglang-0.3.5.dist-info}/LICENSE +0 -0
  78. {sglang-0.3.4.post2.dist-info → sglang-0.3.5.dist-info}/top_level.txt +0 -0
sglang/api.py CHANGED
@@ -99,7 +99,7 @@ def gen(
99
99
  regex: Optional[str] = None,
100
100
  json_schema: Optional[str] = None,
101
101
  ):
102
- """Call the model to generate. See the meaning of the arguments in docs/en/sampling_params.md"""
102
+ """Call the model to generate. See the meaning of the arguments in docs/sampling_params.md"""
103
103
 
104
104
  if choices:
105
105
  return SglSelect(
sglang/bench_latency.py CHANGED
@@ -129,9 +129,9 @@ def load_model(server_args, port_args, tp_rank):
129
129
 
130
130
  model_config = ModelConfig(
131
131
  server_args.model_path,
132
- server_args.trust_remote_code,
132
+ trust_remote_code=server_args.trust_remote_code,
133
133
  context_length=server_args.context_length,
134
- model_override_args=json.loads(server_args.json_model_override_args),
134
+ model_override_args=server_args.json_model_override_args,
135
135
  )
136
136
  model_runner = ModelRunner(
137
137
  model_config=model_config,
@@ -550,4 +550,4 @@ if __name__ == "__main__":
550
550
  except Exception as e:
551
551
  raise e
552
552
  finally:
553
- kill_child_process(os.getpid(), including_parent=False)
553
+ kill_child_process()
@@ -15,7 +15,6 @@ import dataclasses
15
15
  import itertools
16
16
  import json
17
17
  import multiprocessing
18
- import os
19
18
  import time
20
19
  from typing import Tuple
21
20
 
@@ -70,7 +69,7 @@ def launch_server_internal(server_args):
70
69
  except Exception as e:
71
70
  raise e
72
71
  finally:
73
- kill_child_process(os.getpid(), including_parent=False)
72
+ kill_child_process()
74
73
 
75
74
 
76
75
  def launch_server_process(server_args: ServerArgs):
@@ -176,7 +175,7 @@ def run_benchmark(server_args: ServerArgs, bench_args: BenchArgs):
176
175
  )
177
176
  finally:
178
177
  if proc:
179
- kill_child_process(proc.pid)
178
+ kill_child_process(proc.pid, include_self=True)
180
179
 
181
180
  print(f"\nResults are saved to {bench_args.result_filename}")
182
181
 
sglang/bench_serving.py CHANGED
@@ -222,6 +222,85 @@ async def async_request_openai_completions(
222
222
  return output
223
223
 
224
224
 
225
+ async def async_request_truss(
226
+ request_func_input: RequestFuncInput,
227
+ pbar: Optional[tqdm] = None,
228
+ ) -> RequestFuncOutput:
229
+ api_url = request_func_input.api_url
230
+
231
+ prompt = request_func_input.prompt
232
+
233
+ async with aiohttp.ClientSession(timeout=AIOHTTP_TIMEOUT) as session:
234
+ payload = {
235
+ "model": request_func_input.model,
236
+ "prompt": prompt,
237
+ "temperature": 0.0,
238
+ "best_of": 1,
239
+ "max_tokens": request_func_input.output_len,
240
+ "stream": not args.disable_stream,
241
+ "ignore_eos": not args.disable_ignore_eos,
242
+ **request_func_input.extra_request_body,
243
+ }
244
+ headers = {"Authorization": f"Bearer {os.environ.get('OPENAI_API_KEY')}"}
245
+
246
+ output = RequestFuncOutput()
247
+ output.prompt_len = request_func_input.prompt_len
248
+
249
+ generated_text = ""
250
+ ttft = 0.0
251
+ st = time.perf_counter()
252
+ most_recent_timestamp = st
253
+ try:
254
+ async with session.post(
255
+ url=api_url, json=payload, headers=headers
256
+ ) as response:
257
+ if response.status == 200:
258
+ async for chunk_bytes in response.content:
259
+ chunk_bytes = chunk_bytes.strip()
260
+ if not chunk_bytes:
261
+ continue
262
+
263
+ chunk = remove_prefix(chunk_bytes.decode("utf-8"), "data: ")
264
+ latency = time.perf_counter() - st
265
+ if chunk == "[DONE]":
266
+ pass
267
+ else:
268
+ data = json.loads(chunk)
269
+
270
+ # NOTE: Some completion API might have a last
271
+ # usage summary response without a token so we
272
+ # want to check a token was generated
273
+ if data["choices"][0]["delta"]["content"]:
274
+ timestamp = time.perf_counter()
275
+ # First token
276
+ if ttft == 0.0:
277
+ ttft = time.perf_counter() - st
278
+ output.ttft = ttft
279
+
280
+ # Decoding phase
281
+ else:
282
+ output.itl.append(timestamp - most_recent_timestamp)
283
+
284
+ most_recent_timestamp = timestamp
285
+ generated_text += data["choices"][0]["delta"]["content"]
286
+
287
+ output.generated_text = generated_text
288
+ output.success = True
289
+ output.latency = latency
290
+ output.output_len = request_func_input.output_len
291
+ else:
292
+ output.error = response.reason or ""
293
+ output.success = False
294
+ except Exception:
295
+ output.success = False
296
+ exc_info = sys.exc_info()
297
+ output.error = "".join(traceback.format_exception(*exc_info))
298
+
299
+ if pbar:
300
+ pbar.update(1)
301
+ return output
302
+
303
+
225
304
  async def async_request_sglang_generate(
226
305
  request_func_input: RequestFuncInput,
227
306
  pbar: Optional[tqdm] = None,
@@ -350,6 +429,7 @@ ASYNC_REQUEST_FUNCS = {
350
429
  "lmdeploy": async_request_openai_completions,
351
430
  "trt": async_request_trt_llm,
352
431
  "gserver": async_request_gserver,
432
+ "truss": async_request_truss,
353
433
  }
354
434
 
355
435
 
@@ -873,6 +953,7 @@ def run_benchmark(args_: argparse.Namespace):
873
953
  "vllm": 8000,
874
954
  "trt": 8000,
875
955
  "gserver": 9988,
956
+ "truss": 8080,
876
957
  }.get(args.backend, 30000)
877
958
 
878
959
  model_url = (
@@ -905,9 +986,20 @@ def run_benchmark(args_: argparse.Namespace):
905
986
  elif args.backend == "gserver":
906
987
  api_url = args.base_url if args.base_url else f"{args.host}:{args.port}"
907
988
  args.model = args.model or "default"
989
+ elif args.backend == "truss":
990
+ api_url = (
991
+ f"{args.base_url}/v1/models/model:predict"
992
+ if args.base_url
993
+ else f"http://{args.host}:{args.port}/v1/models/model:predict"
994
+ )
908
995
 
909
996
  # Get model name
910
997
  if args.model is None:
998
+ if args.backend == "truss":
999
+ print(
1000
+ "Please provide a model with `--model` when using truss backend. e.g. --model meta-llama/Llama-3.1-8B-Instruct"
1001
+ )
1002
+ sys.exit(1)
911
1003
  try:
912
1004
  response = requests.get(model_url)
913
1005
  model_list = response.json().get("data", [])
sglang/global_config.py CHANGED
@@ -14,9 +14,15 @@ class GlobalConfig:
14
14
  self.default_backend = None
15
15
 
16
16
  # Runtime constants: New generation token ratio estimation
17
- self.init_new_token_ratio = 0.7
18
- self.base_min_new_token_ratio = 0.1
19
- self.new_token_ratio_decay = 0.001
17
+ self.default_init_new_token_ratio = float(
18
+ os.environ.get("SGLANG_INIT_NEW_TOKEN_RATIO", 0.7)
19
+ )
20
+ self.default_min_new_token_ratio_factor = float(
21
+ os.environ.get("SGLANG_MIN_NEW_TOKEN_RATIO_FACTOR", 0.14)
22
+ )
23
+ self.default_new_token_ratio_decay_steps = float(
24
+ os.environ.get("SGLANG_NEW_TOKEN_RATIO_DECAY_STEPS", 600)
25
+ )
20
26
 
21
27
  # Runtime constants: others
22
28
  self.retract_decode_steps = 20
@@ -116,12 +116,10 @@ register_chat_template(
116
116
  )
117
117
  )
118
118
 
119
- # There is default system prompt for qwen
120
- # reference: https://modelscope.cn/models/qwen/Qwen2-72B-Instruct/file/view/master?fileName=tokenizer_config.json&status=1
121
- # The chat template is: "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
119
+
122
120
  register_chat_template(
123
121
  ChatTemplate(
124
- name="qwen",
122
+ name="chatml-llava",
125
123
  default_system_prompt="You are a helpful assistant.",
126
124
  role_prefix_and_suffix={
127
125
  "system": ("<|im_start|>system\n", "<|im_end|>\n"),
@@ -130,13 +128,17 @@ register_chat_template(
130
128
  },
131
129
  style=ChatTemplateStyle.PLAIN,
132
130
  stop_str=("<|im_end|>",),
131
+ image_token="<image>\n",
133
132
  )
134
133
  )
135
134
 
136
- # Reference: https://huggingface.co/docs/transformers/main/model_doc/qwen2_vl#usage-example
135
+
136
+ # There is default system prompt for qwen
137
+ # reference: https://modelscope.cn/models/qwen/Qwen2-72B-Instruct/file/view/master?fileName=tokenizer_config.json&status=1
138
+ # The chat template is: "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
137
139
  register_chat_template(
138
140
  ChatTemplate(
139
- name="qwen2-vl",
141
+ name="qwen",
140
142
  default_system_prompt="You are a helpful assistant.",
141
143
  role_prefix_and_suffix={
142
144
  "system": ("<|im_start|>system\n", "<|im_end|>\n"),
@@ -144,15 +146,14 @@ register_chat_template(
144
146
  "assistant": ("<|im_start|>assistant\n", "<|im_end|>\n"),
145
147
  },
146
148
  style=ChatTemplateStyle.PLAIN,
147
- stop_str=("<|im_end|>"),
148
- image_token="<|vision_start|><|image_pad|><|vision_end|>",
149
+ stop_str=("<|im_end|>",),
149
150
  )
150
151
  )
151
152
 
152
-
153
+ # Reference: https://huggingface.co/docs/transformers/main/model_doc/qwen2_vl#usage-example
153
154
  register_chat_template(
154
155
  ChatTemplate(
155
- name="chatml-llava",
156
+ name="qwen2-vl",
156
157
  default_system_prompt="You are a helpful assistant.",
157
158
  role_prefix_and_suffix={
158
159
  "system": ("<|im_start|>system\n", "<|im_end|>\n"),
@@ -161,7 +162,7 @@ register_chat_template(
161
162
  },
162
163
  style=ChatTemplateStyle.PLAIN,
163
164
  stop_str=("<|im_end|>",),
164
- image_token="<image>\n",
165
+ image_token="<|vision_start|><|image_pad|><|vision_end|>",
165
166
  )
166
167
  )
167
168
 
@@ -182,37 +183,46 @@ register_chat_template(
182
183
  )
183
184
  )
184
185
 
185
- # Reference: https://modelscope.cn/models/01ai/Yi-1.5-34B-Chat/file/view/master?fileName=tokenizer_config.json&status=1
186
186
  register_chat_template(
187
187
  ChatTemplate(
188
- name="yi-1.5",
188
+ name="llama-2-chat",
189
189
  default_system_prompt=None,
190
190
  role_prefix_and_suffix={
191
- "system": ("", ""),
192
- "user": ("<|im_start|>user\n", "<|im_end|>\n<|im_start|>assistant\n"),
193
- "assistant": ("", "<|im_end|>\n"),
191
+ "system": ("<<SYS>>\n", "\n<</SYS>>\n\n"),
192
+ "user": ("[INST] ", " [/INST]"),
193
+ "assistant": ("", " </s><s>"),
194
194
  },
195
- style=ChatTemplateStyle.PLAIN,
196
- stop_str=("<|im_end|>",),
195
+ style=ChatTemplateStyle.LLAMA2,
197
196
  )
198
197
  )
199
198
 
200
199
  register_chat_template(
201
200
  ChatTemplate(
202
- name="llama-2-chat",
201
+ name="llama-3-instruct",
203
202
  default_system_prompt=None,
204
203
  role_prefix_and_suffix={
205
- "system": ("<<SYS>>\n", "\n<</SYS>>\n\n"),
206
- "user": ("[INST] ", " [/INST]"),
207
- "assistant": ("", " </s><s>"),
204
+ "system": (
205
+ "<|start_header_id|>system<|end_header_id|>\n\n",
206
+ "<|eot_id|>",
207
+ ),
208
+ "user": (
209
+ "<|start_header_id|>user<|end_header_id|>\n\n",
210
+ "<|eot_id|>",
211
+ ),
212
+ "assistant": (
213
+ "<|start_header_id|>assistant<|end_header_id|>\n\n",
214
+ "<|eot_id|>",
215
+ ),
208
216
  },
209
- style=ChatTemplateStyle.LLAMA2,
217
+ stop_str=("<|eot_id|>",),
218
+ image_token="<|image|>",
210
219
  )
211
220
  )
212
221
 
222
+ # The difference between "llama-3-instruct-llava" and "llama-3-instruct" is that llava uses a different image_token.
213
223
  register_chat_template(
214
224
  ChatTemplate(
215
- name="llama-3-instruct",
225
+ name="llama-3-instruct-llava",
216
226
  default_system_prompt=None,
217
227
  role_prefix_and_suffix={
218
228
  "system": (
@@ -229,7 +239,22 @@ register_chat_template(
229
239
  ),
230
240
  },
231
241
  stop_str=("<|eot_id|>",),
232
- image_token="<|image|>",
242
+ image_token="<image>\n",
243
+ )
244
+ )
245
+
246
+ # Reference: https://modelscope.cn/models/01ai/Yi-1.5-34B-Chat/file/view/master?fileName=tokenizer_config.json&status=1
247
+ register_chat_template(
248
+ ChatTemplate(
249
+ name="yi-1.5",
250
+ default_system_prompt=None,
251
+ role_prefix_and_suffix={
252
+ "system": ("", ""),
253
+ "user": ("<|im_start|>user\n", "<|im_end|>\n<|im_start|>assistant\n"),
254
+ "assistant": ("", "<|im_end|>\n"),
255
+ },
256
+ style=ChatTemplateStyle.PLAIN,
257
+ stop_str=("<|im_end|>",),
233
258
  )
234
259
  )
235
260
 
@@ -54,7 +54,14 @@ def run_internal(state, program, func_args, func_kwargs, sync):
54
54
 
55
55
 
56
56
  def run_program(
57
- program, backend, func_args, func_kwargs, default_sampling_para, stream, sync=False
57
+ program,
58
+ backend,
59
+ func_args,
60
+ func_kwargs,
61
+ default_sampling_para,
62
+ stream,
63
+ sync=False,
64
+ use_thread=True,
58
65
  ):
59
66
  if hasattr(backend, "endpoint"):
60
67
  backend = backend.endpoint
@@ -67,6 +74,7 @@ def run_program(
67
74
  chat_template=None,
68
75
  stream=stream,
69
76
  num_api_spec_tokens=program.num_api_spec_tokens,
77
+ use_thread=use_thread,
70
78
  )
71
79
  state = ProgramState(stream_executor)
72
80
 
sglang/lang/ir.py CHANGED
@@ -168,6 +168,7 @@ class SglFunction:
168
168
  return_text_in_logprobs: Optional[bool] = None,
169
169
  stream: bool = False,
170
170
  backend=None,
171
+ use_thread: bool = True,
171
172
  **kwargs,
172
173
  ):
173
174
  from sglang.lang.interpreter import run_program
@@ -195,7 +196,15 @@ class SglFunction:
195
196
  return_text_in_logprobs=return_text_in_logprobs,
196
197
  )
197
198
  backend = backend or global_config.default_backend
198
- return run_program(self, backend, args, kwargs, default_sampling_para, stream)
199
+ return run_program(
200
+ self,
201
+ backend,
202
+ args,
203
+ kwargs,
204
+ default_sampling_para,
205
+ stream,
206
+ use_thread=use_thread,
207
+ )
199
208
 
200
209
  def run_batch(
201
210
  self,
@@ -445,7 +454,7 @@ class SglGen(SglExpr):
445
454
  regex: Optional[str] = None,
446
455
  json_schema: Optional[str] = None,
447
456
  ):
448
- """Call the model to generate. See the meaning of the arguments in docs/en/sampling_params.md"""
457
+ """Call the model to generate. See the meaning of the arguments in docs/sampling_params.md"""
449
458
  super().__init__()
450
459
  self.name = name
451
460
  self.sampling_params = SglSamplingParams(
sglang/launch_server.py CHANGED
@@ -15,4 +15,4 @@ if __name__ == "__main__":
15
15
  except Exception as e:
16
16
  raise e
17
17
  finally:
18
- kill_child_process(os.getpid(), including_parent=False)
18
+ kill_child_process()
@@ -13,10 +13,11 @@ See the License for the specific language governing permissions and
13
13
  limitations under the License.
14
14
  """
15
15
 
16
+ import json
16
17
  import logging
17
18
  import os
18
19
  from enum import IntEnum, auto
19
- from typing import Optional
20
+ from typing import List, Optional
20
21
 
21
22
  from transformers import PretrainedConfig
22
23
 
@@ -38,18 +39,24 @@ class ModelConfig:
38
39
  revision: Optional[str] = None,
39
40
  context_length: Optional[int] = None,
40
41
  model_override_args: Optional[dict] = None,
42
+ is_embedding: Optional[bool] = None
41
43
  ) -> None:
42
- self.path = path
43
- self.trust_remote_code = trust_remote_code
44
- self.revision = revision
45
- self.model_override_args = model_override_args
44
+ # Parse args
45
+ self.model_override_args = json.loads(model_override_args)
46
46
  self.hf_config = get_config(
47
- self.path,
48
- trust_remote_code,
49
- revision,
50
- model_override_args=model_override_args,
47
+ path,
48
+ trust_remote_code=trust_remote_code,
49
+ revision=revision,
50
+ model_override_args=self.model_override_args,
51
51
  )
52
52
  self.hf_text_config = get_hf_text_config(self.hf_config)
53
+
54
+ # Check model type
55
+ self.is_generation = is_generation_model(self.hf_config.architectures, is_embedding)
56
+ self.is_multimodal = is_multimodal_model(self.hf_config.architectures)
57
+ self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures)
58
+
59
+ # Derive context length
53
60
  derived_context_len = get_context_length(self.hf_text_config)
54
61
  allow_long_context = os.environ.get(
55
62
  "SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN", None
@@ -81,7 +88,7 @@ class ModelConfig:
81
88
  self.hf_text_config.hidden_size // self.hf_text_config.num_attention_heads,
82
89
  )
83
90
 
84
- # FIXME: temporary special judge for deepseek v2 MLA architecture
91
+ # FIXME: temporary special judge for MLA architecture
85
92
  if "DeepseekV2ForCausalLM" in self.hf_config.architectures:
86
93
  self.head_dim = 256
87
94
  self.attention_arch = AttentionArch.MLA
@@ -112,8 +119,6 @@ class ModelConfig:
112
119
  self.num_hidden_layers = self.hf_text_config.num_hidden_layers
113
120
  self.vocab_size = self.hf_text_config.vocab_size
114
121
 
115
- self.is_encoder_decoder = self.hf_config.model_type in ["mllama"]
116
-
117
122
  # adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L289
118
123
  def get_total_num_kv_heads(self) -> int:
119
124
  """Returns the total number of KV heads."""
@@ -163,7 +168,6 @@ class ModelConfig:
163
168
  # equal to the number of attention heads.
164
169
  return self.hf_text_config.num_attention_heads
165
170
 
166
- # adapted from https://github.com/vllm-project/vllm/blob/main/vllm/config.py#L328
167
171
  def get_num_kv_heads(self, tensor_parallel_size) -> int:
168
172
  """Returns the number of KV heads per GPU."""
169
173
  total_num_kv_heads = self.get_total_num_kv_heads()
@@ -192,3 +196,37 @@ def get_hf_text_config(config: PretrainedConfig):
192
196
  return config.text_config
193
197
  else:
194
198
  return config
199
+
200
+
201
+ def is_generation_model(model_architectures: List[str], is_embedding: bool = False):
202
+ # We have two ways to determine whether a model is a generative model.
203
+ # 1. Check the model architectue
204
+ # 2. check the `is_embedding` server args
205
+
206
+ if (
207
+ "LlamaEmbeddingModel" in model_architectures
208
+ or "MistralModel" in model_architectures
209
+ or "LlamaForSequenceClassification" in model_architectures
210
+ or "LlamaForSequenceClassificationWithNormal_Weights" in model_architectures
211
+ ):
212
+ return False
213
+ else:
214
+ return not is_embedding
215
+
216
+
217
+ def is_multimodal_model(model_architectures: List[str]):
218
+ if (
219
+ "LlavaLlamaForCausalLM" in model_architectures
220
+ or "LlavaQwenForCausalLM" in model_architectures
221
+ or "LlavaMistralForCausalLM" in model_architectures
222
+ or "LlavaVidForCausalLM" in model_architectures
223
+ or "MllamaForConditionalGeneration" in model_architectures
224
+ or "Qwen2VLForConditionalGeneration" in model_architectures
225
+ ):
226
+ return True
227
+ else:
228
+ return False
229
+
230
+
231
+ def is_encoder_decoder_model(model_architectures: List[str]):
232
+ return "MllamaForConditionalGeneration" in model_architectures
@@ -51,6 +51,21 @@ except ImportError:
51
51
  return build_regex_from_schema(schema, whitespace_pattern)
52
52
 
53
53
 
54
+ try:
55
+ from xgrammar import (
56
+ GrammarMatcher,
57
+ GrammarMatcherInitContext,
58
+ GrammarMatcherInitContextCache,
59
+ )
60
+ except ImportError as e:
61
+
62
+ class Dummy:
63
+ pass
64
+
65
+ GrammarMatcher = Dummy
66
+ GrammarMatcherInitContext = Dummy
67
+ GrammarMatcherInitContextCache = Dummy
68
+
54
69
  __all__ = [
55
70
  "RegexGuide",
56
71
  "FSMInfo",
@@ -60,4 +75,7 @@ __all__ = [
60
75
  "disk_cache",
61
76
  "disable_cache",
62
77
  "make_byte_level_fsm",
78
+ "GrammarMatcher",
79
+ "GrammarMatcherInitContext",
80
+ "GrammarMatcherInitContextCache",
63
81
  ]
@@ -0,0 +1,61 @@
1
+ """
2
+ Copyright 2023-2024 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+ http://www.apache.org/licenses/LICENSE-2.0
7
+ Unless required by applicable law or agreed to in writing, software
8
+ distributed under the License is distributed on an "AS IS" BASIS,
9
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
10
+ See the License for the specific language governing permissions and
11
+ limitations under the License.
12
+ """
13
+
14
+ """Cache for the compressed finite state machine."""
15
+
16
+ from typing import Tuple
17
+
18
+ from transformers import AutoTokenizer
19
+
20
+ from sglang.srt.constrained import (
21
+ GrammarMatcher,
22
+ GrammarMatcherInitContext,
23
+ GrammarMatcherInitContextCache,
24
+ )
25
+
26
+ MAX_ROLLBACK_TOKENS = 10
27
+
28
+
29
+ class BNFCache:
30
+ grammar_cache: GrammarMatcherInitContextCache
31
+
32
+ def __init__(
33
+ self,
34
+ tokenizer_path,
35
+ tokenizer_args_dict,
36
+ skip_tokenizer_init=False,
37
+ whitespace_patterns=None,
38
+ ):
39
+ # TODO(dark): how to deal with whitespace_patterns and skip_tokenizer_init
40
+ if skip_tokenizer_init:
41
+ return
42
+
43
+ tokenizer = AutoTokenizer.from_pretrained(tokenizer_path, **tokenizer_args_dict)
44
+ self.grammar_cache = GrammarMatcherInitContextCache(
45
+ tokenizer_or_vocab=tokenizer
46
+ )
47
+
48
+ def get_context(self, key: Tuple[str, str]) -> GrammarMatcherInitContext:
49
+ key_type, key_string = key
50
+ if key_type == "json":
51
+ return self.grammar_cache.get_init_context_for_json_schema(key_string)
52
+ elif key_type == "regex":
53
+ raise ValueError(f"regex hasn't been supported by xgrammar yet")
54
+ else:
55
+ raise ValueError(f"Invalid key_type: {key_type}")
56
+
57
+ def query(self, key: Tuple[str, str], vocab_size: int) -> GrammarMatcher:
58
+ ctx = self.get_context(key)
59
+ return GrammarMatcher(
60
+ ctx, max_rollback_tokens=MAX_ROLLBACK_TOKENS, mask_vocab_size=vocab_size
61
+ )