sglang 0.4.5__py3-none-any.whl → 0.4.5.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. sglang/__init__.py +2 -4
  2. sglang/bench_one_batch.py +23 -2
  3. sglang/bench_serving.py +6 -4
  4. sglang/lang/backend/anthropic.py +0 -4
  5. sglang/lang/backend/base_backend.py +1 -1
  6. sglang/lang/backend/openai.py +1 -1
  7. sglang/lang/backend/vertexai.py +0 -1
  8. sglang/lang/compiler.py +1 -7
  9. sglang/lang/tracer.py +3 -7
  10. sglang/srt/_custom_ops.py +0 -2
  11. sglang/srt/configs/model_config.py +37 -5
  12. sglang/srt/constrained/base_grammar_backend.py +26 -5
  13. sglang/srt/constrained/llguidance_backend.py +1 -0
  14. sglang/srt/constrained/outlines_backend.py +1 -0
  15. sglang/srt/constrained/outlines_jump_forward.py +14 -1
  16. sglang/srt/constrained/reasoner_grammar_backend.py +101 -0
  17. sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
  18. sglang/srt/constrained/xgrammar_backend.py +27 -4
  19. sglang/srt/custom_op.py +0 -62
  20. sglang/srt/disaggregation/base/__init__.py +8 -0
  21. sglang/srt/disaggregation/base/conn.py +113 -0
  22. sglang/srt/disaggregation/decode.py +80 -11
  23. sglang/srt/disaggregation/mini_lb.py +58 -123
  24. sglang/srt/disaggregation/mooncake/__init__.py +6 -0
  25. sglang/srt/disaggregation/mooncake/conn.py +585 -0
  26. sglang/srt/disaggregation/mooncake/transfer_engine.py +77 -0
  27. sglang/srt/disaggregation/prefill.py +82 -22
  28. sglang/srt/disaggregation/utils.py +46 -0
  29. sglang/srt/entrypoints/EngineBase.py +53 -0
  30. sglang/srt/entrypoints/engine.py +36 -8
  31. sglang/srt/entrypoints/http_server.py +37 -8
  32. sglang/srt/entrypoints/http_server_engine.py +142 -0
  33. sglang/srt/entrypoints/verl_engine.py +42 -13
  34. sglang/srt/hf_transformers_utils.py +4 -0
  35. sglang/srt/layers/activation.py +6 -8
  36. sglang/srt/layers/attention/flashattention_backend.py +430 -257
  37. sglang/srt/layers/attention/flashinfer_backend.py +18 -9
  38. sglang/srt/layers/attention/torch_native_backend.py +6 -1
  39. sglang/srt/layers/attention/triton_backend.py +6 -0
  40. sglang/srt/layers/attention/triton_ops/extend_attention.py +13 -2
  41. sglang/srt/layers/attention/vision.py +1 -1
  42. sglang/srt/layers/dp_attention.py +2 -4
  43. sglang/srt/layers/elementwise.py +15 -2
  44. sglang/srt/layers/layernorm.py +1 -1
  45. sglang/srt/layers/linear.py +18 -3
  46. sglang/srt/layers/moe/ep_moe/layer.py +15 -29
  47. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +145 -118
  48. sglang/srt/layers/moe/fused_moe_native.py +4 -0
  49. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  50. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/{E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +34 -34
  52. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +46 -34
  56. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
  57. sglang/srt/layers/moe/router.py +7 -1
  58. sglang/srt/layers/moe/topk.py +63 -45
  59. sglang/srt/layers/parameter.py +0 -2
  60. sglang/srt/layers/quantization/__init__.py +13 -5
  61. sglang/srt/layers/quantization/blockwise_int8.py +2 -0
  62. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +12 -2
  63. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -77
  64. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
  65. sglang/srt/layers/quantization/fp8.py +131 -136
  66. sglang/srt/layers/quantization/fp8_kernel.py +328 -46
  67. sglang/srt/layers/quantization/fp8_utils.py +206 -253
  68. sglang/srt/layers/quantization/kv_cache.py +43 -52
  69. sglang/srt/layers/quantization/modelopt_quant.py +271 -4
  70. sglang/srt/layers/quantization/moe_wna16.py +2 -0
  71. sglang/srt/layers/quantization/utils.py +5 -11
  72. sglang/srt/layers/quantization/w8a8_fp8.py +156 -4
  73. sglang/srt/layers/quantization/w8a8_int8.py +8 -7
  74. sglang/srt/layers/radix_attention.py +28 -1
  75. sglang/srt/layers/rotary_embedding.py +15 -3
  76. sglang/srt/layers/sampler.py +5 -10
  77. sglang/srt/lora/backend/base_backend.py +18 -2
  78. sglang/srt/lora/backend/flashinfer_backend.py +1 -1
  79. sglang/srt/lora/backend/triton_backend.py +1 -1
  80. sglang/srt/lora/layers.py +1 -1
  81. sglang/srt/lora/lora.py +1 -1
  82. sglang/srt/lora/lora_manager.py +1 -1
  83. sglang/srt/managers/detokenizer_manager.py +0 -1
  84. sglang/srt/managers/io_struct.py +255 -97
  85. sglang/srt/managers/mm_utils.py +7 -5
  86. sglang/srt/managers/multimodal_processor.py +0 -2
  87. sglang/srt/managers/multimodal_processors/base_processor.py +117 -79
  88. sglang/srt/managers/multimodal_processors/janus_pro.py +3 -1
  89. sglang/srt/managers/multimodal_processors/mllama4.py +21 -36
  90. sglang/srt/managers/schedule_batch.py +64 -25
  91. sglang/srt/managers/scheduler.py +80 -82
  92. sglang/srt/managers/tokenizer_manager.py +18 -3
  93. sglang/srt/managers/tp_worker.py +1 -0
  94. sglang/srt/mem_cache/hiradix_cache.py +5 -1
  95. sglang/srt/mem_cache/memory_pool.py +21 -3
  96. sglang/srt/metrics/collector.py +9 -0
  97. sglang/srt/model_executor/cuda_graph_runner.py +9 -6
  98. sglang/srt/model_executor/forward_batch_info.py +234 -15
  99. sglang/srt/model_executor/model_runner.py +67 -35
  100. sglang/srt/model_loader/loader.py +31 -4
  101. sglang/srt/model_loader/weight_utils.py +4 -2
  102. sglang/srt/models/baichuan.py +2 -0
  103. sglang/srt/models/bert.py +398 -0
  104. sglang/srt/models/chatglm.py +1 -0
  105. sglang/srt/models/commandr.py +1 -0
  106. sglang/srt/models/dbrx.py +1 -0
  107. sglang/srt/models/deepseek.py +2 -1
  108. sglang/srt/models/deepseek_nextn.py +74 -70
  109. sglang/srt/models/deepseek_v2.py +494 -366
  110. sglang/srt/models/exaone.py +1 -0
  111. sglang/srt/models/gemma.py +1 -0
  112. sglang/srt/models/gemma2.py +1 -0
  113. sglang/srt/models/gemma3_causal.py +1 -0
  114. sglang/srt/models/gpt2.py +1 -0
  115. sglang/srt/models/gpt_bigcode.py +1 -0
  116. sglang/srt/models/granite.py +1 -0
  117. sglang/srt/models/grok.py +1 -0
  118. sglang/srt/models/internlm2.py +1 -0
  119. sglang/srt/models/llama.py +6 -5
  120. sglang/srt/models/llama4.py +101 -34
  121. sglang/srt/models/minicpm.py +1 -0
  122. sglang/srt/models/minicpm3.py +30 -200
  123. sglang/srt/models/mixtral.py +1 -0
  124. sglang/srt/models/mixtral_quant.py +1 -0
  125. sglang/srt/models/mllama.py +51 -8
  126. sglang/srt/models/mllama4.py +102 -29
  127. sglang/srt/models/olmo.py +1 -0
  128. sglang/srt/models/olmo2.py +1 -0
  129. sglang/srt/models/olmoe.py +1 -0
  130. sglang/srt/models/phi3_small.py +1 -0
  131. sglang/srt/models/qwen.py +1 -0
  132. sglang/srt/models/qwen2.py +5 -1
  133. sglang/srt/models/qwen2_5_vl.py +35 -70
  134. sglang/srt/models/qwen2_moe.py +15 -13
  135. sglang/srt/models/qwen2_vl.py +27 -25
  136. sglang/srt/models/qwen3.py +335 -0
  137. sglang/srt/models/qwen3_moe.py +423 -0
  138. sglang/srt/models/stablelm.py +1 -0
  139. sglang/srt/models/xverse.py +1 -0
  140. sglang/srt/models/xverse_moe.py +1 -0
  141. sglang/srt/openai_api/adapter.py +4 -1
  142. sglang/srt/patch_torch.py +11 -0
  143. sglang/srt/reasoning_parser.py +0 -1
  144. sglang/srt/sampling/sampling_batch_info.py +2 -3
  145. sglang/srt/server_args.py +55 -19
  146. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -4
  147. sglang/srt/speculative/eagle_utils.py +1 -11
  148. sglang/srt/speculative/eagle_worker.py +10 -9
  149. sglang/srt/utils.py +136 -10
  150. sglang/test/attention/test_flashattn_backend.py +259 -221
  151. sglang/test/attention/test_flashattn_mla_backend.py +285 -0
  152. sglang/test/attention/test_prefix_chunk_info.py +224 -0
  153. sglang/test/runners.py +5 -1
  154. sglang/test/test_block_fp8.py +224 -0
  155. sglang/test/test_custom_ops.py +1 -1
  156. sglang/test/test_utils.py +19 -8
  157. sglang/version.py +1 -1
  158. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/METADATA +15 -5
  159. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/RECORD +162 -147
  160. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/WHEEL +1 -1
  161. sglang/lang/__init__.py +0 -0
  162. sglang/srt/disaggregation/conn.py +0 -81
  163. sglang/srt/lora/backend/__init__.py +0 -25
  164. sglang/srt/server.py +0 -18
  165. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/licenses/LICENSE +0 -0
  166. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/top_level.txt +0 -0
sglang/__init__.py CHANGED
@@ -24,6 +24,7 @@ from sglang.api import (
24
24
  user_end,
25
25
  video,
26
26
  )
27
+ from sglang.global_config import global_config
27
28
  from sglang.lang.backend.runtime_endpoint import RuntimeEndpoint
28
29
  from sglang.lang.choices import (
29
30
  greedy_token_selection,
@@ -31,6 +32,7 @@ from sglang.lang.choices import (
31
32
  unconditional_likelihood_normalized,
32
33
  )
33
34
  from sglang.utils import LazyImport
35
+ from sglang.version import __version__
34
36
 
35
37
  ServerArgs = LazyImport("sglang.srt.server_args", "ServerArgs")
36
38
  Anthropic = LazyImport("sglang.lang.backend.anthropic", "Anthropic")
@@ -38,10 +40,6 @@ LiteLLM = LazyImport("sglang.lang.backend.litellm", "LiteLLM")
38
40
  OpenAI = LazyImport("sglang.lang.backend.openai", "OpenAI")
39
41
  VertexAI = LazyImport("sglang.lang.backend.vertexai", "VertexAI")
40
42
 
41
- # Other configs
42
- from sglang.global_config import global_config
43
- from sglang.version import __version__
44
-
45
43
  __all__ = [
46
44
  "Engine",
47
45
  "Runtime",
sglang/bench_one_batch.py CHANGED
@@ -60,6 +60,7 @@ from sglang.srt.configs.model_config import ModelConfig
60
60
  from sglang.srt.entrypoints.engine import _set_envs_and_config
61
61
  from sglang.srt.hf_transformers_utils import get_tokenizer
62
62
  from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
63
+ from sglang.srt.managers.scheduler import Scheduler
63
64
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
64
65
  from sglang.srt.model_executor.model_runner import ModelRunner
65
66
  from sglang.srt.sampling.sampling_params import SamplingParams
@@ -135,6 +136,7 @@ def load_model(server_args, port_args, tp_rank):
135
136
  context_length=server_args.context_length,
136
137
  model_override_args=server_args.json_model_override_args,
137
138
  is_embedding=server_args.is_embedding,
139
+ enable_multimodal=server_args.enable_multimodal,
138
140
  dtype=server_args.dtype,
139
141
  quantization=server_args.quantization,
140
142
  )
@@ -184,6 +186,7 @@ def prepare_inputs_for_correctness_test(bench_args, tokenizer):
184
186
  req.prefix_indices = []
185
187
  req.fill_ids = req.origin_input_ids
186
188
  req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
189
+ req.logprob_start_len = len(req.origin_input_ids) - 1
187
190
  reqs.append(req)
188
191
 
189
192
  return input_ids, reqs
@@ -199,11 +202,12 @@ def prepare_extend_inputs_for_correctness_test(
199
202
  i, : bench_args.cut_len
200
203
  ]
201
204
  req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
205
+ req.logprob_start_len = len(req.origin_input_ids) - 1
202
206
  return reqs
203
207
 
204
208
 
205
209
  def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
206
- input_ids = np.ones((batch_size, input_len), dtype=np.int32)
210
+ input_ids = np.random.randint(0, 10000, (batch_size, input_len), dtype=np.int32)
207
211
  sampling_params = SamplingParams(
208
212
  temperature=0,
209
213
  max_new_tokens=BenchArgs.output_len,
@@ -220,6 +224,7 @@ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
220
224
  req.prefix_indices = []
221
225
  req.fill_ids = req.origin_input_ids
222
226
  req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
227
+ req.logprob_start_len = len(req.origin_input_ids) - 1
223
228
  reqs.append(req)
224
229
 
225
230
  return reqs
@@ -238,6 +243,7 @@ def extend(reqs, model_runner):
238
243
  enable_custom_logit_processor=False,
239
244
  )
240
245
  batch.prepare_for_extend()
246
+ _maybe_prepare_dp_attn_batch(batch, model_runner)
241
247
  model_worker_batch = batch.get_model_worker_batch()
242
248
  forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
243
249
  logits_output = model_runner.forward(forward_batch)
@@ -249,6 +255,7 @@ def extend(reqs, model_runner):
249
255
  def decode(input_token_ids, batch, model_runner):
250
256
  batch.output_ids = input_token_ids
251
257
  batch.prepare_for_decode()
258
+ _maybe_prepare_dp_attn_batch(batch, model_runner)
252
259
  model_worker_batch = batch.get_model_worker_batch()
253
260
  forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
254
261
  logits_output = model_runner.forward(forward_batch)
@@ -256,6 +263,20 @@ def decode(input_token_ids, batch, model_runner):
256
263
  return next_token_ids, logits_output.next_token_logits
257
264
 
258
265
 
266
+ def _maybe_prepare_dp_attn_batch(batch: ScheduleBatch, model_runner):
267
+ if model_runner.server_args.enable_dp_attention:
268
+ Scheduler.prepare_dp_attn_batch_raw(
269
+ batch,
270
+ dp_size=model_runner.server_args.dp_size,
271
+ attn_tp_size=1,
272
+ tp_cpu_group=model_runner.tp_group.cpu_group,
273
+ get_idle_batch=None,
274
+ disable_cuda_graph=model_runner.server_args.disable_cuda_graph,
275
+ spec_algorithm=SpeculativeAlgorithm.NONE,
276
+ speculative_num_draft_tokens=None,
277
+ )
278
+
279
+
259
280
  def correctness_test(
260
281
  server_args,
261
282
  port_args,
@@ -375,7 +396,7 @@ def latency_test_run_once(
375
396
  decode_latencies.append(latency)
376
397
  if i < 5:
377
398
  rank_print(
378
- f"Decode. latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s"
399
+ f"Decode. Batch size: {batch_size}, latency: {latency:6.5f} s, throughput: {throughput:9.2f} token/s"
379
400
  )
380
401
 
381
402
  if profile:
sglang/bench_serving.py CHANGED
@@ -490,7 +490,7 @@ def get_dataset(args, tokenizer):
490
490
  prompt_suffix=args.prompt_suffix,
491
491
  apply_chat_template=args.apply_chat_template,
492
492
  )
493
- elif args.dataset_name == "random":
493
+ elif args.dataset_name.startswith("random"):
494
494
  input_requests = sample_random_requests(
495
495
  input_len=args.random_input_len,
496
496
  output_len=args.random_output_len,
@@ -498,6 +498,7 @@ def get_dataset(args, tokenizer):
498
498
  range_ratio=args.random_range_ratio,
499
499
  tokenizer=tokenizer,
500
500
  dataset_path=args.dataset_path,
501
+ random_sample=args.dataset_name == "random",
501
502
  )
502
503
  elif args.dataset_name == "generated-shared-prefix":
503
504
  input_requests = sample_generated_shared_prefix_requests(
@@ -687,6 +688,7 @@ def sample_random_requests(
687
688
  range_ratio: float,
688
689
  tokenizer: PreTrainedTokenizerBase,
689
690
  dataset_path: str,
691
+ random_sample: bool = True,
690
692
  ) -> List[Tuple[str, int, int]]:
691
693
 
692
694
  input_lens = np.random.randint(
@@ -700,7 +702,7 @@ def sample_random_requests(
700
702
  size=num_prompts,
701
703
  )
702
704
 
703
- if True:
705
+ if random_sample:
704
706
  # Sample token ids from ShareGPT and repeat/truncate them to satisfy the input_lens
705
707
 
706
708
  # Download sharegpt if necessary
@@ -1223,7 +1225,7 @@ async def benchmark(
1223
1225
  output_file_name = args.output_file
1224
1226
  else:
1225
1227
  now = datetime.now().strftime("%m%d")
1226
- if args.dataset_name == "random":
1228
+ if args.dataset_name.startswith("random"):
1227
1229
  output_file_name = f"{args.backend}_{now}_{args.num_prompts}_{args.random_input_len}_{args.random_output_len}.jsonl"
1228
1230
  else:
1229
1231
  output_file_name = f"{args.backend}_{now}_{args.num_prompts}_sharegpt.jsonl"
@@ -1442,7 +1444,7 @@ if __name__ == "__main__":
1442
1444
  "--dataset-name",
1443
1445
  type=str,
1444
1446
  default="sharegpt",
1445
- choices=["sharegpt", "random", "generated-shared-prefix"],
1447
+ choices=["sharegpt", "random", "random-ids", "generated-shared-prefix"],
1446
1448
  help="Name of the dataset to benchmark on.",
1447
1449
  )
1448
1450
  parser.add_argument(
@@ -1,7 +1,3 @@
1
- from typing import List, Optional, Union
2
-
3
- import numpy as np
4
-
5
1
  from sglang.lang.backend.base_backend import BaseBackend
6
2
  from sglang.lang.chat_template import get_chat_template
7
3
  from sglang.lang.interpreter import StreamExecutor
@@ -1,4 +1,4 @@
1
- from typing import Callable, List, Optional, Union
1
+ from typing import List, Optional, Union
2
2
 
3
3
  from sglang.lang.chat_template import get_chat_template
4
4
  from sglang.lang.choices import ChoicesDecision, ChoicesSamplingMethod
@@ -2,7 +2,7 @@ import dataclasses
2
2
  import logging
3
3
  import time
4
4
  import warnings
5
- from typing import Callable, List, Optional, Union
5
+ from typing import List, Optional, Union
6
6
 
7
7
  import numpy as np
8
8
 
@@ -1,6 +1,5 @@
1
1
  import os
2
2
  import warnings
3
- from typing import Optional
4
3
 
5
4
  from sglang.lang.backend.base_backend import BaseBackend
6
5
  from sglang.lang.chat_template import get_chat_template
sglang/lang/compiler.py CHANGED
@@ -5,13 +5,7 @@ from typing import List, Union
5
5
 
6
6
  from sglang.global_config import global_config
7
7
  from sglang.lang.interpreter import ProgramState, StreamExecutor, cache_program
8
- from sglang.lang.ir import (
9
- SglArgument,
10
- SglConstantText,
11
- SglExpr,
12
- SglSamplingParams,
13
- SglVariable,
14
- )
8
+ from sglang.lang.ir import SglArgument, SglExpr, SglSamplingParams, SglVariable
15
9
 
16
10
 
17
11
  def compile_func(function, backend):
sglang/lang/tracer.py CHANGED
@@ -1,20 +1,16 @@
1
1
  """Tracing a program."""
2
2
 
3
3
  import uuid
4
- from typing import Any, Callable, Dict, List, Optional, Union
4
+ from typing import Any, Dict, List, Optional
5
5
 
6
- from sglang.global_config import global_config
7
6
  from sglang.lang.backend.base_backend import BaseBackend
8
7
  from sglang.lang.interpreter import ProgramState, ProgramStateGroup
9
8
  from sglang.lang.ir import (
10
9
  SglArgument,
11
- SglCommitLazy,
12
- SglConcateAndAppend,
13
10
  SglConstantText,
14
11
  SglExpr,
15
12
  SglExprList,
16
13
  SglFork,
17
- SglFunction,
18
14
  SglGen,
19
15
  SglGetForkItem,
20
16
  SglRoleBegin,
@@ -230,8 +226,8 @@ class TracerProgramState(ProgramState):
230
226
  self.cur_role = None
231
227
 
232
228
  def _execute_var_scope_end(self, expr: SglVarScopeEnd):
233
- new_node = SglVariable(name, source=self.last_node)
234
- self.variables[name] = new_node
229
+ new_node = SglVariable(expr.name, source=self.last_node)
230
+ self.variables[expr.name] = new_node
235
231
 
236
232
  def get_var(self, name):
237
233
  ret = self.arguments.get(name, None)
sglang/srt/_custom_ops.py CHANGED
@@ -1,10 +1,8 @@
1
1
  # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/_custom_ops.py
2
2
  import logging
3
- import os
4
3
  from typing import List, Tuple
5
4
 
6
5
  import torch
7
- import torch.library
8
6
 
9
7
  from sglang.srt.utils import get_bool_env_var, is_hip, is_hpu
10
8
 
@@ -15,6 +15,7 @@
15
15
  import json
16
16
  import logging
17
17
  import math
18
+ import os
18
19
  from enum import IntEnum, auto
19
20
  from typing import List, Optional, Set, Union
20
21
 
@@ -42,10 +43,12 @@ class ModelConfig:
42
43
  context_length: Optional[int] = None,
43
44
  model_override_args: Optional[str] = None,
44
45
  is_embedding: Optional[bool] = None,
46
+ enable_multimodal: Optional[bool] = None,
45
47
  dtype: str = "auto",
46
48
  quantization: Optional[str] = None,
47
49
  override_config_file: Optional[str] = None,
48
50
  ) -> None:
51
+
49
52
  self.model_path = model_path
50
53
  self.revision = revision
51
54
  self.quantization = quantization
@@ -69,14 +72,28 @@ class ModelConfig:
69
72
  self.hf_text_config, "attention_chunk_size", None
70
73
  )
71
74
 
75
+ if enable_multimodal is None:
76
+ if self.hf_config.architectures == "Llama4ForConditionalGeneration":
77
+ enable_multimodal = False
78
+ else:
79
+ enable_multimodal = True
80
+
72
81
  # Check model type
73
82
  self.is_generation = is_generation_model(
74
83
  self.hf_config.architectures, is_embedding
75
84
  )
76
- self.is_multimodal = is_multimodal_model(self.hf_config.architectures)
77
- self.is_multimodal_gen = is_multimodal_gen_model(self.hf_config.architectures)
78
- self.is_image_gen = is_image_gen_model(self.hf_config.architectures)
79
- self.is_audio_model = is_audio_model(self.hf_config.architectures)
85
+ self.is_multimodal = enable_multimodal and is_multimodal_model(
86
+ self.hf_config.architectures
87
+ )
88
+ self.is_multimodal_gen = enable_multimodal and is_multimodal_gen_model(
89
+ self.hf_config.architectures
90
+ )
91
+ self.is_image_gen = enable_multimodal and is_image_gen_model(
92
+ self.hf_config.architectures
93
+ )
94
+ self.is_audio_model = enable_multimodal and is_audio_model(
95
+ self.hf_config.architectures
96
+ )
80
97
  self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures)
81
98
  self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
82
99
 
@@ -234,6 +251,20 @@ class ModelConfig:
234
251
  if quant_cfg is None:
235
252
  # compressed-tensors uses a "compression_config" key
236
253
  quant_cfg = getattr(self.hf_config, "compression_config", None)
254
+ if quant_cfg is None:
255
+ # check if is modelopt model -- modelopt doesn't have corresponding field
256
+ # in hf `config.json` but has a standalone `hf_quant_config.json` in the root directory
257
+ # example: https://huggingface.co/nvidia/Llama-3.1-8B-Instruct-FP8/tree/main
258
+ is_local = os.path.exists(self.model_path)
259
+ modelopt_quant_config = {"quant_method": "modelopt"}
260
+ if not is_local:
261
+ from huggingface_hub import HfApi
262
+
263
+ hf_api = HfApi()
264
+ if hf_api.file_exists(self.model_path, "hf_quant_config.json"):
265
+ quant_cfg = modelopt_quant_config
266
+ elif os.path.exists(os.path.join(self.model_path, "hf_quant_config.json")):
267
+ quant_cfg = modelopt_quant_config
237
268
  return quant_cfg
238
269
 
239
270
  # adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
@@ -264,6 +295,7 @@ class ModelConfig:
264
295
  "moe_wna16",
265
296
  ]
266
297
  compatible_quantization_methods = {
298
+ "modelopt_fp4": ["modelopt"],
267
299
  "w8a8_int8": ["compressed-tensors", "compressed_tensors"],
268
300
  "w8a8_fp8": ["compressed-tensors", "compressed_tensors"],
269
301
  }
@@ -470,8 +502,8 @@ multimodal_model_archs = [
470
502
  "Gemma3ForConditionalGeneration",
471
503
  "Grok1VForCausalLM",
472
504
  "Grok1AForCausalLM",
473
- # TODO: add multimodal support for "Llama4ForConditionalGeneration",
474
505
  "LlavaLlamaForCausalLM",
506
+ "Llama4ForConditionalGeneration",
475
507
  "LlavaMistralForCausalLM",
476
508
  "LlavaQwenForCausalLM",
477
509
  "LlavaVidForCausalLM",
@@ -28,6 +28,18 @@ logger = logging.getLogger(__name__)
28
28
 
29
29
 
30
30
  class BaseGrammarObject(ABC):
31
+
32
+ def __init__(self):
33
+ self._finished = False
34
+
35
+ @property
36
+ def finished(self):
37
+ return self._finished
38
+
39
+ @finished.setter
40
+ def finished(self, finished):
41
+ self._finished = finished
42
+
31
43
  @abstractmethod
32
44
  def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]:
33
45
  """
@@ -59,6 +71,13 @@ class BaseGrammarObject(ABC):
59
71
  """
60
72
  raise NotImplementedError
61
73
 
74
+ @abstractmethod
75
+ def accept_token(self, token: int) -> None:
76
+ """
77
+ Accept a token in the grammar.
78
+ """
79
+ raise NotImplementedError
80
+
62
81
  @abstractmethod
63
82
  def allocate_vocab_mask(
64
83
  self, vocab_size: int, batch_size: int, device
@@ -90,7 +109,7 @@ class CacheEntry:
90
109
  event: Event
91
110
 
92
111
 
93
- class BaseGrammarBackend(ABC):
112
+ class BaseGrammarBackend:
94
113
  def __init__(self):
95
114
  self.executor = ThreadPoolExecutor()
96
115
  self.cache: Dict[Tuple[str, str], CacheEntry] = {}
@@ -107,19 +126,15 @@ class BaseGrammarBackend(ABC):
107
126
  """
108
127
  raise ValueError(f"Invalid key_type: {key_type}={key_string}")
109
128
 
110
- @abstractmethod
111
129
  def dispatch_json(self, key_string: str) -> Optional[BaseGrammarObject]:
112
130
  return self._not_supported("json", key_string)
113
131
 
114
- @abstractmethod
115
132
  def dispatch_regex(self, key_string: str) -> Optional[BaseGrammarObject]:
116
133
  return self._not_supported("regex", key_string)
117
134
 
118
- @abstractmethod
119
135
  def dispatch_ebnf(self, key_string: str) -> Optional[BaseGrammarObject]:
120
136
  return self._not_supported("ebnf", key_string)
121
137
 
122
- @abstractmethod
123
138
  def dispatch_structural_tag(self, key_string: str) -> Optional[BaseGrammarObject]:
124
139
  return self._not_supported("structural_tag", key_string)
125
140
 
@@ -195,4 +210,10 @@ def create_grammar_backend(
195
210
  else:
196
211
  raise ValueError(f"Invalid grammar backend: {server_args.grammar_backend}")
197
212
 
213
+ if server_args.reasoning_parser and hasattr(tokenizer, "think_end_id"):
214
+ from .reasoner_grammar_backend import ReasonerGrammarBackend
215
+
216
+ grammar_backend = ReasonerGrammarBackend(
217
+ grammar_backend, tokenizer.think_end_id
218
+ )
198
219
  return grammar_backend
@@ -33,6 +33,7 @@ class GuidanceGrammar(BaseGrammarObject):
33
33
  def __init__(
34
34
  self, llguidance_tokenizer: llguidance.LLTokenizer, serialized_grammar: str
35
35
  ):
36
+ super().__init__()
36
37
  self.llguidance_tokenizer = llguidance_tokenizer
37
38
  self.serialized_grammar = serialized_grammar
38
39
 
@@ -44,6 +44,7 @@ class OutlinesGrammar(BaseGrammarObject):
44
44
  guide: RegexGuide,
45
45
  jump_forward_map: Union[OutlinesJumpForwardMap, None],
46
46
  ) -> None:
47
+ super().__init__()
47
48
  self.guide = guide
48
49
  self.jump_forward_map = jump_forward_map
49
50
  self.state = 0
@@ -19,10 +19,13 @@ Reference: https://lmsys.org/blog/2024-02-05-compressed-fsm/
19
19
  import dataclasses
20
20
  import logging
21
21
  from collections import defaultdict
22
+ from typing import Optional
22
23
 
23
24
  import interegular
24
25
  from interegular import InvalidSyntax
25
- from outlines.caching import cache as disk_cache
26
+ from outlines.caching import cache
27
+
28
+ from sglang.srt.utils import get_bool_env_var
26
29
 
27
30
  try:
28
31
  # outlines >= 0.1.0
@@ -34,6 +37,9 @@ except ImportError:
34
37
 
35
38
  IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
36
39
 
40
+ # Env var was set in sglang.srt.server_args.ServerArgs.__post__init__
41
+ DISABLE_DISK_CACHE = get_bool_env_var("SGLANG_DISABLE_OUTLINES_DISK_CACHE", "true")
42
+
37
43
  logger = logging.getLogger(__name__)
38
44
 
39
45
 
@@ -45,6 +51,13 @@ class JumpEdge:
45
51
  byte_next_state: int = None
46
52
 
47
53
 
54
+ def disk_cache(expire: Optional[float] = None, typed=False, ignore=()):
55
+ if not DISABLE_DISK_CACHE:
56
+ return cache(expire, typed, ignore)
57
+ else:
58
+ return lambda fn: None
59
+
60
+
48
61
  @disk_cache()
49
62
  def init_state_to_jump_forward(regex_string):
50
63
  try:
@@ -0,0 +1,101 @@
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+ """The baseclass of a backend for reasoner grammar-guided constrained decoding."""
15
+
16
+ from concurrent.futures import Future
17
+ from typing import List, Optional, Tuple
18
+
19
+ import torch
20
+
21
+ from .base_grammar_backend import BaseGrammarBackend, BaseGrammarObject
22
+
23
+
24
+ class ReasonerGrammarObject(BaseGrammarObject):
25
+ def __init__(self, grammar: BaseGrammarObject, think_end_id):
26
+ super().__init__()
27
+ self.grammar = grammar
28
+ self.think_end_id = think_end_id
29
+ self.is_in_reasoning = True
30
+
31
+ @property
32
+ def finished(self):
33
+ return self.grammar.finished
34
+
35
+ @finished.setter
36
+ def finished(self, finished):
37
+ self.grammar.finished = finished
38
+
39
+ def allocate_vocab_mask(
40
+ self, vocab_size: int, batch_size: int, device
41
+ ) -> torch.Tensor:
42
+ return self.grammar.allocate_vocab_mask(vocab_size, batch_size, device)
43
+
44
+ def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
45
+ if not self.is_in_reasoning:
46
+ self.grammar.fill_vocab_mask(vocab_mask, idx)
47
+
48
+ def move_vocab_mask(self, vocab_mask: torch.Tensor, device) -> torch.Tensor:
49
+ return self.grammar.move_vocab_mask(vocab_mask, device)
50
+
51
+ @property
52
+ def apply_vocab_mask(self):
53
+ return self.grammar.apply_vocab_mask
54
+
55
+ def accept_token(self, token: int):
56
+ if token == self.think_end_id:
57
+ self.is_in_reasoning = False
58
+
59
+ if not self.is_in_reasoning and token != self.think_end_id:
60
+ self.grammar.accept_token(token)
61
+
62
+ def try_jump_forward(self, tokenizer):
63
+ return self.grammar.try_jump_forward(tokenizer)
64
+
65
+ def jump_forward_str_state(self, helper):
66
+ return self.grammar.jump_forward_str_state(helper)
67
+
68
+ def jump_and_retokenize(
69
+ self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
70
+ ):
71
+ return self.grammar.jump_and_retokenize(
72
+ old_output_ids, new_output_ids, next_state
73
+ )
74
+
75
+ def copy(self) -> BaseGrammarObject:
76
+ return ReasonerGrammarObject(self.grammar.copy(), self.think_end_id)
77
+
78
+
79
+ class ReasonerGrammarBackend(BaseGrammarBackend):
80
+ def __init__(self, grammar_backend: BaseGrammarBackend, think_end_id):
81
+ self.grammar_backend = grammar_backend
82
+ self.think_end_id = think_end_id
83
+
84
+ def get_cached_value(self, key: Tuple[str, str]) -> Optional[ReasonerGrammarObject]:
85
+ grammar = self.grammar_backend.get_cached_value(key)
86
+ return ReasonerGrammarObject(grammar, self.think_end_id) if grammar else None
87
+
88
+ def get_future_value(self, key: Tuple[str, str]) -> Future:
89
+ grammar = Future()
90
+
91
+ def callback(f: Future):
92
+ if result := f.result():
93
+ grammar.set_result(ReasonerGrammarObject(result, self.think_end_id))
94
+ else:
95
+ grammar.set_result(None)
96
+
97
+ self.grammar_backend.get_future_value(key).add_done_callback(callback)
98
+ return grammar
99
+
100
+ def reset(self):
101
+ self.grammar_backend.reset()