sglang 0.4.4.post4__py3-none-any.whl → 0.4.5.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (134) hide show
  1. sglang/bench_one_batch.py +21 -0
  2. sglang/bench_serving.py +10 -4
  3. sglang/lang/chat_template.py +24 -0
  4. sglang/srt/configs/model_config.py +40 -4
  5. sglang/srt/constrained/base_grammar_backend.py +26 -5
  6. sglang/srt/constrained/llguidance_backend.py +1 -0
  7. sglang/srt/constrained/outlines_backend.py +1 -0
  8. sglang/srt/constrained/reasoner_grammar_backend.py +101 -0
  9. sglang/srt/constrained/xgrammar_backend.py +1 -0
  10. sglang/srt/conversation.py +29 -4
  11. sglang/srt/disaggregation/base/__init__.py +8 -0
  12. sglang/srt/disaggregation/base/conn.py +113 -0
  13. sglang/srt/disaggregation/decode.py +18 -5
  14. sglang/srt/disaggregation/mini_lb.py +53 -122
  15. sglang/srt/disaggregation/mooncake/__init__.py +6 -0
  16. sglang/srt/disaggregation/mooncake/conn.py +615 -0
  17. sglang/srt/disaggregation/mooncake/transfer_engine.py +108 -0
  18. sglang/srt/disaggregation/prefill.py +43 -19
  19. sglang/srt/disaggregation/utils.py +31 -0
  20. sglang/srt/entrypoints/EngineBase.py +53 -0
  21. sglang/srt/entrypoints/engine.py +36 -8
  22. sglang/srt/entrypoints/http_server.py +37 -8
  23. sglang/srt/entrypoints/http_server_engine.py +142 -0
  24. sglang/srt/entrypoints/verl_engine.py +37 -10
  25. sglang/srt/hf_transformers_utils.py +4 -0
  26. sglang/srt/layers/attention/flashattention_backend.py +609 -202
  27. sglang/srt/layers/attention/flashinfer_backend.py +13 -7
  28. sglang/srt/layers/attention/vision.py +1 -1
  29. sglang/srt/layers/dp_attention.py +2 -4
  30. sglang/srt/layers/elementwise.py +15 -2
  31. sglang/srt/layers/linear.py +1 -0
  32. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +145 -118
  33. sglang/srt/layers/moe/fused_moe_native.py +5 -0
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  35. sglang/srt/layers/moe/fused_moe_triton/configs/E=144,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  36. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  37. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H200.json +146 -0
  38. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  39. sglang/srt/layers/moe/fused_moe_triton/configs/E=20,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  40. sglang/srt/layers/moe/fused_moe_triton/configs/E=24,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  41. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  43. sglang/srt/layers/moe/fused_moe_triton/configs/{E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +34 -34
  44. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  45. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  46. sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  47. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +51 -24
  48. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
  49. sglang/srt/layers/moe/router.py +7 -1
  50. sglang/srt/layers/moe/topk.py +37 -16
  51. sglang/srt/layers/quantization/__init__.py +13 -5
  52. sglang/srt/layers/quantization/blockwise_int8.py +2 -0
  53. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +4 -0
  54. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +68 -45
  55. sglang/srt/layers/quantization/fp8.py +28 -14
  56. sglang/srt/layers/quantization/fp8_kernel.py +130 -4
  57. sglang/srt/layers/quantization/fp8_utils.py +34 -6
  58. sglang/srt/layers/quantization/kv_cache.py +43 -52
  59. sglang/srt/layers/quantization/modelopt_quant.py +271 -4
  60. sglang/srt/layers/quantization/moe_wna16.py +2 -0
  61. sglang/srt/layers/quantization/w8a8_fp8.py +154 -4
  62. sglang/srt/layers/quantization/w8a8_int8.py +3 -0
  63. sglang/srt/layers/radix_attention.py +14 -0
  64. sglang/srt/layers/rotary_embedding.py +75 -1
  65. sglang/srt/managers/io_struct.py +254 -97
  66. sglang/srt/managers/mm_utils.py +3 -2
  67. sglang/srt/managers/multimodal_processors/base_processor.py +114 -77
  68. sglang/srt/managers/multimodal_processors/janus_pro.py +3 -1
  69. sglang/srt/managers/multimodal_processors/mllama4.py +146 -0
  70. sglang/srt/managers/schedule_batch.py +62 -21
  71. sglang/srt/managers/scheduler.py +71 -14
  72. sglang/srt/managers/tokenizer_manager.py +17 -3
  73. sglang/srt/managers/tp_worker.py +1 -0
  74. sglang/srt/mem_cache/memory_pool.py +14 -1
  75. sglang/srt/metrics/collector.py +9 -0
  76. sglang/srt/model_executor/cuda_graph_runner.py +7 -4
  77. sglang/srt/model_executor/forward_batch_info.py +234 -15
  78. sglang/srt/model_executor/model_runner.py +49 -9
  79. sglang/srt/model_loader/loader.py +31 -4
  80. sglang/srt/model_loader/weight_utils.py +4 -2
  81. sglang/srt/models/baichuan.py +2 -0
  82. sglang/srt/models/chatglm.py +1 -0
  83. sglang/srt/models/commandr.py +1 -0
  84. sglang/srt/models/dbrx.py +1 -0
  85. sglang/srt/models/deepseek.py +1 -0
  86. sglang/srt/models/deepseek_v2.py +248 -61
  87. sglang/srt/models/exaone.py +1 -0
  88. sglang/srt/models/gemma.py +1 -0
  89. sglang/srt/models/gemma2.py +1 -0
  90. sglang/srt/models/gemma3_causal.py +1 -0
  91. sglang/srt/models/gpt2.py +1 -0
  92. sglang/srt/models/gpt_bigcode.py +1 -0
  93. sglang/srt/models/granite.py +1 -0
  94. sglang/srt/models/grok.py +1 -0
  95. sglang/srt/models/internlm2.py +1 -0
  96. sglang/srt/models/llama.py +13 -4
  97. sglang/srt/models/llama4.py +487 -0
  98. sglang/srt/models/minicpm.py +1 -0
  99. sglang/srt/models/minicpm3.py +2 -0
  100. sglang/srt/models/mixtral.py +1 -0
  101. sglang/srt/models/mixtral_quant.py +1 -0
  102. sglang/srt/models/mllama.py +51 -8
  103. sglang/srt/models/mllama4.py +227 -0
  104. sglang/srt/models/olmo.py +1 -0
  105. sglang/srt/models/olmo2.py +1 -0
  106. sglang/srt/models/olmoe.py +1 -0
  107. sglang/srt/models/phi3_small.py +1 -0
  108. sglang/srt/models/qwen.py +1 -0
  109. sglang/srt/models/qwen2.py +1 -0
  110. sglang/srt/models/qwen2_5_vl.py +35 -70
  111. sglang/srt/models/qwen2_moe.py +1 -0
  112. sglang/srt/models/qwen2_vl.py +27 -25
  113. sglang/srt/models/stablelm.py +1 -0
  114. sglang/srt/models/xverse.py +1 -0
  115. sglang/srt/models/xverse_moe.py +1 -0
  116. sglang/srt/openai_api/adapter.py +4 -1
  117. sglang/srt/patch_torch.py +11 -0
  118. sglang/srt/server_args.py +34 -0
  119. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -4
  120. sglang/srt/speculative/eagle_utils.py +1 -11
  121. sglang/srt/speculative/eagle_worker.py +6 -2
  122. sglang/srt/utils.py +120 -9
  123. sglang/test/attention/test_flashattn_backend.py +259 -221
  124. sglang/test/attention/test_flashattn_mla_backend.py +285 -0
  125. sglang/test/attention/test_prefix_chunk_info.py +224 -0
  126. sglang/test/test_block_fp8.py +57 -0
  127. sglang/test/test_utils.py +19 -8
  128. sglang/version.py +1 -1
  129. {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/METADATA +14 -4
  130. {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/RECORD +133 -109
  131. sglang/srt/disaggregation/conn.py +0 -81
  132. {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/WHEEL +0 -0
  133. {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/licenses/LICENSE +0 -0
  134. {sglang-0.4.4.post4.dist-info → sglang-0.4.5.post1.dist-info}/top_level.txt +0 -0
sglang/bench_one_batch.py CHANGED
@@ -60,6 +60,7 @@ from sglang.srt.configs.model_config import ModelConfig
60
60
  from sglang.srt.entrypoints.engine import _set_envs_and_config
61
61
  from sglang.srt.hf_transformers_utils import get_tokenizer
62
62
  from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
63
+ from sglang.srt.managers.scheduler import Scheduler
63
64
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
64
65
  from sglang.srt.model_executor.model_runner import ModelRunner
65
66
  from sglang.srt.sampling.sampling_params import SamplingParams
@@ -135,6 +136,7 @@ def load_model(server_args, port_args, tp_rank):
135
136
  context_length=server_args.context_length,
136
137
  model_override_args=server_args.json_model_override_args,
137
138
  is_embedding=server_args.is_embedding,
139
+ enable_multimodal=server_args.enable_multimodal,
138
140
  dtype=server_args.dtype,
139
141
  quantization=server_args.quantization,
140
142
  )
@@ -184,6 +186,7 @@ def prepare_inputs_for_correctness_test(bench_args, tokenizer):
184
186
  req.prefix_indices = []
185
187
  req.fill_ids = req.origin_input_ids
186
188
  req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
189
+ req.logprob_start_len = len(req.origin_input_ids) - 1
187
190
  reqs.append(req)
188
191
 
189
192
  return input_ids, reqs
@@ -199,6 +202,7 @@ def prepare_extend_inputs_for_correctness_test(
199
202
  i, : bench_args.cut_len
200
203
  ]
201
204
  req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
205
+ req.logprob_start_len = len(req.origin_input_ids) - 1
202
206
  return reqs
203
207
 
204
208
 
@@ -220,6 +224,7 @@ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
220
224
  req.prefix_indices = []
221
225
  req.fill_ids = req.origin_input_ids
222
226
  req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
227
+ req.logprob_start_len = len(req.origin_input_ids) - 1
223
228
  reqs.append(req)
224
229
 
225
230
  return reqs
@@ -238,6 +243,7 @@ def extend(reqs, model_runner):
238
243
  enable_custom_logit_processor=False,
239
244
  )
240
245
  batch.prepare_for_extend()
246
+ _maybe_prepare_dp_attn_batch(batch, model_runner)
241
247
  model_worker_batch = batch.get_model_worker_batch()
242
248
  forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
243
249
  logits_output = model_runner.forward(forward_batch)
@@ -249,6 +255,7 @@ def extend(reqs, model_runner):
249
255
  def decode(input_token_ids, batch, model_runner):
250
256
  batch.output_ids = input_token_ids
251
257
  batch.prepare_for_decode()
258
+ _maybe_prepare_dp_attn_batch(batch, model_runner)
252
259
  model_worker_batch = batch.get_model_worker_batch()
253
260
  forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
254
261
  logits_output = model_runner.forward(forward_batch)
@@ -256,6 +263,20 @@ def decode(input_token_ids, batch, model_runner):
256
263
  return next_token_ids, logits_output.next_token_logits
257
264
 
258
265
 
266
+ def _maybe_prepare_dp_attn_batch(batch: ScheduleBatch, model_runner):
267
+ if model_runner.server_args.enable_dp_attention:
268
+ Scheduler.prepare_dp_attn_batch_raw(
269
+ batch,
270
+ dp_size=model_runner.server_args.dp_size,
271
+ attn_tp_size=1,
272
+ tp_cpu_group=model_runner.tp_group.cpu_group,
273
+ get_idle_batch=None,
274
+ disable_cuda_graph=model_runner.server_args.disable_cuda_graph,
275
+ spec_algorithm=SpeculativeAlgorithm.NONE,
276
+ speculative_num_draft_tokens=None,
277
+ )
278
+
279
+
259
280
  def correctness_test(
260
281
  server_args,
261
282
  port_args,
sglang/bench_serving.py CHANGED
@@ -490,7 +490,7 @@ def get_dataset(args, tokenizer):
490
490
  prompt_suffix=args.prompt_suffix,
491
491
  apply_chat_template=args.apply_chat_template,
492
492
  )
493
- elif args.dataset_name == "random":
493
+ elif args.dataset_name.startswith("random"):
494
494
  input_requests = sample_random_requests(
495
495
  input_len=args.random_input_len,
496
496
  output_len=args.random_output_len,
@@ -498,6 +498,7 @@ def get_dataset(args, tokenizer):
498
498
  range_ratio=args.random_range_ratio,
499
499
  tokenizer=tokenizer,
500
500
  dataset_path=args.dataset_path,
501
+ random_sample=args.dataset_name == "random",
501
502
  )
502
503
  elif args.dataset_name == "generated-shared-prefix":
503
504
  input_requests = sample_generated_shared_prefix_requests(
@@ -687,6 +688,7 @@ def sample_random_requests(
687
688
  range_ratio: float,
688
689
  tokenizer: PreTrainedTokenizerBase,
689
690
  dataset_path: str,
691
+ random_sample: bool = True,
690
692
  ) -> List[Tuple[str, int, int]]:
691
693
 
692
694
  input_lens = np.random.randint(
@@ -700,11 +702,15 @@ def sample_random_requests(
700
702
  size=num_prompts,
701
703
  )
702
704
 
703
- if True:
705
+ if random_sample:
704
706
  # Sample token ids from ShareGPT and repeat/truncate them to satisfy the input_lens
705
707
 
706
708
  # Download sharegpt if necessary
707
709
  if not os.path.isfile(dataset_path):
710
+ print(
711
+ "If you do not want to randomly sample from a dataset,"
712
+ " please use --dataset-name random-ids."
713
+ )
708
714
  dataset_path = download_and_cache_file(SHAREGPT_URL)
709
715
 
710
716
  # Load the dataset.
@@ -1223,7 +1229,7 @@ async def benchmark(
1223
1229
  output_file_name = args.output_file
1224
1230
  else:
1225
1231
  now = datetime.now().strftime("%m%d")
1226
- if args.dataset_name == "random":
1232
+ if args.dataset_name.startswith("random"):
1227
1233
  output_file_name = f"{args.backend}_{now}_{args.num_prompts}_{args.random_input_len}_{args.random_output_len}.jsonl"
1228
1234
  else:
1229
1235
  output_file_name = f"{args.backend}_{now}_{args.num_prompts}_sharegpt.jsonl"
@@ -1442,7 +1448,7 @@ if __name__ == "__main__":
1442
1448
  "--dataset-name",
1443
1449
  type=str,
1444
1450
  default="sharegpt",
1445
- choices=["sharegpt", "random", "generated-shared-prefix"],
1451
+ choices=["sharegpt", "random", "random-ids", "generated-shared-prefix"],
1446
1452
  help="Name of the dataset to benchmark on.",
1447
1453
  )
1448
1454
  parser.add_argument(
@@ -294,6 +294,30 @@ register_chat_template(
294
294
  )
295
295
  )
296
296
 
297
+ # Reference: https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct/blob/main/chat_template.json
298
+ register_chat_template(
299
+ ChatTemplate(
300
+ name="llama-4",
301
+ default_system_prompt=None,
302
+ role_prefix_and_suffix={
303
+ "system": (
304
+ "<|header_start|>system<|header_end|>\n\n",
305
+ "<|eot|>",
306
+ ),
307
+ "user": (
308
+ "<|header_start|>user<|header_end|>\n\n",
309
+ "<|eot|>",
310
+ ),
311
+ "assistant": (
312
+ "<|header_start|>assistant<|header_end|>\n\n",
313
+ "<|eot|>",
314
+ ),
315
+ },
316
+ stop_str=("<|eot|>",),
317
+ image_token="<|image|>",
318
+ )
319
+ )
320
+
297
321
  # Reference: https://modelscope.cn/models/01ai/Yi-1.5-34B-Chat/file/view/master?fileName=tokenizer_config.json&status=1
298
322
  register_chat_template(
299
323
  ChatTemplate(
@@ -15,6 +15,7 @@
15
15
  import json
16
16
  import logging
17
17
  import math
18
+ import os
18
19
  from enum import IntEnum, auto
19
20
  from typing import List, Optional, Set, Union
20
21
 
@@ -42,10 +43,12 @@ class ModelConfig:
42
43
  context_length: Optional[int] = None,
43
44
  model_override_args: Optional[str] = None,
44
45
  is_embedding: Optional[bool] = None,
46
+ enable_multimodal: Optional[bool] = None,
45
47
  dtype: str = "auto",
46
48
  quantization: Optional[str] = None,
47
49
  override_config_file: Optional[str] = None,
48
50
  ) -> None:
51
+
49
52
  self.model_path = model_path
50
53
  self.revision = revision
51
54
  self.quantization = quantization
@@ -65,15 +68,32 @@ class ModelConfig:
65
68
  **kwargs,
66
69
  )
67
70
  self.hf_text_config = get_hf_text_config(self.hf_config)
71
+ self.attention_chunk_size = getattr(
72
+ self.hf_text_config, "attention_chunk_size", None
73
+ )
74
+
75
+ if enable_multimodal is None:
76
+ if self.hf_config.architectures == "Llama4ForConditionalGeneration":
77
+ enable_multimodal = False
78
+ else:
79
+ enable_multimodal = True
68
80
 
69
81
  # Check model type
70
82
  self.is_generation = is_generation_model(
71
83
  self.hf_config.architectures, is_embedding
72
84
  )
73
- self.is_multimodal = is_multimodal_model(self.hf_config.architectures)
74
- self.is_multimodal_gen = is_multimodal_gen_model(self.hf_config.architectures)
75
- self.is_image_gen = is_image_gen_model(self.hf_config.architectures)
76
- self.is_audio_model = is_audio_model(self.hf_config.architectures)
85
+ self.is_multimodal = enable_multimodal and is_multimodal_model(
86
+ self.hf_config.architectures
87
+ )
88
+ self.is_multimodal_gen = enable_multimodal and is_multimodal_gen_model(
89
+ self.hf_config.architectures
90
+ )
91
+ self.is_image_gen = enable_multimodal and is_image_gen_model(
92
+ self.hf_config.architectures
93
+ )
94
+ self.is_audio_model = enable_multimodal and is_audio_model(
95
+ self.hf_config.architectures
96
+ )
77
97
  self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures)
78
98
  self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
79
99
 
@@ -231,6 +251,20 @@ class ModelConfig:
231
251
  if quant_cfg is None:
232
252
  # compressed-tensors uses a "compression_config" key
233
253
  quant_cfg = getattr(self.hf_config, "compression_config", None)
254
+ if quant_cfg is None:
255
+ # check if is modelopt model -- modelopt doesn't have corresponding field
256
+ # in hf `config.json` but has a standalone `hf_quant_config.json` in the root directory
257
+ # example: https://huggingface.co/nvidia/Llama-3.1-8B-Instruct-FP8/tree/main
258
+ is_local = os.path.exists(self.model_path)
259
+ modelopt_quant_config = {"quant_method": "modelopt"}
260
+ if not is_local:
261
+ from huggingface_hub import HfApi
262
+
263
+ hf_api = HfApi()
264
+ if hf_api.file_exists(self.model_path, "hf_quant_config.json"):
265
+ quant_cfg = modelopt_quant_config
266
+ elif os.path.exists(os.path.join(self.model_path, "hf_quant_config.json")):
267
+ quant_cfg = modelopt_quant_config
234
268
  return quant_cfg
235
269
 
236
270
  # adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
@@ -261,6 +295,7 @@ class ModelConfig:
261
295
  "moe_wna16",
262
296
  ]
263
297
  compatible_quantization_methods = {
298
+ "modelopt_fp4": ["modelopt"],
264
299
  "w8a8_int8": ["compressed-tensors", "compressed_tensors"],
265
300
  "w8a8_fp8": ["compressed-tensors", "compressed_tensors"],
266
301
  }
@@ -468,6 +503,7 @@ multimodal_model_archs = [
468
503
  "Grok1VForCausalLM",
469
504
  "Grok1AForCausalLM",
470
505
  "LlavaLlamaForCausalLM",
506
+ "Llama4ForConditionalGeneration",
471
507
  "LlavaMistralForCausalLM",
472
508
  "LlavaQwenForCausalLM",
473
509
  "LlavaVidForCausalLM",
@@ -28,6 +28,18 @@ logger = logging.getLogger(__name__)
28
28
 
29
29
 
30
30
  class BaseGrammarObject(ABC):
31
+
32
+ def __init__(self):
33
+ self._finished = False
34
+
35
+ @property
36
+ def finished(self):
37
+ return self._finished
38
+
39
+ @finished.setter
40
+ def finished(self, finished):
41
+ self._finished = finished
42
+
31
43
  @abstractmethod
32
44
  def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]:
33
45
  """
@@ -59,6 +71,13 @@ class BaseGrammarObject(ABC):
59
71
  """
60
72
  raise NotImplementedError
61
73
 
74
+ @abstractmethod
75
+ def accept_token(self, token: int) -> None:
76
+ """
77
+ Accept a token in the grammar.
78
+ """
79
+ raise NotImplementedError
80
+
62
81
  @abstractmethod
63
82
  def allocate_vocab_mask(
64
83
  self, vocab_size: int, batch_size: int, device
@@ -90,7 +109,7 @@ class CacheEntry:
90
109
  event: Event
91
110
 
92
111
 
93
- class BaseGrammarBackend(ABC):
112
+ class BaseGrammarBackend:
94
113
  def __init__(self):
95
114
  self.executor = ThreadPoolExecutor()
96
115
  self.cache: Dict[Tuple[str, str], CacheEntry] = {}
@@ -107,19 +126,15 @@ class BaseGrammarBackend(ABC):
107
126
  """
108
127
  raise ValueError(f"Invalid key_type: {key_type}={key_string}")
109
128
 
110
- @abstractmethod
111
129
  def dispatch_json(self, key_string: str) -> Optional[BaseGrammarObject]:
112
130
  return self._not_supported("json", key_string)
113
131
 
114
- @abstractmethod
115
132
  def dispatch_regex(self, key_string: str) -> Optional[BaseGrammarObject]:
116
133
  return self._not_supported("regex", key_string)
117
134
 
118
- @abstractmethod
119
135
  def dispatch_ebnf(self, key_string: str) -> Optional[BaseGrammarObject]:
120
136
  return self._not_supported("ebnf", key_string)
121
137
 
122
- @abstractmethod
123
138
  def dispatch_structural_tag(self, key_string: str) -> Optional[BaseGrammarObject]:
124
139
  return self._not_supported("structural_tag", key_string)
125
140
 
@@ -195,4 +210,10 @@ def create_grammar_backend(
195
210
  else:
196
211
  raise ValueError(f"Invalid grammar backend: {server_args.grammar_backend}")
197
212
 
213
+ if server_args.reasoning_parser and hasattr(tokenizer, "think_end_id"):
214
+ from .reasoner_grammar_backend import ReasonerGrammarBackend
215
+
216
+ grammar_backend = ReasonerGrammarBackend(
217
+ grammar_backend, tokenizer.think_end_id
218
+ )
198
219
  return grammar_backend
@@ -33,6 +33,7 @@ class GuidanceGrammar(BaseGrammarObject):
33
33
  def __init__(
34
34
  self, llguidance_tokenizer: llguidance.LLTokenizer, serialized_grammar: str
35
35
  ):
36
+ super().__init__()
36
37
  self.llguidance_tokenizer = llguidance_tokenizer
37
38
  self.serialized_grammar = serialized_grammar
38
39
 
@@ -44,6 +44,7 @@ class OutlinesGrammar(BaseGrammarObject):
44
44
  guide: RegexGuide,
45
45
  jump_forward_map: Union[OutlinesJumpForwardMap, None],
46
46
  ) -> None:
47
+ super().__init__()
47
48
  self.guide = guide
48
49
  self.jump_forward_map = jump_forward_map
49
50
  self.state = 0
@@ -0,0 +1,101 @@
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+ """The baseclass of a backend for reasoner grammar-guided constrained decoding."""
15
+
16
+ from concurrent.futures import Future
17
+ from typing import List, Optional, Tuple
18
+
19
+ import torch
20
+
21
+ from .base_grammar_backend import BaseGrammarBackend, BaseGrammarObject
22
+
23
+
24
+ class ReasonerGrammarObject(BaseGrammarObject):
25
+ def __init__(self, grammar: BaseGrammarObject, think_end_id):
26
+ super().__init__()
27
+ self.grammar = grammar
28
+ self.think_end_id = think_end_id
29
+ self.is_in_reasoning = True
30
+
31
+ @property
32
+ def finished(self):
33
+ return self.grammar.finished
34
+
35
+ @finished.setter
36
+ def finished(self, finished):
37
+ self.grammar.finished = finished
38
+
39
+ def allocate_vocab_mask(
40
+ self, vocab_size: int, batch_size: int, device
41
+ ) -> torch.Tensor:
42
+ return self.grammar.allocate_vocab_mask(vocab_size, batch_size, device)
43
+
44
+ def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
45
+ if not self.is_in_reasoning:
46
+ self.grammar.fill_vocab_mask(vocab_mask, idx)
47
+
48
+ def move_vocab_mask(self, vocab_mask: torch.Tensor, device) -> torch.Tensor:
49
+ return self.grammar.move_vocab_mask(vocab_mask, device)
50
+
51
+ @property
52
+ def apply_vocab_mask(self):
53
+ return self.grammar.apply_vocab_mask
54
+
55
+ def accept_token(self, token: int):
56
+ if token == self.think_end_id:
57
+ self.is_in_reasoning = False
58
+
59
+ if not self.is_in_reasoning and token != self.think_end_id:
60
+ self.grammar.accept_token(token)
61
+
62
+ def try_jump_forward(self, tokenizer):
63
+ return self.grammar.try_jump_forward(tokenizer)
64
+
65
+ def jump_forward_str_state(self, helper):
66
+ return self.grammar.jump_forward_str_state(helper)
67
+
68
+ def jump_and_retokenize(
69
+ self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
70
+ ):
71
+ return self.grammar.jump_and_retokenize(
72
+ old_output_ids, new_output_ids, next_state
73
+ )
74
+
75
+ def copy(self) -> BaseGrammarObject:
76
+ return ReasonerGrammarObject(self.grammar.copy(), self.think_end_id)
77
+
78
+
79
+ class ReasonerGrammarBackend(BaseGrammarBackend):
80
+ def __init__(self, grammar_backend: BaseGrammarBackend, think_end_id):
81
+ self.grammar_backend = grammar_backend
82
+ self.think_end_id = think_end_id
83
+
84
+ def get_cached_value(self, key: Tuple[str, str]) -> Optional[ReasonerGrammarObject]:
85
+ grammar = self.grammar_backend.get_cached_value(key)
86
+ return ReasonerGrammarObject(grammar, self.think_end_id) if grammar else None
87
+
88
+ def get_future_value(self, key: Tuple[str, str]) -> Future:
89
+ grammar = Future()
90
+
91
+ def callback(f: Future):
92
+ if result := f.result():
93
+ grammar.set_result(ReasonerGrammarObject(result, self.think_end_id))
94
+ else:
95
+ grammar.set_result(None)
96
+
97
+ self.grammar_backend.get_future_value(key).add_done_callback(callback)
98
+ return grammar
99
+
100
+ def reset(self):
101
+ self.grammar_backend.reset()
@@ -48,6 +48,7 @@ class XGrammarGrammar(BaseGrammarObject):
48
48
  ctx: CompiledGrammar,
49
49
  override_stop_tokens: Optional[Union[List[int], int]],
50
50
  ) -> None:
51
+ super().__init__()
51
52
  self.matcher = matcher
52
53
  self.vocab_size = vocab_size
53
54
  self.ctx = ctx
@@ -33,6 +33,7 @@ class SeparatorStyle(IntEnum):
33
33
  ADD_NEW_LINE_SINGLE = auto()
34
34
  LLAMA2 = auto()
35
35
  LLAMA3 = auto()
36
+ LLAMA4 = auto()
36
37
  CHATGLM = auto()
37
38
  CHATML = auto()
38
39
  CHATINTERN = auto()
@@ -156,19 +157,30 @@ class Conversation:
156
157
  else:
157
158
  ret += role + ":"
158
159
  return ret
160
+ elif self.sep_style == SeparatorStyle.LLAMA4:
161
+ # begin_of_text is added by default
162
+ if self.system_message:
163
+ ret = system_prompt
164
+ else:
165
+ ret = ""
166
+ for i, (role, message) in enumerate(self.messages):
167
+ if message:
168
+ ret += f"<|header_start|>{role}<|header_end|>\n\n"
169
+ ret += f"{message.strip()}<|eot|>"
170
+ else:
171
+ ret += f"<|header_start|>{role}<|header_end|>\n\n"
172
+ return ret
159
173
  elif self.sep_style == SeparatorStyle.LLAMA3:
160
- ret = "<|begin_of_text|>"
161
174
  if self.system_message:
162
- ret += system_prompt
175
+ ret = system_prompt
163
176
  else:
164
- ret += ""
177
+ ret = ""
165
178
  for i, (role, message) in enumerate(self.messages):
166
179
  if message:
167
180
  ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n"
168
181
  ret += f"{message.strip()}<|eot_id|>"
169
182
  else:
170
183
  ret += f"<|start_header_id|>{role}<|end_header_id|>\n\n"
171
- # print(ret)
172
184
  return ret
173
185
  elif self.sep_style == SeparatorStyle.LLAMA2:
174
186
  seps = [self.sep, self.sep2]
@@ -561,6 +573,19 @@ register_conv_template(
561
573
  )
562
574
  )
563
575
 
576
+ # reference: https://huggingface.co/meta-llama/Llama-4-Scout-17B-16E-Instruct/blob/main/chat_template.json
577
+ register_conv_template(
578
+ Conversation(
579
+ name="llama-4",
580
+ system_template="<|header_start|>system<|header_end|>\n\n{system_message}<|eot|>",
581
+ roles=("user", "assistant"),
582
+ sep_style=SeparatorStyle.LLAMA4,
583
+ sep="",
584
+ stop_str=["<|end_of_text|>", "<|eot|>", "<|eom|>"],
585
+ image_token="<|image|>",
586
+ )
587
+ )
588
+
564
589
  register_conv_template(
565
590
  Conversation(
566
591
  name="chatml",
@@ -0,0 +1,8 @@
1
+ from .conn import (
2
+ BaseKVBootstrapServer,
3
+ BaseKVManager,
4
+ BaseKVReceiver,
5
+ BaseKVSender,
6
+ KVArgs,
7
+ KVPoll,
8
+ )
@@ -0,0 +1,113 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Optional
3
+
4
+ import numpy as np
5
+ import numpy.typing as npt
6
+
7
+ from sglang.srt.disaggregation.utils import DisaggregationMode
8
+ from sglang.srt.server_args import ServerArgs
9
+
10
+
11
+ class KVArgs:
12
+ engine_rank: int
13
+ kv_data_ptrs: list[int]
14
+ kv_data_lens: list[int]
15
+ kv_item_lens: list[int]
16
+ aux_data_ptrs: list[int]
17
+ aux_data_lens: list[int]
18
+ aux_item_lens: list[int]
19
+ ib_device: str
20
+ gpu_id: int
21
+
22
+
23
+ class KVPoll:
24
+ Failed = 0
25
+ Bootstrapping = 1
26
+ WaitingForInput = 2
27
+ Transferring = 3
28
+ Success = 4
29
+
30
+
31
+ class BaseKVManager(ABC):
32
+ """Base class for managing transfers states"""
33
+
34
+ @abstractmethod
35
+ def __init__(
36
+ self,
37
+ args: KVArgs,
38
+ disaggregation_mode: DisaggregationMode,
39
+ server_args: ServerArgs,
40
+ ): ...
41
+
42
+
43
+ class BaseKVSender(ABC):
44
+
45
+ @abstractmethod
46
+ def __init__(
47
+ self, mgr: BaseKVManager, bootstrap_addr: str, bootstrap_room: int
48
+ ): ...
49
+
50
+ @abstractmethod
51
+ def init(self, num_kv_indices: int, aux_index: Optional[int] = None):
52
+ """
53
+ Notify the decoder server about the kv indices length and aux index
54
+ """
55
+ ...
56
+
57
+ @abstractmethod
58
+ def send(self, kv_indices: npt.NDArray[np.int64]):
59
+ """
60
+ Send the kv cache at the given kv indices to the decoder server
61
+ """
62
+ ...
63
+
64
+ @abstractmethod
65
+ def poll(self) -> KVPoll:
66
+ """
67
+ Check the status of the kv cache transfer
68
+ """
69
+ ...
70
+
71
+ @abstractmethod
72
+ def failure_exception(self):
73
+ """
74
+ Raise an exception if the kv cache transfer fails
75
+ """
76
+ ...
77
+
78
+
79
+ class BaseKVReceiver(ABC):
80
+
81
+ @abstractmethod
82
+ def __init__(
83
+ self,
84
+ mgr: BaseKVManager,
85
+ bootstrap_addr: str,
86
+ bootstrap_room: Optional[int] = None,
87
+ ): ...
88
+
89
+ @abstractmethod
90
+ def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None):
91
+ """
92
+ Notify the prefill server about the kv indices and aux index
93
+ """
94
+ ...
95
+
96
+ @abstractmethod
97
+ def poll(self) -> KVPoll:
98
+ """
99
+ Check the status of the kv cache transfer
100
+ """
101
+ ...
102
+
103
+ @abstractmethod
104
+ def failure_exception(self):
105
+ """
106
+ Raise an exception if the kv cache transfer fails
107
+ """
108
+ ...
109
+
110
+
111
+ class BaseKVBootstrapServer(ABC):
112
+ @abstractmethod
113
+ def __init__(self, port: int): ...