sglang 0.4.5.post2__py3-none-any.whl → 0.4.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. sglang/bench_one_batch.py +19 -3
  2. sglang/bench_serving.py +8 -8
  3. sglang/compile_deep_gemm.py +177 -0
  4. sglang/lang/backend/openai.py +5 -1
  5. sglang/lang/backend/runtime_endpoint.py +5 -1
  6. sglang/srt/code_completion_parser.py +1 -1
  7. sglang/srt/configs/deepseekvl2.py +1 -1
  8. sglang/srt/configs/model_config.py +11 -2
  9. sglang/srt/constrained/llguidance_backend.py +78 -61
  10. sglang/srt/constrained/xgrammar_backend.py +1 -0
  11. sglang/srt/conversation.py +34 -1
  12. sglang/srt/disaggregation/decode.py +96 -5
  13. sglang/srt/disaggregation/mini_lb.py +113 -15
  14. sglang/srt/disaggregation/mooncake/conn.py +199 -32
  15. sglang/srt/disaggregation/nixl/__init__.py +1 -0
  16. sglang/srt/disaggregation/nixl/conn.py +622 -0
  17. sglang/srt/disaggregation/prefill.py +119 -20
  18. sglang/srt/disaggregation/utils.py +17 -0
  19. sglang/srt/entrypoints/engine.py +4 -0
  20. sglang/srt/entrypoints/http_server.py +11 -9
  21. sglang/srt/function_call_parser.py +132 -0
  22. sglang/srt/layers/activation.py +2 -2
  23. sglang/srt/layers/attention/base_attn_backend.py +3 -0
  24. sglang/srt/layers/attention/flashattention_backend.py +809 -160
  25. sglang/srt/layers/attention/flashmla_backend.py +8 -11
  26. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +5 -5
  27. sglang/srt/layers/attention/triton_ops/extend_attention.py +5 -5
  28. sglang/srt/layers/attention/triton_ops/prefill_attention.py +7 -3
  29. sglang/srt/layers/attention/vision.py +2 -0
  30. sglang/srt/layers/dp_attention.py +1 -1
  31. sglang/srt/layers/layernorm.py +42 -5
  32. sglang/srt/layers/logits_processor.py +2 -2
  33. sglang/srt/layers/moe/ep_moe/layer.py +2 -0
  34. sglang/srt/layers/moe/fused_moe_native.py +2 -4
  35. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +41 -41
  36. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  37. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +18 -15
  38. sglang/srt/layers/pooler.py +6 -0
  39. sglang/srt/layers/quantization/awq.py +5 -1
  40. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  41. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  42. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +153 -0
  43. sglang/srt/layers/quantization/deep_gemm.py +385 -0
  44. sglang/srt/layers/quantization/fp8_kernel.py +7 -38
  45. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  46. sglang/srt/layers/quantization/gptq.py +13 -7
  47. sglang/srt/layers/quantization/int8_kernel.py +32 -1
  48. sglang/srt/layers/quantization/modelopt_quant.py +2 -2
  49. sglang/srt/layers/quantization/w8a8_int8.py +3 -3
  50. sglang/srt/layers/radix_attention.py +13 -3
  51. sglang/srt/layers/rotary_embedding.py +176 -132
  52. sglang/srt/layers/sampler.py +2 -2
  53. sglang/srt/managers/data_parallel_controller.py +17 -4
  54. sglang/srt/managers/io_struct.py +21 -3
  55. sglang/srt/managers/mm_utils.py +85 -28
  56. sglang/srt/managers/multimodal_processors/base_processor.py +14 -1
  57. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +9 -2
  58. sglang/srt/managers/multimodal_processors/gemma3.py +2 -5
  59. sglang/srt/managers/multimodal_processors/janus_pro.py +2 -2
  60. sglang/srt/managers/multimodal_processors/minicpm.py +4 -3
  61. sglang/srt/managers/multimodal_processors/qwen_vl.py +38 -13
  62. sglang/srt/managers/schedule_batch.py +42 -12
  63. sglang/srt/managers/scheduler.py +47 -26
  64. sglang/srt/managers/tokenizer_manager.py +120 -30
  65. sglang/srt/managers/tp_worker.py +1 -0
  66. sglang/srt/mem_cache/hiradix_cache.py +40 -32
  67. sglang/srt/mem_cache/memory_pool.py +118 -13
  68. sglang/srt/model_executor/cuda_graph_runner.py +16 -10
  69. sglang/srt/model_executor/forward_batch_info.py +51 -95
  70. sglang/srt/model_executor/model_runner.py +29 -27
  71. sglang/srt/models/deepseek.py +12 -2
  72. sglang/srt/models/deepseek_nextn.py +101 -6
  73. sglang/srt/models/deepseek_v2.py +153 -76
  74. sglang/srt/models/deepseek_vl2.py +9 -4
  75. sglang/srt/models/gemma3_causal.py +1 -1
  76. sglang/srt/models/llama4.py +0 -1
  77. sglang/srt/models/minicpm3.py +2 -2
  78. sglang/srt/models/minicpmo.py +22 -7
  79. sglang/srt/models/mllama4.py +2 -2
  80. sglang/srt/models/qwen2_5_vl.py +3 -6
  81. sglang/srt/models/qwen2_vl.py +3 -7
  82. sglang/srt/models/roberta.py +178 -0
  83. sglang/srt/openai_api/adapter.py +87 -10
  84. sglang/srt/openai_api/protocol.py +6 -1
  85. sglang/srt/server_args.py +65 -60
  86. sglang/srt/speculative/build_eagle_tree.py +2 -2
  87. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  88. sglang/srt/speculative/eagle_utils.py +2 -2
  89. sglang/srt/speculative/eagle_worker.py +2 -7
  90. sglang/srt/torch_memory_saver_adapter.py +10 -1
  91. sglang/srt/utils.py +48 -6
  92. sglang/test/runners.py +6 -13
  93. sglang/test/test_utils.py +39 -19
  94. sglang/version.py +1 -1
  95. {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/METADATA +6 -7
  96. {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/RECORD +99 -92
  97. {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/WHEEL +1 -1
  98. {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/licenses/LICENSE +0 -0
  99. {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/top_level.txt +0 -0
sglang/bench_one_batch.py CHANGED
@@ -57,6 +57,7 @@ import torch
57
57
  import torch.distributed as dist
58
58
 
59
59
  from sglang.srt.configs.model_config import ModelConfig
60
+ from sglang.srt.distributed.parallel_state import destroy_distributed_environment
60
61
  from sglang.srt.entrypoints.engine import _set_envs_and_config
61
62
  from sglang.srt.hf_transformers_utils import get_tokenizer
62
63
  from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
@@ -85,6 +86,7 @@ class BenchArgs:
85
86
  correctness_test: bool = False
86
87
  # This is only used for correctness test
87
88
  cut_len: int = 4
89
+ log_decode_step: int = 0
88
90
  profile: bool = False
89
91
  profile_filename_prefix: str = "profile"
90
92
 
@@ -105,6 +107,12 @@ class BenchArgs:
105
107
  )
106
108
  parser.add_argument("--correctness-test", action="store_true")
107
109
  parser.add_argument("--cut-len", type=int, default=BenchArgs.cut_len)
110
+ parser.add_argument(
111
+ "--log-decode-step",
112
+ type=int,
113
+ default=BenchArgs.log_decode_step,
114
+ help="Log decode latency by step, default is set to zero to disable.",
115
+ )
108
116
  parser.add_argument(
109
117
  "--profile", action="store_true", help="Use Torch Profiler."
110
118
  )
@@ -335,6 +343,7 @@ def latency_test_run_once(
335
343
  input_len,
336
344
  output_len,
337
345
  device,
346
+ log_decode_step,
338
347
  profile,
339
348
  profile_filename_prefix,
340
349
  ):
@@ -394,9 +403,9 @@ def latency_test_run_once(
394
403
  tot_latency += latency
395
404
  throughput = batch_size / latency
396
405
  decode_latencies.append(latency)
397
- if i < 5:
406
+ if i < 5 or (log_decode_step > 0 and i % log_decode_step == 0):
398
407
  rank_print(
399
- f"Decode. Batch size: {batch_size}, latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s"
408
+ f"Decode {i}. Batch size: {batch_size}, latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s"
400
409
  )
401
410
 
402
411
  if profile:
@@ -457,8 +466,9 @@ def latency_test(
457
466
  reqs,
458
467
  bench_args.batch_size[0],
459
468
  bench_args.input_len[0],
460
- 8, # shorter decoding to speed up the warmup
469
+ min(32, bench_args.output_len[0]), # shorter decoding to speed up the warmup
461
470
  server_args.device,
471
+ log_decode_step=0,
462
472
  profile=False,
463
473
  profile_filename_prefix="", # not used
464
474
  )
@@ -480,6 +490,7 @@ def latency_test(
480
490
  il,
481
491
  ol,
482
492
  server_args.device,
493
+ bench_args.log_decode_step,
483
494
  bench_args.profile if tp_rank == 0 else None,
484
495
  bench_args.profile_filename_prefix,
485
496
  )
@@ -492,8 +503,13 @@ def latency_test(
492
503
  for result in result_list:
493
504
  fout.write(json.dumps(result) + "\n")
494
505
 
506
+ if server_args.tp_size > 1:
507
+ destroy_distributed_environment()
508
+
495
509
 
496
510
  def main(server_args, bench_args):
511
+ server_args.cuda_graph_max_bs = max(bench_args.batch_size)
512
+
497
513
  _set_envs_and_config(server_args)
498
514
 
499
515
  if server_args.model_path:
sglang/bench_serving.py CHANGED
@@ -295,7 +295,7 @@ async def async_request_truss(
295
295
  # NOTE: Some completion API might have a last
296
296
  # usage summary response without a token so we
297
297
  # want to check a token was generated
298
- if data["choices"][0]["delta"]["content"]:
298
+ if data["choices"][0]["text"]:
299
299
  timestamp = time.perf_counter()
300
300
  # First token
301
301
  if ttft == 0.0:
@@ -307,7 +307,7 @@ async def async_request_truss(
307
307
  output.itl.append(timestamp - most_recent_timestamp)
308
308
 
309
309
  most_recent_timestamp = timestamp
310
- generated_text += data["choices"][0]["delta"]["content"]
310
+ generated_text += data["choices"][0]["text"]
311
311
 
312
312
  output.generated_text = generated_text
313
313
  output.success = True
@@ -690,7 +690,6 @@ def sample_random_requests(
690
690
  dataset_path: str,
691
691
  random_sample: bool = True,
692
692
  ) -> List[Tuple[str, int, int]]:
693
-
694
693
  input_lens = np.random.randint(
695
694
  max(int(input_len * range_ratio), 1),
696
695
  input_len + 1,
@@ -978,6 +977,7 @@ async def benchmark(
978
977
  profile: bool,
979
978
  pd_seperated: bool = False,
980
979
  flush_cache: bool = False,
980
+ warmup_requests: int = 1,
981
981
  ):
982
982
  if backend in ASYNC_REQUEST_FUNCS:
983
983
  request_func = ASYNC_REQUEST_FUNCS[backend]
@@ -995,11 +995,11 @@ async def benchmark(
995
995
  return await request_func(request_func_input=request_func_input, pbar=pbar)
996
996
 
997
997
  # Warmup
998
- print(f"Starting warmup with {args.warmup_requests} sequences...")
998
+ print(f"Starting warmup with {warmup_requests} sequences...")
999
999
 
1000
1000
  # Use the first request for all warmup iterations
1001
1001
  test_prompt, test_prompt_len, test_output_len = input_requests[0]
1002
- if lora_names != None and len(lora_names) != 0:
1002
+ if lora_names is not None and len(lora_names) != 0:
1003
1003
  lora_name = lora_names[0]
1004
1004
  else:
1005
1005
  lora_name = None
@@ -1017,7 +1017,7 @@ async def benchmark(
1017
1017
 
1018
1018
  # Run warmup requests
1019
1019
  warmup_tasks = []
1020
- for _ in range(args.warmup_requests):
1020
+ for _ in range(warmup_requests):
1021
1021
  warmup_tasks.append(
1022
1022
  asyncio.create_task(request_func(request_func_input=test_input))
1023
1023
  )
@@ -1025,7 +1025,7 @@ async def benchmark(
1025
1025
  warmup_outputs = await asyncio.gather(*warmup_tasks)
1026
1026
 
1027
1027
  # Check if at least one warmup request succeeded
1028
- if not any(output.success for output in warmup_outputs):
1028
+ if warmup_requests > 0 and not any(output.success for output in warmup_outputs):
1029
1029
  raise ValueError(
1030
1030
  "Warmup failed - Please make sure benchmark arguments "
1031
1031
  f"are correctly specified. Error: {warmup_outputs[0].error}"
@@ -1057,7 +1057,7 @@ async def benchmark(
1057
1057
  tasks: List[asyncio.Task] = []
1058
1058
  async for request in get_request(input_requests, request_rate):
1059
1059
  prompt, prompt_len, output_len = request
1060
- if lora_names != None and len(lora_names) != 0:
1060
+ if lora_names is not None and len(lora_names) != 0:
1061
1061
  idx = random.randint(0, len(lora_names) - 1)
1062
1062
  lora_name = lora_names[idx]
1063
1063
  else:
@@ -0,0 +1,177 @@
1
+ """
2
+ Compile DeepGEMM Kernels for a model with specify server arguments
3
+
4
+ This script launches a server for capturing DeepGEMM calls and then compiles the kernels.
5
+ It accepts server arguments (the same as launch_server.py).
6
+
7
+ Usage:
8
+ python3 -m sglang.compile_deep_gemm --model deepseek-ai/DeepSeek-V3 --tp 8 --trust-remote-code
9
+
10
+ """
11
+
12
+ import argparse
13
+ import dataclasses
14
+ import multiprocessing
15
+ import os
16
+ import time
17
+
18
+ import requests
19
+
20
+ from sglang.srt.entrypoints.http_server import launch_server
21
+ from sglang.srt.managers.io_struct import GenerateReqInput
22
+ from sglang.srt.managers.tokenizer_manager import TokenizerManager
23
+ from sglang.srt.server_args import ServerArgs
24
+ from sglang.srt.utils import kill_process_tree
25
+ from sglang.srt.warmup import warmup
26
+
27
+ multiprocessing.set_start_method("spawn", force=True)
28
+
29
+ # Reduce warning
30
+ os.environ["SGL_IN_DEEPGEMM_PRECOMPILE_STAGE"] = "1"
31
+ # Force enable deep gemm
32
+ os.environ["SGL_ENABLE_JIT_DEEPGEMM"] = "1"
33
+ # Force enable mha chunked kv for DeepSeek V3 to avoid missing kv_b_proj DeepGEMM case
34
+ os.environ["SGL_CHUNKED_PREFIX_CACHE_THRESHOLD"] = "0"
35
+
36
+
37
+ @dataclasses.dataclass
38
+ class CompileArgs:
39
+ timeout: int = 3600
40
+
41
+ @staticmethod
42
+ def add_cli_args(parser: argparse.ArgumentParser):
43
+ parser.add_argument("--timeout", type=int, default=CompileArgs.timeout)
44
+
45
+ @classmethod
46
+ def from_cli_args(cls, args: argparse.Namespace):
47
+ # use the default value's type to cast the args into correct types.
48
+ attrs = [(attr.name, type(attr.default)) for attr in dataclasses.fields(cls)]
49
+ return cls(
50
+ **{attr: attr_type(getattr(args, attr)) for attr, attr_type in attrs}
51
+ )
52
+
53
+
54
+ @warmup("compile-deep-gemm")
55
+ async def warm_up_compile(tokenizer_manager: TokenizerManager):
56
+ print("\nGenerate warm up request for compiling DeepGEMM...\n")
57
+ generate_req_input = GenerateReqInput(
58
+ input_ids=[0, 1, 2, 3],
59
+ sampling_params={
60
+ "temperature": 0.0,
61
+ "max_new_tokens": 8,
62
+ "ignore_eos": True,
63
+ },
64
+ )
65
+ await tokenizer_manager.generate_request(generate_req_input, None).__anext__()
66
+
67
+
68
+ def launch_server_internal(server_args):
69
+ try:
70
+ launch_server(server_args)
71
+ except Exception as e:
72
+ raise e
73
+ finally:
74
+ kill_process_tree(os.getpid(), include_parent=False)
75
+
76
+
77
+ def launch_server_process_and_send_one_request(
78
+ server_args: ServerArgs, compile_args: CompileArgs
79
+ ):
80
+ proc = multiprocessing.Process(target=launch_server_internal, args=(server_args,))
81
+ proc.start()
82
+ base_url = f"http://{server_args.host}:{server_args.port}"
83
+ timeout = compile_args.timeout
84
+
85
+ start_time = time.time()
86
+ while time.time() - start_time < timeout:
87
+ try:
88
+ headers = {
89
+ "Content-Type": "application/json; charset=utf-8",
90
+ }
91
+ if server_args.node_rank == 0:
92
+ response = requests.get(f"{base_url}/v1/models", headers=headers)
93
+ else:
94
+ # This http api is created by launch_dummy_health_check_server for none-rank0 node.
95
+ response = requests.get(f"{base_url}/health", headers=headers)
96
+ if response.status_code == 200:
97
+ # Rank-0 node send a request to sync with other node and then return.
98
+ if server_args.node_rank == 0:
99
+ response = requests.post(
100
+ f"{base_url}/generate",
101
+ json={
102
+ "input_ids": [0, 1, 2, 3],
103
+ "sampling_params": {
104
+ "max_new_tokens": 8,
105
+ "temperature": 0,
106
+ },
107
+ },
108
+ timeout=600,
109
+ )
110
+ if response.status_code != 200:
111
+ error = response.json()
112
+ raise RuntimeError(f"Sync request failed: {error}")
113
+ # Other nodes should wait for the exit signal from Rank-0 node.
114
+ else:
115
+ start_time_waiting = time.time()
116
+ while proc.is_alive():
117
+ if time.time() - start_time_waiting < timeout:
118
+ time.sleep(10)
119
+ else:
120
+ raise TimeoutError("Waiting for main node timeout!")
121
+ return proc
122
+ except requests.RequestException:
123
+ pass
124
+ time.sleep(10)
125
+ raise TimeoutError(
126
+ "DeepGEMM Kernels compilation timeout."
127
+ "\n\nFeel free and please restart the command."
128
+ )
129
+
130
+
131
+ def refine_server_args(server_args: ServerArgs, compile_args: CompileArgs):
132
+ # Disbale cuda graph and torch compile to save time
133
+ server_args.disable_cuda_graph = True
134
+ server_args.enable_torch_compile = False
135
+ print(f"Disable CUDA Graph and Torch Compile to save time...")
136
+
137
+ # Set watchdog timeout to compile_args.timeout because compilation will take a long time
138
+ server_args.watchdog_timeout = compile_args.timeout
139
+ server_args.warmups = "compile-deep-gemm"
140
+
141
+
142
+ def run_compile(server_args: ServerArgs, compile_args: CompileArgs):
143
+ print(
144
+ "Begin DeepGEMM Kernels compilation...\n"
145
+ "It may take a long time and timeout maybe raised "
146
+ "while the compilation is still in progress.\n"
147
+ "Just feel free to restart the command "
148
+ "until the compilation is fully finished.\n"
149
+ )
150
+
151
+ proc = launch_server_process_and_send_one_request(server_args, compile_args)
152
+
153
+ print("\nDeepGEMM Kernels compilation finished successfully.")
154
+
155
+ # Sleep for safety
156
+ time.sleep(10)
157
+ if proc.is_alive():
158
+ # This is the rank0 node.
159
+ kill_process_tree(proc.pid)
160
+ else:
161
+ try:
162
+ kill_process_tree(proc.pid)
163
+ except Exception:
164
+ pass
165
+
166
+
167
+ if __name__ == "__main__":
168
+ parser = argparse.ArgumentParser()
169
+ ServerArgs.add_cli_args(parser)
170
+ CompileArgs.add_cli_args(parser)
171
+ args = parser.parse_args()
172
+ server_args = ServerArgs.from_cli_args(args)
173
+ compile_args = CompileArgs.from_cli_args(args)
174
+
175
+ refine_server_args(server_args, compile_args)
176
+
177
+ run_compile(server_args, compile_args)
@@ -161,7 +161,11 @@ class OpenAI(BaseBackend):
161
161
  prompt = s.text_
162
162
 
163
163
  kwargs = sampling_params.to_openai_kwargs()
164
- if self.model_name.startswith("o1") or self.model_name.startswith("o3"):
164
+ if (
165
+ self.model_name.startswith("o1")
166
+ or self.model_name.startswith("o3")
167
+ or "o1" in self.model_name
168
+ ):
165
169
  kwargs.pop("max_tokens", None)
166
170
  else:
167
171
  kwargs.pop("max_completion_tokens", None)
@@ -324,7 +324,11 @@ class RuntimeEndpoint(BaseBackend):
324
324
 
325
325
  def _assert_success(self, res):
326
326
  if res.status_code != 200:
327
- raise RuntimeError(res.json())
327
+ try:
328
+ content = res.json()
329
+ except json.JSONDecodeError:
330
+ content = res.text
331
+ raise RuntimeError(content)
328
332
 
329
333
 
330
334
  def compute_normalized_prompt_logprobs(input_logprobs):
@@ -113,7 +113,7 @@ def completion_template_exists(template_name: str) -> bool:
113
113
 
114
114
  def is_completion_template_defined() -> bool:
115
115
  global completion_template_name
116
- return completion_template_name != None
116
+ return completion_template_name is not None
117
117
 
118
118
 
119
119
  def generate_completion_prompt_from_request(request: ChatCompletionRequest) -> str:
@@ -182,7 +182,7 @@ class DeepseekVLV2Processor(ProcessorMixin):
182
182
  tokenized_str, images, seq_mask, spatial_crop = self.tokenize_with_images(
183
183
  messages,
184
184
  pil_images[image_index : image_index + image_token_cnt],
185
- bos=False,
185
+ bos=True,
186
186
  eos=True,
187
187
  cropping=len(pil_images) <= 2,
188
188
  max_req_input_len=max_req_input_len,
@@ -73,8 +73,15 @@ class ModelConfig:
73
73
  )
74
74
 
75
75
  if enable_multimodal is None:
76
- if self.hf_config.architectures == "Llama4ForConditionalGeneration":
76
+ mm_disabled_models = [
77
+ "Gemma3ForConditionalGeneration",
78
+ "Llama4ForConditionalGeneration",
79
+ ]
80
+ if self.hf_config.architectures[0] in mm_disabled_models:
77
81
  enable_multimodal = False
82
+ logger.info(
83
+ f"Multimodal is disabled for {self.hf_config.model_type}. To enable it, set --enable-multimodal."
84
+ )
78
85
  else:
79
86
  enable_multimodal = True
80
87
 
@@ -155,7 +162,9 @@ class ModelConfig:
155
162
  self.attention_arch = AttentionArch.MLA
156
163
  self.kv_lora_rank = self.hf_config.kv_lora_rank
157
164
  self.qk_rope_head_dim = self.hf_config.qk_rope_head_dim
158
- elif "DeepseekVL2ForCausalLM" in self.hf_config.architectures:
165
+ elif "DeepseekVL2ForCausalLM" in self.hf_config.architectures and getattr(
166
+ self.hf_text_config, "use_mla", True
167
+ ):
159
168
  self.head_dim = 256
160
169
  self.attention_arch = AttentionArch.MLA
161
170
  self.kv_lora_rank = self.hf_text_config.kv_lora_rank
@@ -14,49 +14,48 @@
14
14
  """Constrained decoding with llguidance backend."""
15
15
 
16
16
  import json
17
+ import logging
17
18
  import os
18
19
  from typing import List, Optional, Tuple
19
20
 
20
- import llguidance
21
- import llguidance.hf
22
- import llguidance.torch
23
21
  import torch
24
- from llguidance.gbnf_to_lark import any_to_lark
22
+ from llguidance import LLMatcher, LLTokenizer, StructTag, grammar_from
23
+ from llguidance.hf import from_tokenizer
24
+ from llguidance.torch import (
25
+ allocate_token_bitmask,
26
+ apply_token_bitmask_inplace,
27
+ fill_next_token_bitmask,
28
+ )
25
29
 
26
30
  from sglang.srt.constrained.base_grammar_backend import (
27
31
  BaseGrammarBackend,
28
32
  BaseGrammarObject,
29
33
  )
30
34
 
35
+ logger = logging.getLogger(__name__)
36
+
31
37
 
32
38
  class GuidanceGrammar(BaseGrammarObject):
33
- def __init__(
34
- self, llguidance_tokenizer: llguidance.LLTokenizer, serialized_grammar: str
35
- ):
39
+
40
+ def __init__(self, llguidance_tokenizer: LLTokenizer, serialized_grammar: str):
36
41
  super().__init__()
37
42
  self.llguidance_tokenizer = llguidance_tokenizer
38
43
  self.serialized_grammar = serialized_grammar
39
44
 
40
- # TODO: add support for fast-forward tokens in the future
41
- self.ll_interpreter = llguidance.LLInterpreter(
45
+ self.ll_matcher = LLMatcher(
42
46
  self.llguidance_tokenizer,
43
47
  self.serialized_grammar,
44
- enable_backtrack=False,
45
- enable_ff_tokens=False,
46
48
  log_level=int(os.environ.get("LLGUIDANCE_LOG_LEVEL", "1")),
47
49
  )
48
- self.pending_ff_tokens: list[int] = []
49
50
  self.finished = False
50
51
  self.bitmask = None
51
52
 
52
53
  def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]:
53
- if len(self.pending_ff_tokens) > 0:
54
- s = self.llguidance_tokenizer.decode_str(self.pending_ff_tokens)
55
- ff_tokens = self.pending_ff_tokens
56
- self.pending_ff_tokens = []
57
- return (ff_tokens, s)
58
-
59
- return None
54
+ ff_tokens = self.ll_matcher.compute_ff_tokens()
55
+ if ff_tokens:
56
+ return ff_tokens, ""
57
+ else:
58
+ return None
60
59
 
61
60
  def jump_forward_str_state(self, helper: Tuple[List[int], str]) -> Tuple[str, int]:
62
61
  return "", -1
@@ -67,32 +66,22 @@ class GuidanceGrammar(BaseGrammarObject):
67
66
  pass
68
67
 
69
68
  def accept_token(self, token: int):
70
- backtrack, ff_tokens = self.ll_interpreter.commit_token(token)
71
- if len(ff_tokens) > 0 and backtrack == 0:
72
- # first token is last generated token
73
- ff_tokens = ff_tokens[1:]
74
- self.pending_ff_tokens.extend(ff_tokens)
69
+ if not self.ll_matcher.consume_token(token):
70
+ logger.warning(f"matcher error: {self.ll_matcher.get_error()}")
71
+ self.finished = True
75
72
 
76
73
  def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
77
- if len(self.pending_ff_tokens) > 0:
78
- # if we have pending fast-forward tokens,
79
- # just return them immediately
80
- ff_token = self.pending_ff_tokens.pop(0)
81
- vocab_mask[idx, :] = 0
82
- vocab_mask[idx, ff_token // 32] = 1 << (ff_token % 32)
83
- return
84
-
85
- if self.ll_interpreter.has_pending_stop():
74
+ if self.ll_matcher.is_stopped():
86
75
  self.finished = True
87
76
 
88
- llguidance.torch.fill_next_token_bitmask(self.ll_interpreter, vocab_mask, idx)
77
+ fill_next_token_bitmask(self.ll_matcher, vocab_mask, idx)
89
78
 
90
79
  def allocate_vocab_mask(
91
80
  self, vocab_size: int, batch_size: int, device
92
81
  ) -> torch.Tensor:
93
82
  if self.bitmask is None or self.bitmask.shape[0] < batch_size:
94
83
  # only create bitmask when batch gets larger
95
- self.bitmask = llguidance.torch.allocate_token_bitmask(
84
+ self.bitmask = allocate_token_bitmask(
96
85
  batch_size, self.llguidance_tokenizer.vocab_size
97
86
  )
98
87
  bitmask = self.bitmask
@@ -107,7 +96,7 @@ class GuidanceGrammar(BaseGrammarObject):
107
96
 
108
97
  @staticmethod
109
98
  def apply_vocab_mask(logits: torch.Tensor, vocab_mask: torch.Tensor) -> None:
110
- llguidance.torch.apply_token_bitmask_inplace(logits, vocab_mask)
99
+ apply_token_bitmask_inplace(logits, vocab_mask)
111
100
 
112
101
  def copy(self):
113
102
  return GuidanceGrammar(
@@ -117,36 +106,64 @@ class GuidanceGrammar(BaseGrammarObject):
117
106
 
118
107
 
119
108
  class GuidanceBackend(BaseGrammarBackend):
120
- def __init__(self, tokenizer, whitespace_pattern: Optional[str] = None):
109
+
110
+ def __init__(
111
+ self,
112
+ tokenizer,
113
+ whitespace_pattern: Optional[str] = None,
114
+ n_vocab: Optional[int] = None,
115
+ ):
121
116
  super().__init__()
122
117
 
123
118
  self.tokenizer = tokenizer
124
- self.whitespace_flexible = (
125
- True if whitespace_pattern == "whitespace_flexible" else False
126
- )
127
- self.llguidance_tokenizer = llguidance.hf.from_tokenizer(self.tokenizer, None)
128
-
129
- def _from_serialized(self, serialized_grammar) -> GuidanceGrammar:
130
- return GuidanceGrammar(
131
- llguidance_tokenizer=self.llguidance_tokenizer,
132
- serialized_grammar=serialized_grammar,
119
+ self.whitespace_pattern = whitespace_pattern
120
+ self.llguidance_tokenizer = from_tokenizer(self.tokenizer, n_vocab)
121
+
122
+ def _from_serialized(self, serialized_grammar) -> Optional[GuidanceGrammar]:
123
+ try:
124
+ return GuidanceGrammar(
125
+ llguidance_tokenizer=self.llguidance_tokenizer,
126
+ serialized_grammar=serialized_grammar,
127
+ )
128
+ except Exception as e:
129
+ logger.warning(f"Skip invalid grammar: {serialized_grammar}, {e=}")
130
+ return None
131
+
132
+ def dispatch_json(self, key_string: str) -> Optional[GuidanceGrammar]:
133
+ serialized_grammar = LLMatcher.grammar_from_json_schema(
134
+ key_string,
135
+ defaults={
136
+ "whitespace_pattern": self.whitespace_pattern,
137
+ },
133
138
  )
134
-
135
- def dispatch_json(self, key_string: str) -> GuidanceGrammar:
136
- json_schema = key_string
137
- compiler = llguidance.JsonCompiler(whitespace_flexible=self.whitespace_flexible)
138
- serialized_grammar = compiler.compile(json_schema)
139
- return self._from_serialized(serialized_grammar)
140
-
141
- def dispatch_regex(self, key_string: str) -> GuidanceGrammar:
142
- compiler = llguidance.RegexCompiler()
143
- serialized_grammar = compiler.compile(regex=key_string)
144
139
  return self._from_serialized(serialized_grammar)
145
140
 
146
- def dispatch_ebnf(self, key_string: str) -> GuidanceGrammar:
147
- compiler = llguidance.LarkCompiler()
148
- serialized_grammar = compiler.compile(any_to_lark(key_string))
141
+ def dispatch_regex(self, key_string: str) -> Optional[GuidanceGrammar]:
142
+ serialized_grammar = grammar_from("regex", key_string)
149
143
  return self._from_serialized(serialized_grammar)
150
144
 
151
- def dispatch_structural_tag(self, key_string: str):
152
- return super().dispatch_structural_tag(key_string)
145
+ def dispatch_ebnf(self, key_string: str) -> Optional[GuidanceGrammar]:
146
+ try:
147
+ serialized_grammar = grammar_from("ebnf", key_string)
148
+ return self._from_serialized(serialized_grammar)
149
+ except ValueError as e:
150
+ logger.warning(f"Skip invalid ebnf: regex={key_string}, {e=}")
151
+ return None
152
+
153
+ def dispatch_structural_tag(self, key_string: str) -> Optional[GuidanceGrammar]:
154
+ try:
155
+ structural_tag = json.loads(key_string)
156
+ tags = [
157
+ StructTag(
158
+ begin=structure["begin"],
159
+ grammar=structure["schema"],
160
+ end=structure["end"],
161
+ trigger=structural_tag["triggers"][0], # TODO?
162
+ )
163
+ for structure in structural_tag["structures"]
164
+ ]
165
+ g = StructTag.to_grammar(tags)
166
+ return self._from_serialized(g)
167
+ except Exception as e:
168
+ logging.warning(f"Skip invalid structural_tag: {key_string}, {e=}")
169
+ return None
@@ -158,6 +158,7 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
158
158
  def dispatch_json(self, key_string: str) -> Optional[XGrammarGrammar]:
159
159
  try:
160
160
  if key_string == "$$ANY$$":
161
+ # Note: This builtin JSON grammar includes *all* valid JSON (including, for example, arrays at the root)
161
162
  ctx = self.grammar_compiler.compile_builtin_json_grammar()
162
163
  else:
163
164
  ctx = self.grammar_compiler.compile_json_schema(schema=key_string)
@@ -463,6 +463,30 @@ def generate_embedding_convs(
463
463
  return convs
464
464
 
465
465
 
466
+ # Models in which system adds modality tokens at prompt start automatically
467
+ # when media inputs exceed modality tokens in prompt (e.g. 3 images but 2 <image> tokens)
468
+ _MODELS_REQUIRING_MODALITY_SUPPLEMENT = {"deepseek-vl2"}
469
+
470
+
471
+ # adapted from https://github.com/vllm-project/vllm/blob/5124f5bf51b83e6f344c1bc6652e8c4d81313b34/vllm/entrypoints/chat_utils.py#L856
472
+ def _get_full_multimodal_text_prompt(
473
+ modality_token: str, modality_count: int, text_prompt: str
474
+ ) -> str:
475
+ """Combine multimodal prompts for a multimodal language model."""
476
+
477
+ # For any existing placeholder in the text prompt, we leave it as is
478
+ left: int = modality_count - text_prompt.count(modality_token)
479
+ if left < 0:
480
+ raise ValueError(
481
+ f"Found more '{modality_token}' placeholders in input prompt than "
482
+ "actual multimodal data items."
483
+ )
484
+
485
+ # NOTE: For now we always add missing modality_token at the front of
486
+ # the prompt. This may change to be customizable in the future.
487
+ return "\n".join([modality_token] * left + [text_prompt])
488
+
489
+
466
490
  def generate_chat_conv(
467
491
  request: ChatCompletionRequest, template_name: str
468
492
  ) -> Conversation:
@@ -520,6 +544,12 @@ def generate_chat_conv(
520
544
  if conv.name != "qwen2-vl"
521
545
  else conv.image_token
522
546
  )
547
+ add_token_as_needed: bool = (
548
+ conv.name in _MODELS_REQUIRING_MODALITY_SUPPLEMENT
549
+ )
550
+ if add_token_as_needed:
551
+ image_token = ""
552
+
523
553
  audio_token = conv.audio_token
524
554
  for content in message.content:
525
555
  if content.type == "text":
@@ -533,7 +563,10 @@ def generate_chat_conv(
533
563
  elif content.type == "audio_url":
534
564
  real_content += audio_token
535
565
  conv.append_audio(content.audio_url.url)
536
-
566
+ if add_token_as_needed:
567
+ real_content = _get_full_multimodal_text_prompt(
568
+ conv.image_token, num_image_url, real_content
569
+ )
537
570
  conv.append_message(conv.roles[0], real_content)
538
571
  elif msg_role == "assistant":
539
572
  parsed_content = ""