sglang 0.4.5__py3-none-any.whl → 0.4.5.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (121) hide show
  1. sglang/bench_one_batch.py +21 -0
  2. sglang/bench_serving.py +10 -4
  3. sglang/srt/configs/model_config.py +37 -5
  4. sglang/srt/constrained/base_grammar_backend.py +26 -5
  5. sglang/srt/constrained/llguidance_backend.py +1 -0
  6. sglang/srt/constrained/outlines_backend.py +1 -0
  7. sglang/srt/constrained/reasoner_grammar_backend.py +101 -0
  8. sglang/srt/constrained/xgrammar_backend.py +1 -0
  9. sglang/srt/disaggregation/base/__init__.py +8 -0
  10. sglang/srt/disaggregation/base/conn.py +113 -0
  11. sglang/srt/disaggregation/decode.py +18 -5
  12. sglang/srt/disaggregation/mini_lb.py +53 -122
  13. sglang/srt/disaggregation/mooncake/__init__.py +6 -0
  14. sglang/srt/disaggregation/mooncake/conn.py +615 -0
  15. sglang/srt/disaggregation/mooncake/transfer_engine.py +108 -0
  16. sglang/srt/disaggregation/prefill.py +43 -19
  17. sglang/srt/disaggregation/utils.py +31 -0
  18. sglang/srt/entrypoints/EngineBase.py +53 -0
  19. sglang/srt/entrypoints/engine.py +36 -8
  20. sglang/srt/entrypoints/http_server.py +37 -8
  21. sglang/srt/entrypoints/http_server_engine.py +142 -0
  22. sglang/srt/entrypoints/verl_engine.py +37 -10
  23. sglang/srt/hf_transformers_utils.py +4 -0
  24. sglang/srt/layers/attention/flashattention_backend.py +330 -200
  25. sglang/srt/layers/attention/flashinfer_backend.py +13 -7
  26. sglang/srt/layers/attention/vision.py +1 -1
  27. sglang/srt/layers/dp_attention.py +2 -4
  28. sglang/srt/layers/elementwise.py +15 -2
  29. sglang/srt/layers/linear.py +1 -0
  30. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +145 -118
  31. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  32. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  33. sglang/srt/layers/moe/fused_moe_triton/configs/{E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +34 -34
  34. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  35. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  36. sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  37. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +38 -21
  38. sglang/srt/layers/moe/router.py +7 -1
  39. sglang/srt/layers/moe/topk.py +37 -16
  40. sglang/srt/layers/quantization/__init__.py +12 -5
  41. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +4 -0
  42. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +68 -45
  43. sglang/srt/layers/quantization/fp8.py +25 -13
  44. sglang/srt/layers/quantization/fp8_kernel.py +130 -4
  45. sglang/srt/layers/quantization/fp8_utils.py +34 -6
  46. sglang/srt/layers/quantization/kv_cache.py +43 -52
  47. sglang/srt/layers/quantization/modelopt_quant.py +271 -4
  48. sglang/srt/layers/quantization/w8a8_fp8.py +154 -4
  49. sglang/srt/layers/quantization/w8a8_int8.py +1 -0
  50. sglang/srt/layers/radix_attention.py +13 -1
  51. sglang/srt/layers/rotary_embedding.py +12 -1
  52. sglang/srt/managers/io_struct.py +254 -97
  53. sglang/srt/managers/mm_utils.py +3 -2
  54. sglang/srt/managers/multimodal_processors/base_processor.py +114 -77
  55. sglang/srt/managers/multimodal_processors/janus_pro.py +3 -1
  56. sglang/srt/managers/multimodal_processors/mllama4.py +21 -36
  57. sglang/srt/managers/schedule_batch.py +62 -21
  58. sglang/srt/managers/scheduler.py +71 -14
  59. sglang/srt/managers/tokenizer_manager.py +17 -3
  60. sglang/srt/managers/tp_worker.py +1 -0
  61. sglang/srt/mem_cache/memory_pool.py +14 -1
  62. sglang/srt/metrics/collector.py +9 -0
  63. sglang/srt/model_executor/cuda_graph_runner.py +7 -4
  64. sglang/srt/model_executor/forward_batch_info.py +234 -15
  65. sglang/srt/model_executor/model_runner.py +48 -9
  66. sglang/srt/model_loader/loader.py +31 -4
  67. sglang/srt/model_loader/weight_utils.py +4 -2
  68. sglang/srt/models/baichuan.py +2 -0
  69. sglang/srt/models/chatglm.py +1 -0
  70. sglang/srt/models/commandr.py +1 -0
  71. sglang/srt/models/dbrx.py +1 -0
  72. sglang/srt/models/deepseek.py +1 -0
  73. sglang/srt/models/deepseek_v2.py +248 -61
  74. sglang/srt/models/exaone.py +1 -0
  75. sglang/srt/models/gemma.py +1 -0
  76. sglang/srt/models/gemma2.py +1 -0
  77. sglang/srt/models/gemma3_causal.py +1 -0
  78. sglang/srt/models/gpt2.py +1 -0
  79. sglang/srt/models/gpt_bigcode.py +1 -0
  80. sglang/srt/models/granite.py +1 -0
  81. sglang/srt/models/grok.py +1 -0
  82. sglang/srt/models/internlm2.py +1 -0
  83. sglang/srt/models/llama.py +1 -0
  84. sglang/srt/models/llama4.py +101 -34
  85. sglang/srt/models/minicpm.py +1 -0
  86. sglang/srt/models/minicpm3.py +2 -0
  87. sglang/srt/models/mixtral.py +1 -0
  88. sglang/srt/models/mixtral_quant.py +1 -0
  89. sglang/srt/models/mllama.py +51 -8
  90. sglang/srt/models/mllama4.py +102 -29
  91. sglang/srt/models/olmo.py +1 -0
  92. sglang/srt/models/olmo2.py +1 -0
  93. sglang/srt/models/olmoe.py +1 -0
  94. sglang/srt/models/phi3_small.py +1 -0
  95. sglang/srt/models/qwen.py +1 -0
  96. sglang/srt/models/qwen2.py +1 -0
  97. sglang/srt/models/qwen2_5_vl.py +35 -70
  98. sglang/srt/models/qwen2_moe.py +1 -0
  99. sglang/srt/models/qwen2_vl.py +27 -25
  100. sglang/srt/models/stablelm.py +1 -0
  101. sglang/srt/models/xverse.py +1 -0
  102. sglang/srt/models/xverse_moe.py +1 -0
  103. sglang/srt/openai_api/adapter.py +4 -1
  104. sglang/srt/patch_torch.py +11 -0
  105. sglang/srt/server_args.py +34 -0
  106. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -4
  107. sglang/srt/speculative/eagle_utils.py +1 -11
  108. sglang/srt/speculative/eagle_worker.py +6 -2
  109. sglang/srt/utils.py +120 -9
  110. sglang/test/attention/test_flashattn_backend.py +259 -221
  111. sglang/test/attention/test_flashattn_mla_backend.py +285 -0
  112. sglang/test/attention/test_prefix_chunk_info.py +224 -0
  113. sglang/test/test_block_fp8.py +57 -0
  114. sglang/test/test_utils.py +19 -8
  115. sglang/version.py +1 -1
  116. {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/METADATA +14 -4
  117. {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/RECORD +120 -106
  118. sglang/srt/disaggregation/conn.py +0 -81
  119. {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/WHEEL +0 -0
  120. {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/licenses/LICENSE +0 -0
  121. {sglang-0.4.5.dist-info → sglang-0.4.5.post1.dist-info}/top_level.txt +0 -0
sglang/bench_one_batch.py CHANGED
@@ -60,6 +60,7 @@ from sglang.srt.configs.model_config import ModelConfig
60
60
  from sglang.srt.entrypoints.engine import _set_envs_and_config
61
61
  from sglang.srt.hf_transformers_utils import get_tokenizer
62
62
  from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
63
+ from sglang.srt.managers.scheduler import Scheduler
63
64
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
64
65
  from sglang.srt.model_executor.model_runner import ModelRunner
65
66
  from sglang.srt.sampling.sampling_params import SamplingParams
@@ -135,6 +136,7 @@ def load_model(server_args, port_args, tp_rank):
135
136
  context_length=server_args.context_length,
136
137
  model_override_args=server_args.json_model_override_args,
137
138
  is_embedding=server_args.is_embedding,
139
+ enable_multimodal=server_args.enable_multimodal,
138
140
  dtype=server_args.dtype,
139
141
  quantization=server_args.quantization,
140
142
  )
@@ -184,6 +186,7 @@ def prepare_inputs_for_correctness_test(bench_args, tokenizer):
184
186
  req.prefix_indices = []
185
187
  req.fill_ids = req.origin_input_ids
186
188
  req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
189
+ req.logprob_start_len = len(req.origin_input_ids) - 1
187
190
  reqs.append(req)
188
191
 
189
192
  return input_ids, reqs
@@ -199,6 +202,7 @@ def prepare_extend_inputs_for_correctness_test(
199
202
  i, : bench_args.cut_len
200
203
  ]
201
204
  req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
205
+ req.logprob_start_len = len(req.origin_input_ids) - 1
202
206
  return reqs
203
207
 
204
208
 
@@ -220,6 +224,7 @@ def prepare_synthetic_inputs_for_latency_test(batch_size, input_len):
220
224
  req.prefix_indices = []
221
225
  req.fill_ids = req.origin_input_ids
222
226
  req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
227
+ req.logprob_start_len = len(req.origin_input_ids) - 1
223
228
  reqs.append(req)
224
229
 
225
230
  return reqs
@@ -238,6 +243,7 @@ def extend(reqs, model_runner):
238
243
  enable_custom_logit_processor=False,
239
244
  )
240
245
  batch.prepare_for_extend()
246
+ _maybe_prepare_dp_attn_batch(batch, model_runner)
241
247
  model_worker_batch = batch.get_model_worker_batch()
242
248
  forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
243
249
  logits_output = model_runner.forward(forward_batch)
@@ -249,6 +255,7 @@ def extend(reqs, model_runner):
249
255
  def decode(input_token_ids, batch, model_runner):
250
256
  batch.output_ids = input_token_ids
251
257
  batch.prepare_for_decode()
258
+ _maybe_prepare_dp_attn_batch(batch, model_runner)
252
259
  model_worker_batch = batch.get_model_worker_batch()
253
260
  forward_batch = ForwardBatch.init_new(model_worker_batch, model_runner)
254
261
  logits_output = model_runner.forward(forward_batch)
@@ -256,6 +263,20 @@ def decode(input_token_ids, batch, model_runner):
256
263
  return next_token_ids, logits_output.next_token_logits
257
264
 
258
265
 
266
+ def _maybe_prepare_dp_attn_batch(batch: ScheduleBatch, model_runner):
267
+ if model_runner.server_args.enable_dp_attention:
268
+ Scheduler.prepare_dp_attn_batch_raw(
269
+ batch,
270
+ dp_size=model_runner.server_args.dp_size,
271
+ attn_tp_size=1,
272
+ tp_cpu_group=model_runner.tp_group.cpu_group,
273
+ get_idle_batch=None,
274
+ disable_cuda_graph=model_runner.server_args.disable_cuda_graph,
275
+ spec_algorithm=SpeculativeAlgorithm.NONE,
276
+ speculative_num_draft_tokens=None,
277
+ )
278
+
279
+
259
280
  def correctness_test(
260
281
  server_args,
261
282
  port_args,
sglang/bench_serving.py CHANGED
@@ -490,7 +490,7 @@ def get_dataset(args, tokenizer):
490
490
  prompt_suffix=args.prompt_suffix,
491
491
  apply_chat_template=args.apply_chat_template,
492
492
  )
493
- elif args.dataset_name == "random":
493
+ elif args.dataset_name.startswith("random"):
494
494
  input_requests = sample_random_requests(
495
495
  input_len=args.random_input_len,
496
496
  output_len=args.random_output_len,
@@ -498,6 +498,7 @@ def get_dataset(args, tokenizer):
498
498
  range_ratio=args.random_range_ratio,
499
499
  tokenizer=tokenizer,
500
500
  dataset_path=args.dataset_path,
501
+ random_sample=args.dataset_name == "random",
501
502
  )
502
503
  elif args.dataset_name == "generated-shared-prefix":
503
504
  input_requests = sample_generated_shared_prefix_requests(
@@ -687,6 +688,7 @@ def sample_random_requests(
687
688
  range_ratio: float,
688
689
  tokenizer: PreTrainedTokenizerBase,
689
690
  dataset_path: str,
691
+ random_sample: bool = True,
690
692
  ) -> List[Tuple[str, int, int]]:
691
693
 
692
694
  input_lens = np.random.randint(
@@ -700,11 +702,15 @@ def sample_random_requests(
700
702
  size=num_prompts,
701
703
  )
702
704
 
703
- if True:
705
+ if random_sample:
704
706
  # Sample token ids from ShareGPT and repeat/truncate them to satisfy the input_lens
705
707
 
706
708
  # Download sharegpt if necessary
707
709
  if not os.path.isfile(dataset_path):
710
+ print(
711
+ "If you do not want to randomly sample from a dataset,"
712
+ " please use --dataset-name random-ids."
713
+ )
708
714
  dataset_path = download_and_cache_file(SHAREGPT_URL)
709
715
 
710
716
  # Load the dataset.
@@ -1223,7 +1229,7 @@ async def benchmark(
1223
1229
  output_file_name = args.output_file
1224
1230
  else:
1225
1231
  now = datetime.now().strftime("%m%d")
1226
- if args.dataset_name == "random":
1232
+ if args.dataset_name.startswith("random"):
1227
1233
  output_file_name = f"{args.backend}_{now}_{args.num_prompts}_{args.random_input_len}_{args.random_output_len}.jsonl"
1228
1234
  else:
1229
1235
  output_file_name = f"{args.backend}_{now}_{args.num_prompts}_sharegpt.jsonl"
@@ -1442,7 +1448,7 @@ if __name__ == "__main__":
1442
1448
  "--dataset-name",
1443
1449
  type=str,
1444
1450
  default="sharegpt",
1445
- choices=["sharegpt", "random", "generated-shared-prefix"],
1451
+ choices=["sharegpt", "random", "random-ids", "generated-shared-prefix"],
1446
1452
  help="Name of the dataset to benchmark on.",
1447
1453
  )
1448
1454
  parser.add_argument(
@@ -15,6 +15,7 @@
15
15
  import json
16
16
  import logging
17
17
  import math
18
+ import os
18
19
  from enum import IntEnum, auto
19
20
  from typing import List, Optional, Set, Union
20
21
 
@@ -42,10 +43,12 @@ class ModelConfig:
42
43
  context_length: Optional[int] = None,
43
44
  model_override_args: Optional[str] = None,
44
45
  is_embedding: Optional[bool] = None,
46
+ enable_multimodal: Optional[bool] = None,
45
47
  dtype: str = "auto",
46
48
  quantization: Optional[str] = None,
47
49
  override_config_file: Optional[str] = None,
48
50
  ) -> None:
51
+
49
52
  self.model_path = model_path
50
53
  self.revision = revision
51
54
  self.quantization = quantization
@@ -69,14 +72,28 @@ class ModelConfig:
69
72
  self.hf_text_config, "attention_chunk_size", None
70
73
  )
71
74
 
75
+ if enable_multimodal is None:
76
+ if self.hf_config.architectures == "Llama4ForConditionalGeneration":
77
+ enable_multimodal = False
78
+ else:
79
+ enable_multimodal = True
80
+
72
81
  # Check model type
73
82
  self.is_generation = is_generation_model(
74
83
  self.hf_config.architectures, is_embedding
75
84
  )
76
- self.is_multimodal = is_multimodal_model(self.hf_config.architectures)
77
- self.is_multimodal_gen = is_multimodal_gen_model(self.hf_config.architectures)
78
- self.is_image_gen = is_image_gen_model(self.hf_config.architectures)
79
- self.is_audio_model = is_audio_model(self.hf_config.architectures)
85
+ self.is_multimodal = enable_multimodal and is_multimodal_model(
86
+ self.hf_config.architectures
87
+ )
88
+ self.is_multimodal_gen = enable_multimodal and is_multimodal_gen_model(
89
+ self.hf_config.architectures
90
+ )
91
+ self.is_image_gen = enable_multimodal and is_image_gen_model(
92
+ self.hf_config.architectures
93
+ )
94
+ self.is_audio_model = enable_multimodal and is_audio_model(
95
+ self.hf_config.architectures
96
+ )
80
97
  self.is_encoder_decoder = is_encoder_decoder_model(self.hf_config.architectures)
81
98
  self.dtype = _get_and_verify_dtype(self.hf_text_config, dtype)
82
99
 
@@ -234,6 +251,20 @@ class ModelConfig:
234
251
  if quant_cfg is None:
235
252
  # compressed-tensors uses a "compression_config" key
236
253
  quant_cfg = getattr(self.hf_config, "compression_config", None)
254
+ if quant_cfg is None:
255
+ # check if is modelopt model -- modelopt doesn't have corresponding field
256
+ # in hf `config.json` but has a standalone `hf_quant_config.json` in the root directory
257
+ # example: https://huggingface.co/nvidia/Llama-3.1-8B-Instruct-FP8/tree/main
258
+ is_local = os.path.exists(self.model_path)
259
+ modelopt_quant_config = {"quant_method": "modelopt"}
260
+ if not is_local:
261
+ from huggingface_hub import HfApi
262
+
263
+ hf_api = HfApi()
264
+ if hf_api.file_exists(self.model_path, "hf_quant_config.json"):
265
+ quant_cfg = modelopt_quant_config
266
+ elif os.path.exists(os.path.join(self.model_path, "hf_quant_config.json")):
267
+ quant_cfg = modelopt_quant_config
237
268
  return quant_cfg
238
269
 
239
270
  # adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/config.py
@@ -264,6 +295,7 @@ class ModelConfig:
264
295
  "moe_wna16",
265
296
  ]
266
297
  compatible_quantization_methods = {
298
+ "modelopt_fp4": ["modelopt"],
267
299
  "w8a8_int8": ["compressed-tensors", "compressed_tensors"],
268
300
  "w8a8_fp8": ["compressed-tensors", "compressed_tensors"],
269
301
  }
@@ -470,8 +502,8 @@ multimodal_model_archs = [
470
502
  "Gemma3ForConditionalGeneration",
471
503
  "Grok1VForCausalLM",
472
504
  "Grok1AForCausalLM",
473
- # TODO: add multimodal support for "Llama4ForConditionalGeneration",
474
505
  "LlavaLlamaForCausalLM",
506
+ "Llama4ForConditionalGeneration",
475
507
  "LlavaMistralForCausalLM",
476
508
  "LlavaQwenForCausalLM",
477
509
  "LlavaVidForCausalLM",
@@ -28,6 +28,18 @@ logger = logging.getLogger(__name__)
28
28
 
29
29
 
30
30
  class BaseGrammarObject(ABC):
31
+
32
+ def __init__(self):
33
+ self._finished = False
34
+
35
+ @property
36
+ def finished(self):
37
+ return self._finished
38
+
39
+ @finished.setter
40
+ def finished(self, finished):
41
+ self._finished = finished
42
+
31
43
  @abstractmethod
32
44
  def try_jump_forward(self, tokenizer) -> Optional[Tuple[List[int], str]]:
33
45
  """
@@ -59,6 +71,13 @@ class BaseGrammarObject(ABC):
59
71
  """
60
72
  raise NotImplementedError
61
73
 
74
+ @abstractmethod
75
+ def accept_token(self, token: int) -> None:
76
+ """
77
+ Accept a token in the grammar.
78
+ """
79
+ raise NotImplementedError
80
+
62
81
  @abstractmethod
63
82
  def allocate_vocab_mask(
64
83
  self, vocab_size: int, batch_size: int, device
@@ -90,7 +109,7 @@ class CacheEntry:
90
109
  event: Event
91
110
 
92
111
 
93
- class BaseGrammarBackend(ABC):
112
+ class BaseGrammarBackend:
94
113
  def __init__(self):
95
114
  self.executor = ThreadPoolExecutor()
96
115
  self.cache: Dict[Tuple[str, str], CacheEntry] = {}
@@ -107,19 +126,15 @@ class BaseGrammarBackend(ABC):
107
126
  """
108
127
  raise ValueError(f"Invalid key_type: {key_type}={key_string}")
109
128
 
110
- @abstractmethod
111
129
  def dispatch_json(self, key_string: str) -> Optional[BaseGrammarObject]:
112
130
  return self._not_supported("json", key_string)
113
131
 
114
- @abstractmethod
115
132
  def dispatch_regex(self, key_string: str) -> Optional[BaseGrammarObject]:
116
133
  return self._not_supported("regex", key_string)
117
134
 
118
- @abstractmethod
119
135
  def dispatch_ebnf(self, key_string: str) -> Optional[BaseGrammarObject]:
120
136
  return self._not_supported("ebnf", key_string)
121
137
 
122
- @abstractmethod
123
138
  def dispatch_structural_tag(self, key_string: str) -> Optional[BaseGrammarObject]:
124
139
  return self._not_supported("structural_tag", key_string)
125
140
 
@@ -195,4 +210,10 @@ def create_grammar_backend(
195
210
  else:
196
211
  raise ValueError(f"Invalid grammar backend: {server_args.grammar_backend}")
197
212
 
213
+ if server_args.reasoning_parser and hasattr(tokenizer, "think_end_id"):
214
+ from .reasoner_grammar_backend import ReasonerGrammarBackend
215
+
216
+ grammar_backend = ReasonerGrammarBackend(
217
+ grammar_backend, tokenizer.think_end_id
218
+ )
198
219
  return grammar_backend
@@ -33,6 +33,7 @@ class GuidanceGrammar(BaseGrammarObject):
33
33
  def __init__(
34
34
  self, llguidance_tokenizer: llguidance.LLTokenizer, serialized_grammar: str
35
35
  ):
36
+ super().__init__()
36
37
  self.llguidance_tokenizer = llguidance_tokenizer
37
38
  self.serialized_grammar = serialized_grammar
38
39
 
@@ -44,6 +44,7 @@ class OutlinesGrammar(BaseGrammarObject):
44
44
  guide: RegexGuide,
45
45
  jump_forward_map: Union[OutlinesJumpForwardMap, None],
46
46
  ) -> None:
47
+ super().__init__()
47
48
  self.guide = guide
48
49
  self.jump_forward_map = jump_forward_map
49
50
  self.state = 0
@@ -0,0 +1,101 @@
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+ """The baseclass of a backend for reasoner grammar-guided constrained decoding."""
15
+
16
+ from concurrent.futures import Future
17
+ from typing import List, Optional, Tuple
18
+
19
+ import torch
20
+
21
+ from .base_grammar_backend import BaseGrammarBackend, BaseGrammarObject
22
+
23
+
24
+ class ReasonerGrammarObject(BaseGrammarObject):
25
+ def __init__(self, grammar: BaseGrammarObject, think_end_id):
26
+ super().__init__()
27
+ self.grammar = grammar
28
+ self.think_end_id = think_end_id
29
+ self.is_in_reasoning = True
30
+
31
+ @property
32
+ def finished(self):
33
+ return self.grammar.finished
34
+
35
+ @finished.setter
36
+ def finished(self, finished):
37
+ self.grammar.finished = finished
38
+
39
+ def allocate_vocab_mask(
40
+ self, vocab_size: int, batch_size: int, device
41
+ ) -> torch.Tensor:
42
+ return self.grammar.allocate_vocab_mask(vocab_size, batch_size, device)
43
+
44
+ def fill_vocab_mask(self, vocab_mask: torch.Tensor, idx: int) -> None:
45
+ if not self.is_in_reasoning:
46
+ self.grammar.fill_vocab_mask(vocab_mask, idx)
47
+
48
+ def move_vocab_mask(self, vocab_mask: torch.Tensor, device) -> torch.Tensor:
49
+ return self.grammar.move_vocab_mask(vocab_mask, device)
50
+
51
+ @property
52
+ def apply_vocab_mask(self):
53
+ return self.grammar.apply_vocab_mask
54
+
55
+ def accept_token(self, token: int):
56
+ if token == self.think_end_id:
57
+ self.is_in_reasoning = False
58
+
59
+ if not self.is_in_reasoning and token != self.think_end_id:
60
+ self.grammar.accept_token(token)
61
+
62
+ def try_jump_forward(self, tokenizer):
63
+ return self.grammar.try_jump_forward(tokenizer)
64
+
65
+ def jump_forward_str_state(self, helper):
66
+ return self.grammar.jump_forward_str_state(helper)
67
+
68
+ def jump_and_retokenize(
69
+ self, old_output_ids: List[int], new_output_ids: List[int], next_state: int
70
+ ):
71
+ return self.grammar.jump_and_retokenize(
72
+ old_output_ids, new_output_ids, next_state
73
+ )
74
+
75
+ def copy(self) -> BaseGrammarObject:
76
+ return ReasonerGrammarObject(self.grammar.copy(), self.think_end_id)
77
+
78
+
79
+ class ReasonerGrammarBackend(BaseGrammarBackend):
80
+ def __init__(self, grammar_backend: BaseGrammarBackend, think_end_id):
81
+ self.grammar_backend = grammar_backend
82
+ self.think_end_id = think_end_id
83
+
84
+ def get_cached_value(self, key: Tuple[str, str]) -> Optional[ReasonerGrammarObject]:
85
+ grammar = self.grammar_backend.get_cached_value(key)
86
+ return ReasonerGrammarObject(grammar, self.think_end_id) if grammar else None
87
+
88
+ def get_future_value(self, key: Tuple[str, str]) -> Future:
89
+ grammar = Future()
90
+
91
+ def callback(f: Future):
92
+ if result := f.result():
93
+ grammar.set_result(ReasonerGrammarObject(result, self.think_end_id))
94
+ else:
95
+ grammar.set_result(None)
96
+
97
+ self.grammar_backend.get_future_value(key).add_done_callback(callback)
98
+ return grammar
99
+
100
+ def reset(self):
101
+ self.grammar_backend.reset()
@@ -48,6 +48,7 @@ class XGrammarGrammar(BaseGrammarObject):
48
48
  ctx: CompiledGrammar,
49
49
  override_stop_tokens: Optional[Union[List[int], int]],
50
50
  ) -> None:
51
+ super().__init__()
51
52
  self.matcher = matcher
52
53
  self.vocab_size = vocab_size
53
54
  self.ctx = ctx
@@ -0,0 +1,8 @@
1
+ from .conn import (
2
+ BaseKVBootstrapServer,
3
+ BaseKVManager,
4
+ BaseKVReceiver,
5
+ BaseKVSender,
6
+ KVArgs,
7
+ KVPoll,
8
+ )
@@ -0,0 +1,113 @@
1
+ from abc import ABC, abstractmethod
2
+ from typing import Optional
3
+
4
+ import numpy as np
5
+ import numpy.typing as npt
6
+
7
+ from sglang.srt.disaggregation.utils import DisaggregationMode
8
+ from sglang.srt.server_args import ServerArgs
9
+
10
+
11
+ class KVArgs:
12
+ engine_rank: int
13
+ kv_data_ptrs: list[int]
14
+ kv_data_lens: list[int]
15
+ kv_item_lens: list[int]
16
+ aux_data_ptrs: list[int]
17
+ aux_data_lens: list[int]
18
+ aux_item_lens: list[int]
19
+ ib_device: str
20
+ gpu_id: int
21
+
22
+
23
+ class KVPoll:
24
+ Failed = 0
25
+ Bootstrapping = 1
26
+ WaitingForInput = 2
27
+ Transferring = 3
28
+ Success = 4
29
+
30
+
31
+ class BaseKVManager(ABC):
32
+ """Base class for managing transfers states"""
33
+
34
+ @abstractmethod
35
+ def __init__(
36
+ self,
37
+ args: KVArgs,
38
+ disaggregation_mode: DisaggregationMode,
39
+ server_args: ServerArgs,
40
+ ): ...
41
+
42
+
43
+ class BaseKVSender(ABC):
44
+
45
+ @abstractmethod
46
+ def __init__(
47
+ self, mgr: BaseKVManager, bootstrap_addr: str, bootstrap_room: int
48
+ ): ...
49
+
50
+ @abstractmethod
51
+ def init(self, num_kv_indices: int, aux_index: Optional[int] = None):
52
+ """
53
+ Notify the decoder server about the kv indices length and aux index
54
+ """
55
+ ...
56
+
57
+ @abstractmethod
58
+ def send(self, kv_indices: npt.NDArray[np.int64]):
59
+ """
60
+ Send the kv cache at the given kv indices to the decoder server
61
+ """
62
+ ...
63
+
64
+ @abstractmethod
65
+ def poll(self) -> KVPoll:
66
+ """
67
+ Check the status of the kv cache transfer
68
+ """
69
+ ...
70
+
71
+ @abstractmethod
72
+ def failure_exception(self):
73
+ """
74
+ Raise an exception if the kv cache transfer fails
75
+ """
76
+ ...
77
+
78
+
79
+ class BaseKVReceiver(ABC):
80
+
81
+ @abstractmethod
82
+ def __init__(
83
+ self,
84
+ mgr: BaseKVManager,
85
+ bootstrap_addr: str,
86
+ bootstrap_room: Optional[int] = None,
87
+ ): ...
88
+
89
+ @abstractmethod
90
+ def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None):
91
+ """
92
+ Notify the prefill server about the kv indices and aux index
93
+ """
94
+ ...
95
+
96
+ @abstractmethod
97
+ def poll(self) -> KVPoll:
98
+ """
99
+ Check the status of the kv cache transfer
100
+ """
101
+ ...
102
+
103
+ @abstractmethod
104
+ def failure_exception(self):
105
+ """
106
+ Raise an exception if the kv cache transfer fails
107
+ """
108
+ ...
109
+
110
+
111
+ class BaseKVBootstrapServer(ABC):
112
+ @abstractmethod
113
+ def __init__(self, port: int): ...
@@ -24,12 +24,17 @@ import logging
24
24
  from dataclasses import dataclass
25
25
  from typing import TYPE_CHECKING, List, Optional, Tuple
26
26
 
27
+ import numpy as np
27
28
  import torch
28
29
  from torch.distributed import ProcessGroup
29
30
 
30
- from sglang.srt.disaggregation.conn import KVArgs, KVManager, KVPoll, KVReceiver
31
+ from sglang.srt.disaggregation.base import BaseKVManager, BaseKVReceiver, KVArgs, KVPoll
31
32
  from sglang.srt.disaggregation.utils import (
33
+ DisaggregationMode,
34
+ KVClassType,
32
35
  ReqToMetadataIdxAllocator,
36
+ TransferBackend,
37
+ get_kv_class,
33
38
  poll_and_all_reduce,
34
39
  )
35
40
  from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
@@ -49,7 +54,7 @@ if TYPE_CHECKING:
49
54
  @dataclass
50
55
  class DecodeRequest:
51
56
  req: Req
52
- kv_receiver: KVReceiver
57
+ kv_receiver: BaseKVReceiver
53
58
  waiting_for_input: bool = False
54
59
  metadata_buffer_index: int = -1
55
60
 
@@ -73,6 +78,7 @@ class DecodePreallocQueue:
73
78
  tp_rank: int,
74
79
  tp_size: int,
75
80
  bootstrap_port: int,
81
+ transfer_backend: TransferBackend,
76
82
  ):
77
83
  self.req_to_token_pool = req_to_token_pool
78
84
  self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
@@ -92,9 +98,10 @@ class DecodePreallocQueue:
92
98
 
93
99
  # Queue for requests pending pre-allocation
94
100
  self.queue: List[DecodeRequest] = []
101
+ self.transfer_backend = transfer_backend
95
102
  self.kv_manager = self._init_kv_manager()
96
103
 
97
- def _init_kv_manager(self) -> KVManager:
104
+ def _init_kv_manager(self) -> BaseKVManager:
98
105
  kv_args = KVArgs()
99
106
  kv_args.engine_rank = self.tp_rank
100
107
  kv_data_ptrs, kv_data_lens, kv_item_lens = (
@@ -115,13 +122,18 @@ class DecodePreallocQueue:
115
122
  metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers
116
123
  ]
117
124
  kv_args.ib_device = "mock-ib-device"
118
- kv_manager = KVManager(kv_args)
125
+ kv_args.gpu_id = self.scheduler.gpu_id
126
+ kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
127
+ kv_manager = kv_manager_class(
128
+ kv_args, DisaggregationMode.DECODE, self.scheduler.server_args
129
+ )
119
130
  return kv_manager
120
131
 
121
132
  def add(self, req: Req) -> None:
122
133
  """Add a request to the pending queue."""
123
134
 
124
- kv_receiver = KVReceiver(
135
+ kv_receiver_class = get_kv_class(self.transfer_backend, KVClassType.RECEIVER)
136
+ kv_receiver = kv_receiver_class(
125
137
  mgr=self.kv_manager,
126
138
  bootstrap_addr=f"{req.bootstrap_host}:{self.bootstrap_port}",
127
139
  bootstrap_room=req.bootstrap_room,
@@ -186,6 +198,7 @@ class DecodePreallocQueue:
186
198
  ]
187
199
  .cpu()
188
200
  .numpy()
201
+ .astype(np.int64)
189
202
  )
190
203
 
191
204
  decode_req.metadata_buffer_index = (