sglang 0.4.4__py3-none-any.whl → 0.4.4.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (176) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +6 -0
  3. sglang/bench_one_batch.py +1 -1
  4. sglang/bench_one_batch_server.py +1 -1
  5. sglang/bench_serving.py +3 -1
  6. sglang/check_env.py +3 -4
  7. sglang/lang/backend/openai.py +18 -5
  8. sglang/lang/chat_template.py +28 -7
  9. sglang/lang/interpreter.py +7 -3
  10. sglang/lang/ir.py +10 -0
  11. sglang/srt/_custom_ops.py +1 -1
  12. sglang/srt/code_completion_parser.py +174 -0
  13. sglang/srt/configs/__init__.py +2 -6
  14. sglang/srt/configs/deepseekvl2.py +667 -0
  15. sglang/srt/configs/janus_pro.py +3 -4
  16. sglang/srt/configs/load_config.py +1 -0
  17. sglang/srt/configs/model_config.py +63 -11
  18. sglang/srt/configs/utils.py +25 -0
  19. sglang/srt/connector/__init__.py +51 -0
  20. sglang/srt/connector/base_connector.py +112 -0
  21. sglang/srt/connector/redis.py +85 -0
  22. sglang/srt/connector/s3.py +122 -0
  23. sglang/srt/connector/serde/__init__.py +31 -0
  24. sglang/srt/connector/serde/safe_serde.py +29 -0
  25. sglang/srt/connector/serde/serde.py +43 -0
  26. sglang/srt/connector/utils.py +35 -0
  27. sglang/srt/conversation.py +88 -0
  28. sglang/srt/disaggregation/conn.py +81 -0
  29. sglang/srt/disaggregation/decode.py +495 -0
  30. sglang/srt/disaggregation/mini_lb.py +285 -0
  31. sglang/srt/disaggregation/prefill.py +249 -0
  32. sglang/srt/disaggregation/utils.py +44 -0
  33. sglang/srt/distributed/parallel_state.py +10 -3
  34. sglang/srt/entrypoints/engine.py +55 -5
  35. sglang/srt/entrypoints/http_server.py +71 -12
  36. sglang/srt/function_call_parser.py +164 -54
  37. sglang/srt/hf_transformers_utils.py +28 -3
  38. sglang/srt/layers/activation.py +4 -2
  39. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  40. sglang/srt/layers/attention/flashattention_backend.py +295 -0
  41. sglang/srt/layers/attention/flashinfer_backend.py +1 -1
  42. sglang/srt/layers/attention/flashmla_backend.py +284 -0
  43. sglang/srt/layers/attention/triton_backend.py +171 -38
  44. sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
  45. sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
  46. sglang/srt/layers/attention/utils.py +53 -0
  47. sglang/srt/layers/attention/vision.py +9 -28
  48. sglang/srt/layers/dp_attention.py +62 -23
  49. sglang/srt/layers/elementwise.py +411 -0
  50. sglang/srt/layers/layernorm.py +24 -2
  51. sglang/srt/layers/linear.py +17 -5
  52. sglang/srt/layers/logits_processor.py +26 -7
  53. sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
  54. sglang/srt/layers/moe/ep_moe/layer.py +273 -1
  55. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
  56. sglang/srt/layers/moe/fused_moe_native.py +2 -1
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
  63. sglang/srt/layers/moe/router.py +342 -0
  64. sglang/srt/layers/moe/topk.py +31 -18
  65. sglang/srt/layers/parameter.py +1 -1
  66. sglang/srt/layers/quantization/__init__.py +184 -126
  67. sglang/srt/layers/quantization/base_config.py +5 -0
  68. sglang/srt/layers/quantization/blockwise_int8.py +1 -1
  69. sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
  70. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
  71. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
  72. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
  73. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
  74. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
  75. sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
  76. sglang/srt/layers/quantization/fp8.py +76 -34
  77. sglang/srt/layers/quantization/fp8_kernel.py +24 -8
  78. sglang/srt/layers/quantization/fp8_utils.py +284 -28
  79. sglang/srt/layers/quantization/gptq.py +36 -9
  80. sglang/srt/layers/quantization/kv_cache.py +98 -0
  81. sglang/srt/layers/quantization/modelopt_quant.py +9 -7
  82. sglang/srt/layers/quantization/utils.py +153 -0
  83. sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
  84. sglang/srt/layers/rotary_embedding.py +66 -87
  85. sglang/srt/layers/sampler.py +1 -1
  86. sglang/srt/lora/layers.py +68 -0
  87. sglang/srt/lora/lora.py +2 -22
  88. sglang/srt/lora/lora_manager.py +47 -23
  89. sglang/srt/lora/mem_pool.py +110 -51
  90. sglang/srt/lora/utils.py +12 -1
  91. sglang/srt/managers/cache_controller.py +4 -5
  92. sglang/srt/managers/data_parallel_controller.py +31 -9
  93. sglang/srt/managers/expert_distribution.py +81 -0
  94. sglang/srt/managers/io_struct.py +39 -3
  95. sglang/srt/managers/mm_utils.py +373 -0
  96. sglang/srt/managers/multimodal_processor.py +68 -0
  97. sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
  98. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
  99. sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
  100. sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
  101. sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
  102. sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
  103. sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
  104. sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
  105. sglang/srt/managers/schedule_batch.py +134 -31
  106. sglang/srt/managers/scheduler.py +325 -38
  107. sglang/srt/managers/scheduler_output_processor_mixin.py +4 -1
  108. sglang/srt/managers/session_controller.py +1 -1
  109. sglang/srt/managers/tokenizer_manager.py +59 -23
  110. sglang/srt/managers/tp_worker.py +1 -1
  111. sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
  112. sglang/srt/managers/utils.py +6 -1
  113. sglang/srt/mem_cache/hiradix_cache.py +27 -8
  114. sglang/srt/mem_cache/memory_pool.py +258 -98
  115. sglang/srt/mem_cache/paged_allocator.py +2 -2
  116. sglang/srt/mem_cache/radix_cache.py +4 -4
  117. sglang/srt/model_executor/cuda_graph_runner.py +85 -28
  118. sglang/srt/model_executor/forward_batch_info.py +81 -15
  119. sglang/srt/model_executor/model_runner.py +70 -6
  120. sglang/srt/model_loader/loader.py +160 -2
  121. sglang/srt/model_loader/weight_utils.py +45 -0
  122. sglang/srt/models/deepseek_janus_pro.py +29 -86
  123. sglang/srt/models/deepseek_nextn.py +22 -10
  124. sglang/srt/models/deepseek_v2.py +326 -192
  125. sglang/srt/models/deepseek_vl2.py +358 -0
  126. sglang/srt/models/gemma3_causal.py +684 -0
  127. sglang/srt/models/gemma3_mm.py +462 -0
  128. sglang/srt/models/grok.py +374 -119
  129. sglang/srt/models/llama.py +47 -7
  130. sglang/srt/models/llama_eagle.py +1 -0
  131. sglang/srt/models/llama_eagle3.py +196 -0
  132. sglang/srt/models/llava.py +3 -3
  133. sglang/srt/models/llavavid.py +3 -3
  134. sglang/srt/models/minicpmo.py +1995 -0
  135. sglang/srt/models/minicpmv.py +62 -137
  136. sglang/srt/models/mllama.py +4 -4
  137. sglang/srt/models/phi3_small.py +1 -1
  138. sglang/srt/models/qwen2.py +3 -0
  139. sglang/srt/models/qwen2_5_vl.py +68 -146
  140. sglang/srt/models/qwen2_classification.py +75 -0
  141. sglang/srt/models/qwen2_moe.py +9 -1
  142. sglang/srt/models/qwen2_vl.py +25 -63
  143. sglang/srt/openai_api/adapter.py +145 -47
  144. sglang/srt/openai_api/protocol.py +23 -2
  145. sglang/srt/sampling/sampling_batch_info.py +1 -1
  146. sglang/srt/sampling/sampling_params.py +6 -6
  147. sglang/srt/server_args.py +104 -14
  148. sglang/srt/speculative/build_eagle_tree.py +7 -347
  149. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
  150. sglang/srt/speculative/eagle_utils.py +208 -252
  151. sglang/srt/speculative/eagle_worker.py +139 -53
  152. sglang/srt/speculative/spec_info.py +6 -1
  153. sglang/srt/torch_memory_saver_adapter.py +22 -0
  154. sglang/srt/utils.py +182 -21
  155. sglang/test/__init__.py +0 -0
  156. sglang/test/attention/__init__.py +0 -0
  157. sglang/test/attention/test_flashattn_backend.py +312 -0
  158. sglang/test/runners.py +2 -0
  159. sglang/test/test_activation.py +2 -1
  160. sglang/test/test_block_fp8.py +5 -4
  161. sglang/test/test_block_fp8_ep.py +2 -1
  162. sglang/test/test_dynamic_grad_mode.py +58 -0
  163. sglang/test/test_layernorm.py +3 -2
  164. sglang/test/test_utils.py +55 -4
  165. sglang/utils.py +31 -0
  166. sglang/version.py +1 -1
  167. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/METADATA +12 -8
  168. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/RECORD +171 -125
  169. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/WHEEL +1 -1
  170. sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
  171. sglang/srt/managers/image_processor.py +0 -55
  172. sglang/srt/managers/image_processors/base_image_processor.py +0 -219
  173. sglang/srt/managers/image_processors/minicpmv.py +0 -86
  174. sglang/srt/managers/multi_modality_padding.py +0 -134
  175. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info/licenses}/LICENSE +0 -0
  176. {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/top_level.txt +0 -0
sglang/srt/server_args.py CHANGED
@@ -16,6 +16,7 @@
16
16
  import argparse
17
17
  import dataclasses
18
18
  import logging
19
+ import os
19
20
  import random
20
21
  import tempfile
21
22
  from typing import List, Optional
@@ -24,12 +25,14 @@ from sglang.srt.hf_transformers_utils import check_gguf_file
24
25
  from sglang.srt.reasoning_parser import ReasoningParser
25
26
  from sglang.srt.utils import (
26
27
  get_amdgpu_memory_capacity,
28
+ get_device,
27
29
  get_hpu_memory_capacity,
28
30
  get_nvgpu_memory_capacity,
29
31
  is_cuda,
30
32
  is_flashinfer_available,
31
33
  is_hip,
32
34
  is_port_available,
35
+ is_remote_url,
33
36
  is_valid_ipv6_address,
34
37
  nullable_str,
35
38
  )
@@ -51,9 +54,10 @@ class ServerArgs:
51
54
  quantization: Optional[str] = None
52
55
  quantization_param_path: nullable_str = None
53
56
  context_length: Optional[int] = None
54
- device: str = "cuda"
57
+ device: Optional[str] = None
55
58
  served_model_name: Optional[str] = None
56
59
  chat_template: Optional[str] = None
60
+ completion_template: Optional[str] = None
57
61
  is_embedding: bool = False
58
62
  revision: Optional[str] = None
59
63
 
@@ -122,7 +126,7 @@ class ServerArgs:
122
126
  # Kernel backend
123
127
  attention_backend: Optional[str] = None
124
128
  sampling_backend: Optional[str] = None
125
- grammar_backend: Optional[str] = "outlines"
129
+ grammar_backend: Optional[str] = "xgrammar"
126
130
 
127
131
  # Speculative decoding
128
132
  speculative_algorithm: Optional[str] = None
@@ -154,6 +158,7 @@ class ServerArgs:
154
158
  enable_mixed_chunk: bool = False
155
159
  enable_dp_attention: bool = False
156
160
  enable_ep_moe: bool = False
161
+ enable_deepep_moe: bool = False
157
162
  enable_torch_compile: bool = False
158
163
  torch_compile_max_bs: int = 32
159
164
  cuda_graph_max_bs: Optional[int] = None
@@ -170,7 +175,9 @@ class ServerArgs:
170
175
  enable_custom_logit_processor: bool = False
171
176
  tool_call_parser: str = None
172
177
  enable_hierarchical_cache: bool = False
178
+ hicache_ratio: float = 2.0
173
179
  enable_flashinfer_mla: bool = False
180
+ enable_flashmla: bool = False
174
181
  flashinfer_mla_disable_ragged: bool = False
175
182
  warmups: Optional[str] = None
176
183
 
@@ -179,11 +186,18 @@ class ServerArgs:
179
186
  debug_tensor_dump_input_file: Optional[str] = None
180
187
  debug_tensor_dump_inject: bool = False
181
188
 
189
+ # For PD disaggregation: can be "null" (not disaggregated), "prefill" (prefill-only), or "decode" (decode-only)
190
+ disaggregation_mode: str = "null"
191
+ disaggregation_bootstrap_port: int = 8998
192
+
182
193
  def __post_init__(self):
183
194
  # Set missing default values
184
195
  if self.tokenizer_path is None:
185
196
  self.tokenizer_path = self.model_path
186
197
 
198
+ if self.device is None:
199
+ self.device = get_device()
200
+
187
201
  if self.served_model_name is None:
188
202
  self.served_model_name = self.model_path
189
203
 
@@ -222,6 +236,11 @@ class ServerArgs:
222
236
 
223
237
  assert self.chunked_prefill_size % self.page_size == 0
224
238
 
239
+ if self.enable_flashmla is True:
240
+ logger.warning(
241
+ "FlashMLA only supports a page_size of 64, change page_size to 64."
242
+ )
243
+ self.page_size = 64
225
244
  # Set cuda graph max batch size
226
245
  if self.cuda_graph_max_bs is None:
227
246
  # Based on detailed statistics, when serving TP1/TP2 models on lower-end GPUs with HBM<25G, you can either disable cuda graph or set `cuda_graph_max_bs` to a very small value to reduce the memory overhead of creating cuda graphs, with almost no impact on performance. However, when serving models with TP4 or TP8, we need to enable cuda graph to maintain high performance. In this case, we can set `cuda_graph_max_bs` to 80 (half of the default value 160) to reduce the memory overhead of creating cuda graphs. Looking at the logs from TP4 serving of qwen2-72b, a value of 80 is sufficient and can reduce the memory overhead of creating cuda graphs on lower-end GPUs compared to the original 160, avoiding OOM issues.
@@ -262,25 +281,33 @@ class ServerArgs:
262
281
 
263
282
  # Data parallelism attention
264
283
  if self.enable_dp_attention:
265
- self.dp_size = self.tp_size
266
- assert self.tp_size % self.dp_size == 0
267
- self.chunked_prefill_size = self.chunked_prefill_size // 2
268
284
  self.schedule_conservativeness = self.schedule_conservativeness * 0.3
285
+ assert (
286
+ self.dp_size > 1
287
+ ), "Please set a dp-size > 1. You can use 1 < dp-size <= tp-size "
288
+ assert self.tp_size % self.dp_size == 0
289
+ self.chunked_prefill_size = self.chunked_prefill_size // self.dp_size
269
290
  logger.warning(
270
291
  f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. "
271
- f"The schedule conservativeness is adjusted to {self.schedule_conservativeness}. "
272
- "Data parallel size is adjusted to be the same as tensor parallel size. "
273
292
  )
293
+ # DeepEP MoE
294
+ if self.enable_deepep_moe:
295
+ self.ep_size = self.dp_size
296
+ logger.info(
297
+ f"DeepEP MoE is enabled. The expert parallel size is adjusted to be the same as the data parallel size[{self.dp_size}]."
298
+ )
274
299
 
275
300
  # Speculative Decoding
276
301
  if self.speculative_algorithm == "NEXTN":
277
302
  # NEXTN shares the same implementation of EAGLE
278
303
  self.speculative_algorithm = "EAGLE"
279
304
 
280
- if self.speculative_algorithm == "EAGLE":
305
+ if (
306
+ self.speculative_algorithm == "EAGLE"
307
+ or self.speculative_algorithm == "EAGLE3"
308
+ ):
281
309
  if self.max_running_requests is None:
282
310
  self.max_running_requests = 32
283
- self.disable_cuda_graph_padding = True
284
311
  self.disable_overlap_schedule = True
285
312
  logger.info(
286
313
  "Overlap scheduler is disabled because of using "
@@ -296,10 +323,29 @@ class ServerArgs:
296
323
  ) and check_gguf_file(self.model_path):
297
324
  self.quantization = self.load_format = "gguf"
298
325
 
326
+ if is_remote_url(self.model_path):
327
+ self.load_format = "remote"
328
+
299
329
  # AMD-specific Triton attention KV splits default number
300
330
  if is_hip():
301
331
  self.triton_attention_num_kv_splits = 16
302
332
 
333
+ # PD disaggregation
334
+ if self.disaggregation_mode == "prefill":
335
+ self.disable_cuda_graph = True
336
+ logger.warning("KV cache is forced as chunk cache for decode server")
337
+ self.disable_overlap_schedule = True
338
+ logger.warning("Overlap scheduler is disabled for prefill server")
339
+ elif self.disaggregation_mode == "decode":
340
+ self.disable_radix_cache = True
341
+ logger.warning("Cuda graph is disabled for prefill server")
342
+ self.disable_overlap_schedule = True
343
+ logger.warning("Overlap scheduler is disabled for decode server")
344
+
345
+ os.environ["SGLANG_ENABLE_TORCH_COMPILE"] = (
346
+ "1" if self.enable_torch_compile else "0"
347
+ )
348
+
303
349
  @staticmethod
304
350
  def add_cli_args(parser: argparse.ArgumentParser):
305
351
  # Model and port args
@@ -345,9 +391,11 @@ class ServerArgs:
345
391
  "safetensors",
346
392
  "npcache",
347
393
  "dummy",
394
+ "sharded_state",
348
395
  "gguf",
349
396
  "bitsandbytes",
350
397
  "layered",
398
+ "remote",
351
399
  ],
352
400
  help="The format of the model weights to load. "
353
401
  '"auto" will try to load the weights in the safetensors format '
@@ -429,9 +477,8 @@ class ServerArgs:
429
477
  parser.add_argument(
430
478
  "--device",
431
479
  type=str,
432
- default="cuda",
433
- choices=["cuda", "xpu", "hpu", "cpu"],
434
- help="The device type.",
480
+ default=ServerArgs.device,
481
+ help="The device to use ('cuda', 'xpu', 'hpu', 'cpu'). Defaults to auto-detection if not specified.",
435
482
  )
436
483
  parser.add_argument(
437
484
  "--served-model-name",
@@ -445,6 +492,12 @@ class ServerArgs:
445
492
  default=ServerArgs.chat_template,
446
493
  help="The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server.",
447
494
  )
495
+ parser.add_argument(
496
+ "--completion-template",
497
+ type=str,
498
+ default=ServerArgs.completion_template,
499
+ help="The buliltin completion template name or the path of the completion template file. This is only used for OpenAI-compatible API server. only for code completion currently.",
500
+ )
448
501
  parser.add_argument(
449
502
  "--is-embedding",
450
503
  action="store_true",
@@ -722,7 +775,7 @@ class ServerArgs:
722
775
  parser.add_argument(
723
776
  "--attention-backend",
724
777
  type=str,
725
- choices=["flashinfer", "triton", "torch_native"],
778
+ choices=["flashinfer", "triton", "torch_native", "fa3"],
726
779
  default=ServerArgs.attention_backend,
727
780
  help="Choose the kernels for attention layers.",
728
781
  )
@@ -745,6 +798,11 @@ class ServerArgs:
745
798
  action="store_true",
746
799
  help="Enable FlashInfer MLA optimization",
747
800
  )
801
+ parser.add_argument(
802
+ "--enable-flashmla",
803
+ action="store_true",
804
+ help="Enable FlashMLA decode optimization",
805
+ )
748
806
  parser.add_argument(
749
807
  "--flashinfer-mla-disable-ragged",
750
808
  action="store_true",
@@ -755,7 +813,7 @@ class ServerArgs:
755
813
  parser.add_argument(
756
814
  "--speculative-algorithm",
757
815
  type=str,
758
- choices=["EAGLE", "NEXTN"],
816
+ choices=["EAGLE", "EAGLE3", "NEXTN"],
759
817
  help="Speculative algorithm.",
760
818
  )
761
819
  parser.add_argument(
@@ -984,6 +1042,18 @@ class ServerArgs:
984
1042
  action="store_true",
985
1043
  help="Enable hierarchical cache",
986
1044
  )
1045
+ parser.add_argument(
1046
+ "--hicache-ratio",
1047
+ type=float,
1048
+ required=False,
1049
+ default=ServerArgs.hicache_ratio,
1050
+ help="The ratio of the size of host KV cache memory pool to the size of device pool.",
1051
+ )
1052
+ parser.add_argument(
1053
+ "--enable-deepep-moe",
1054
+ action="store_true",
1055
+ help="Enabling DeepEP MoE implementation for EP MoE.",
1056
+ )
987
1057
 
988
1058
  # Server warmups
989
1059
  parser.add_argument(
@@ -1014,6 +1084,21 @@ class ServerArgs:
1014
1084
  help="Inject the outputs from jax as the input of every layer.",
1015
1085
  )
1016
1086
 
1087
+ # Disaggregation
1088
+ parser.add_argument(
1089
+ "--disaggregation-mode",
1090
+ type=str,
1091
+ default="null",
1092
+ choices=["null", "prefill", "decode"],
1093
+ help='Only used for PD disaggregation. "prefill" for prefill-only server, and "decode" for decode-only server. If not specified, it is not PD disaggregated',
1094
+ )
1095
+ parser.add_argument(
1096
+ "--disaggregation-bootstrap-port",
1097
+ type=int,
1098
+ default=ServerArgs.disaggregation_bootstrap_port,
1099
+ help="Bootstrap server port on the prefill server. Default is 8998.",
1100
+ )
1101
+
1017
1102
  @classmethod
1018
1103
  def from_cli_args(cls, args: argparse.Namespace):
1019
1104
  args.tp_size = args.tensor_parallel_size
@@ -1088,6 +1173,9 @@ class PortArgs:
1088
1173
  # The port for nccl initialization (torch.dist)
1089
1174
  nccl_port: int
1090
1175
 
1176
+ # The ipc filename for rpc call between Engine and Scheduler
1177
+ rpc_ipc_name: str
1178
+
1091
1179
  @staticmethod
1092
1180
  def init_new(server_args, dp_rank: Optional[int] = None) -> "PortArgs":
1093
1181
  port = server_args.port + random.randint(100, 1000)
@@ -1106,6 +1194,7 @@ class PortArgs:
1106
1194
  scheduler_input_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
1107
1195
  detokenizer_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
1108
1196
  nccl_port=port,
1197
+ rpc_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
1109
1198
  )
1110
1199
  else:
1111
1200
  # DP attention. Use TCP + port to handle both single-node and multi-node.
@@ -1131,6 +1220,7 @@ class PortArgs:
1131
1220
  scheduler_input_ipc_name=f"tcp://{dist_init_host}:{scheduler_input_port}",
1132
1221
  detokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base + 1}",
1133
1222
  nccl_port=port,
1223
+ rpc_ipc_name=f"tcp://{dist_init_host}:{port_base + 2}",
1134
1224
  )
1135
1225
 
1136
1226
 
@@ -3,8 +3,13 @@
3
3
  from typing import List
4
4
 
5
5
  import torch
6
- from sgl_kernel import build_tree_kernel as sgl_build_tree_kernel
7
- from sgl_kernel import build_tree_kernel_efficient as sgl_build_tree_kernel_efficient
6
+
7
+ from sglang.srt.utils import is_cuda_available, is_hip
8
+
9
+ if is_cuda_available() or is_hip():
10
+ from sgl_kernel import (
11
+ build_tree_kernel_efficient as sgl_build_tree_kernel_efficient,
12
+ )
8
13
 
9
14
 
10
15
  def build_tree_kernel_efficient_preprocess(
@@ -23,7 +28,6 @@ def build_tree_kernel_efficient_preprocess(
23
28
  top_scores = torch.topk(score_list, num_verify_tokens - 1, dim=-1)
24
29
  top_scores_index = top_scores.indices
25
30
  top_scores_index = torch.sort(top_scores_index).values
26
-
27
31
  draft_tokens = torch.gather(ss_token_list, index=top_scores_index, dim=1)
28
32
  draft_tokens = torch.cat((verified_id.unsqueeze(1), draft_tokens), dim=1).flatten()
29
33
 
@@ -108,296 +112,6 @@ def build_tree_kernel_efficient(
108
112
  )
109
113
 
110
114
 
111
- def build_tree_kernel(
112
- verified_id: torch.Tensor,
113
- score_list: List[torch.Tensor],
114
- token_list: List[torch.Tensor],
115
- parents_list: List[torch.Tensor],
116
- seq_lens: torch.Tensor,
117
- seq_lens_sum: int,
118
- topk: int,
119
- spec_steps: int,
120
- num_verify_tokens: int,
121
- ):
122
- parent_list, top_scores_index, draft_tokens = (
123
- build_tree_kernel_efficient_preprocess(
124
- verified_id,
125
- score_list,
126
- token_list,
127
- parents_list,
128
- num_verify_tokens,
129
- )
130
- )
131
-
132
- bs = seq_lens.numel()
133
- device = seq_lens.device
134
-
135
- tree_mask = torch.full(
136
- (
137
- seq_lens_sum * num_verify_tokens
138
- + num_verify_tokens * num_verify_tokens * bs,
139
- ),
140
- True,
141
- device=device,
142
- )
143
- retrive_index = torch.full(
144
- (bs, num_verify_tokens, spec_steps + 2), -1, device=device, dtype=torch.long
145
- )
146
- positions = torch.empty((bs * num_verify_tokens,), device=device, dtype=torch.long)
147
-
148
- sgl_build_tree_kernel(
149
- parent_list,
150
- top_scores_index,
151
- seq_lens.to(torch.int32),
152
- tree_mask,
153
- positions,
154
- retrive_index,
155
- topk,
156
- spec_steps,
157
- num_verify_tokens,
158
- )
159
-
160
- index = retrive_index.sum(dim=-1) != -spec_steps - 2
161
- cum_len = torch.cumsum(torch.sum(index, dim=-1), dim=-1)
162
- retrive_cum_len = torch.zeros(
163
- (cum_len.numel() + 1,), dtype=torch.int32, device="cuda"
164
- )
165
- retrive_cum_len[1:] = cum_len
166
- # TODO: this indexing cause a synchronization, optimize this
167
- retrive_index = retrive_index[index]
168
- return tree_mask, positions, retrive_index, retrive_cum_len, draft_tokens
169
-
170
-
171
- def test_build_tree_kernel():
172
- def findp(p_i, index, parent_list):
173
- pos = index // 10
174
- index_list = index.tolist()
175
- parent_list = parent_list.tolist()
176
- res = [p_i]
177
- while True:
178
- p = pos[p_i]
179
- if p == 0:
180
- break
181
- token_idx = parent_list[p]
182
- p_i = index_list.index(token_idx)
183
- res.append(p_i)
184
- return res
185
-
186
- def create_mask(seq_len, draft_token, index, parent_list, max_depth):
187
- mask = []
188
- positions = []
189
- retrive_index = []
190
- for i, lens in enumerate(seq_len.tolist()):
191
- first_mask = torch.full((lens + draft_token,), True)
192
- first_mask[-(draft_token - 1) :] = False
193
- positions.append(lens)
194
- mask.append(first_mask)
195
- seq_order = []
196
- first_index = torch.Tensor([0] + [-1] * (depth + 1)).cuda().to(torch.long)
197
- r_index = [first_index]
198
- for j in range(draft_token - 1):
199
- mask.append(torch.full((lens + 1,), True))
200
- idx = findp(j, index, parent_list)
201
-
202
- seq_order.append(idx)
203
- positions.append(len(idx) + seq_len)
204
- t = torch.full((draft_token - 1,), False)
205
- t[idx] = True
206
- mask.append(t)
207
-
208
- for i in range(1, draft_token - 1):
209
- is_leaf = 0
210
- for j in range(draft_token - 1):
211
- if i in seq_order[j]:
212
- is_leaf += 1
213
-
214
- if is_leaf == 1:
215
- order_list = [0] + [x + 1 for x in seq_order[i][::-1]]
216
- for _ in range(max_depth + 1 - len(seq_order[i])):
217
- order_list.append(-1)
218
- order = torch.Tensor(order_list).cuda().to(torch.long)
219
- r_index.append(order)
220
- retrive_index.append(torch.stack(r_index))
221
-
222
- return (
223
- torch.cat(mask).cuda(),
224
- torch.Tensor(positions).cuda().to(torch.long),
225
- torch.stack(retrive_index),
226
- )
227
-
228
- index = (
229
- torch.Tensor(
230
- [
231
- 0,
232
- 1,
233
- 2,
234
- 3,
235
- 10,
236
- 11,
237
- 12,
238
- 13,
239
- 20,
240
- 21,
241
- 22,
242
- 30,
243
- 110,
244
- 130,
245
- 150,
246
- 160,
247
- 210,
248
- 211,
249
- 212,
250
- 213,
251
- 214,
252
- 215,
253
- 216,
254
- 217,
255
- 218,
256
- 219,
257
- 220,
258
- 230,
259
- 310,
260
- 311,
261
- 312,
262
- 313,
263
- 314,
264
- 315,
265
- 316,
266
- 317,
267
- 320,
268
- 321,
269
- 322,
270
- 330,
271
- 360,
272
- 380,
273
- 390,
274
- 410,
275
- 411,
276
- 412,
277
- 413,
278
- 414,
279
- 415,
280
- 416,
281
- 417,
282
- 418,
283
- 419,
284
- 420,
285
- 421,
286
- 422,
287
- 423,
288
- 430,
289
- 431,
290
- 440,
291
- 441,
292
- 460,
293
- 470,
294
- ]
295
- )
296
- .to(torch.long)
297
- .cuda()
298
- )
299
-
300
- parent_list = (
301
- torch.Tensor(
302
- [
303
- -1,
304
- 0,
305
- 1,
306
- 2,
307
- 3,
308
- 4,
309
- 5,
310
- 6,
311
- 7,
312
- 8,
313
- 9,
314
- 10,
315
- 11,
316
- 12,
317
- 20,
318
- 30,
319
- 21,
320
- 13,
321
- 22,
322
- 40,
323
- 23,
324
- 110,
325
- 130,
326
- 160,
327
- 150,
328
- 190,
329
- 120,
330
- 111,
331
- 121,
332
- 200,
333
- 180,
334
- 210,
335
- 211,
336
- 212,
337
- 213,
338
- 214,
339
- 215,
340
- 216,
341
- 220,
342
- 230,
343
- 217,
344
- 310,
345
- 311,
346
- 312,
347
- 313,
348
- 320,
349
- 314,
350
- 321,
351
- 315,
352
- 316,
353
- 317,
354
- ]
355
- )
356
- .to(torch.long)
357
- .cuda()
358
- )
359
-
360
- verified_seq_len = torch.Tensor([47]).to(torch.long).cuda()
361
- bs = verified_seq_len.shape[0]
362
- topk = 10
363
- depth = 5 # depth <= 10
364
- num_draft_token = 64
365
-
366
- tree_mask = torch.full(
367
- (
368
- torch.sum(verified_seq_len).item() * num_draft_token
369
- + num_draft_token * num_draft_token * bs,
370
- ),
371
- True,
372
- ).cuda()
373
- retrive_index = torch.full(
374
- (bs, num_draft_token, depth + 2), -1, device="cuda", dtype=torch.long
375
- )
376
- positions = torch.empty((bs * num_draft_token,), device="cuda", dtype=torch.long)
377
-
378
- sgl_build_tree_kernel(
379
- parent_list.unsqueeze(0),
380
- index.unsqueeze(0),
381
- verified_seq_len,
382
- tree_mask,
383
- positions,
384
- retrive_index,
385
- topk,
386
- depth,
387
- num_draft_token,
388
- )
389
-
390
- retrive_index = retrive_index[retrive_index.sum(dim=-1) != -depth - 2]
391
-
392
- c_mask, c_positions, c_retive_index = create_mask(
393
- verified_seq_len, num_draft_token, index, parent_list, depth
394
- )
395
-
396
- assert torch.allclose(tree_mask, c_mask), "tree mask has error."
397
- assert torch.allclose(positions, c_positions), "positions has error."
398
- assert torch.allclose(retrive_index, c_retive_index), "retrive_index has error."
399
-
400
-
401
115
  def test_build_tree_kernel_efficient():
402
116
  verified_id = torch.tensor([29974, 13], device="cuda", dtype=torch.int32)
403
117
  score_list = [
@@ -611,59 +325,6 @@ def test_build_tree_kernel_efficient():
611
325
  depth = 4
612
326
  num_draft_token = 8
613
327
 
614
- tree_mask, position, retrive_index, retrive_cum_len, draft_tokens = (
615
- build_tree_kernel(
616
- verified_id=verified_id,
617
- score_list=score_list,
618
- token_list=token_list,
619
- parents_list=parents_list,
620
- seq_lens=seq_lens,
621
- seq_lens_sum=torch.sum(seq_lens).item(),
622
- topk=topk,
623
- spec_steps=depth,
624
- num_verify_tokens=num_draft_token,
625
- )
626
- )
627
-
628
- from sglang.srt.utils import first_rank_print
629
-
630
- first_rank_print("=========== build tree kernel ==========")
631
- # first_rank_print(f"{tree_mask=}", flush=True)
632
- first_rank_print(f"{position=}", flush=True)
633
- first_rank_print(f"{retrive_index=}", flush=True)
634
- first_rank_print(f"{retrive_cum_len=}", flush=True)
635
- first_rank_print(f"{draft_tokens=}", flush=True)
636
- assert position.tolist() == [5, 6, 6, 7, 7, 8, 8, 9, 10, 11, 12, 12, 12, 12, 13, 14]
637
- assert retrive_index.tolist() == [
638
- [0, -1, -1, -1, -1, -1],
639
- [0, 2, 4, 6, -1, -1],
640
- [0, 1, 3, 5, 7, -1],
641
- [8, -1, -1, -1, -1, -1],
642
- [8, 9, 10, -1, -1, -1],
643
- [8, 9, 12, -1, -1, -1],
644
- [8, 9, 13, -1, -1, -1],
645
- [8, 9, 11, 14, 15, -1],
646
- ]
647
- assert retrive_cum_len.tolist() == [0, 3, 8]
648
- assert draft_tokens.tolist() == [
649
- 29974,
650
- 29896,
651
- 29906,
652
- 29889,
653
- 29974,
654
- 29946,
655
- 29896,
656
- 29946,
657
- 13,
658
- 13,
659
- 22550,
660
- 4136,
661
- 16492,
662
- 8439,
663
- 29871,
664
- 29941,
665
- ]
666
-
667
328
  (
668
329
  tree_mask,
669
330
  position,
@@ -725,4 +386,3 @@ def test_build_tree_kernel_efficient():
725
386
 
726
387
  if __name__ == "__main__":
727
388
  test_build_tree_kernel_efficient()
728
- test_build_tree_kernel()