sglang 0.4.8.post1__py3-none-any.whl → 0.4.9.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (158) hide show
  1. sglang/bench_one_batch_server.py +17 -2
  2. sglang/bench_serving.py +170 -24
  3. sglang/srt/configs/internvl.py +4 -2
  4. sglang/srt/configs/janus_pro.py +1 -1
  5. sglang/srt/configs/model_config.py +60 -1
  6. sglang/srt/configs/update_config.py +119 -0
  7. sglang/srt/conversation.py +69 -1
  8. sglang/srt/disaggregation/decode.py +21 -5
  9. sglang/srt/disaggregation/mooncake/conn.py +35 -4
  10. sglang/srt/disaggregation/nixl/conn.py +6 -6
  11. sglang/srt/disaggregation/prefill.py +2 -2
  12. sglang/srt/disaggregation/utils.py +1 -1
  13. sglang/srt/distributed/parallel_state.py +44 -17
  14. sglang/srt/entrypoints/EngineBase.py +8 -0
  15. sglang/srt/entrypoints/engine.py +40 -6
  16. sglang/srt/entrypoints/http_server.py +111 -24
  17. sglang/srt/entrypoints/http_server_engine.py +1 -1
  18. sglang/srt/entrypoints/openai/protocol.py +4 -2
  19. sglang/srt/eplb/__init__.py +0 -0
  20. sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
  21. sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
  22. sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
  23. sglang/srt/{managers → eplb}/expert_distribution.py +1 -5
  24. sglang/srt/{managers → eplb}/expert_location.py +1 -1
  25. sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
  26. sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
  27. sglang/srt/hf_transformers_utils.py +2 -1
  28. sglang/srt/layers/activation.py +2 -2
  29. sglang/srt/layers/amx_utils.py +86 -0
  30. sglang/srt/layers/attention/ascend_backend.py +219 -0
  31. sglang/srt/layers/attention/flashattention_backend.py +32 -9
  32. sglang/srt/layers/attention/tbo_backend.py +37 -9
  33. sglang/srt/layers/communicator.py +20 -2
  34. sglang/srt/layers/dp_attention.py +9 -3
  35. sglang/srt/layers/elementwise.py +76 -12
  36. sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
  37. sglang/srt/layers/layernorm.py +26 -0
  38. sglang/srt/layers/linear.py +84 -14
  39. sglang/srt/layers/logits_processor.py +4 -4
  40. sglang/srt/layers/moe/cutlass_w4a8_moe.py +215 -0
  41. sglang/srt/layers/moe/ep_moe/kernels.py +81 -8
  42. sglang/srt/layers/moe/ep_moe/layer.py +176 -15
  43. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +23 -17
  44. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +3 -2
  45. sglang/srt/layers/moe/fused_moe_triton/layer.py +211 -74
  46. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +176 -0
  47. sglang/srt/layers/moe/router.py +60 -22
  48. sglang/srt/layers/moe/topk.py +10 -28
  49. sglang/srt/layers/parameter.py +67 -7
  50. sglang/srt/layers/quantization/__init__.py +2 -0
  51. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
  52. sglang/srt/layers/quantization/fp8.py +72 -7
  53. sglang/srt/layers/quantization/fp8_kernel.py +1 -1
  54. sglang/srt/layers/quantization/fp8_utils.py +1 -2
  55. sglang/srt/layers/quantization/gptq.py +5 -1
  56. sglang/srt/layers/quantization/modelopt_quant.py +244 -1
  57. sglang/srt/layers/quantization/moe_wna16.py +1 -1
  58. sglang/srt/layers/quantization/quant_utils.py +166 -0
  59. sglang/srt/layers/quantization/w4afp8.py +264 -0
  60. sglang/srt/layers/quantization/w8a8_int8.py +52 -1
  61. sglang/srt/layers/rotary_embedding.py +2 -2
  62. sglang/srt/layers/vocab_parallel_embedding.py +20 -10
  63. sglang/srt/lora/lora.py +4 -5
  64. sglang/srt/lora/lora_manager.py +73 -20
  65. sglang/srt/lora/triton_ops/gate_up_lora_b.py +30 -19
  66. sglang/srt/lora/triton_ops/qkv_lora_b.py +30 -19
  67. sglang/srt/lora/triton_ops/sgemm_lora_a.py +27 -11
  68. sglang/srt/lora/triton_ops/sgemm_lora_b.py +27 -15
  69. sglang/srt/managers/cache_controller.py +41 -195
  70. sglang/srt/managers/configure_logging.py +1 -1
  71. sglang/srt/managers/io_struct.py +58 -14
  72. sglang/srt/managers/mm_utils.py +77 -61
  73. sglang/srt/managers/multimodal_processor.py +2 -6
  74. sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
  75. sglang/srt/managers/schedule_batch.py +78 -85
  76. sglang/srt/managers/scheduler.py +130 -64
  77. sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
  78. sglang/srt/managers/session_controller.py +12 -3
  79. sglang/srt/managers/tokenizer_manager.py +314 -103
  80. sglang/srt/managers/tp_worker.py +13 -1
  81. sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
  82. sglang/srt/mem_cache/allocator.py +290 -0
  83. sglang/srt/mem_cache/chunk_cache.py +34 -2
  84. sglang/srt/mem_cache/hiradix_cache.py +2 -0
  85. sglang/srt/mem_cache/memory_pool.py +402 -66
  86. sglang/srt/mem_cache/memory_pool_host.py +6 -109
  87. sglang/srt/mem_cache/multimodal_cache.py +3 -0
  88. sglang/srt/mem_cache/radix_cache.py +8 -4
  89. sglang/srt/model_executor/cuda_graph_runner.py +2 -1
  90. sglang/srt/model_executor/forward_batch_info.py +17 -4
  91. sglang/srt/model_executor/model_runner.py +297 -56
  92. sglang/srt/model_loader/loader.py +41 -0
  93. sglang/srt/model_loader/weight_utils.py +72 -4
  94. sglang/srt/models/deepseek_nextn.py +1 -3
  95. sglang/srt/models/deepseek_v2.py +195 -45
  96. sglang/srt/models/deepseek_vl2.py +3 -5
  97. sglang/srt/models/gemma3_causal.py +1 -2
  98. sglang/srt/models/gemma3n_causal.py +4 -3
  99. sglang/srt/models/gemma3n_mm.py +4 -20
  100. sglang/srt/models/hunyuan.py +1 -1
  101. sglang/srt/models/kimi_vl.py +1 -2
  102. sglang/srt/models/llama.py +10 -4
  103. sglang/srt/models/llama4.py +32 -45
  104. sglang/srt/models/llama_eagle3.py +61 -11
  105. sglang/srt/models/llava.py +5 -5
  106. sglang/srt/models/minicpmo.py +2 -2
  107. sglang/srt/models/mistral.py +1 -1
  108. sglang/srt/models/mllama4.py +402 -89
  109. sglang/srt/models/phi4mm.py +1 -3
  110. sglang/srt/models/pixtral.py +3 -7
  111. sglang/srt/models/qwen2.py +31 -3
  112. sglang/srt/models/qwen2_5_vl.py +1 -3
  113. sglang/srt/models/qwen2_audio.py +200 -0
  114. sglang/srt/models/qwen2_moe.py +32 -6
  115. sglang/srt/models/qwen2_vl.py +1 -4
  116. sglang/srt/models/qwen3.py +94 -25
  117. sglang/srt/models/qwen3_moe.py +68 -21
  118. sglang/srt/models/vila.py +3 -8
  119. sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +2 -2
  120. sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +140 -158
  121. sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
  122. sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
  123. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
  124. sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3n.py +5 -20
  125. sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
  126. sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
  127. sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
  128. sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
  129. sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
  130. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
  131. sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +65 -66
  132. sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
  133. sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
  134. sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
  135. sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
  136. sglang/srt/operations_strategy.py +6 -2
  137. sglang/srt/reasoning_parser.py +26 -0
  138. sglang/srt/sampling/sampling_batch_info.py +39 -1
  139. sglang/srt/server_args.py +84 -22
  140. sglang/srt/speculative/build_eagle_tree.py +57 -18
  141. sglang/srt/speculative/eagle_worker.py +6 -4
  142. sglang/srt/two_batch_overlap.py +203 -27
  143. sglang/srt/utils.py +343 -163
  144. sglang/srt/warmup.py +12 -3
  145. sglang/test/runners.py +10 -1
  146. sglang/test/test_cutlass_w4a8_moe.py +281 -0
  147. sglang/test/test_utils.py +15 -3
  148. sglang/utils.py +5 -5
  149. sglang/version.py +1 -1
  150. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/METADATA +12 -8
  151. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/RECORD +157 -146
  152. sglang/math_utils.py +0 -8
  153. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
  154. /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
  155. /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
  156. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/WHEEL +0 -0
  157. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/licenses/LICENSE +0 -0
  158. {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/top_level.txt +0 -0
sglang/srt/server_args.py CHANGED
@@ -20,7 +20,7 @@ import logging
20
20
  import os
21
21
  import random
22
22
  import tempfile
23
- from typing import List, Literal, Optional
23
+ from typing import List, Literal, Optional, Union
24
24
 
25
25
  from sglang.srt.hf_transformers_utils import check_gguf_file, get_config
26
26
  from sglang.srt.reasoning_parser import ReasoningParser
@@ -46,6 +46,7 @@ class ServerArgs:
46
46
  tokenizer_path: Optional[str] = None
47
47
  tokenizer_mode: str = "auto"
48
48
  skip_tokenizer_init: bool = False
49
+ skip_server_warmup: bool = False
49
50
  load_format: str = "auto"
50
51
  model_loader_extra_config: str = "{}"
51
52
  trust_remote_code: bool = False
@@ -61,11 +62,13 @@ class ServerArgs:
61
62
  is_embedding: bool = False
62
63
  enable_multimodal: Optional[bool] = None
63
64
  revision: Optional[str] = None
65
+ hybrid_kvcache_ratio: Optional[float] = None
64
66
  impl: str = "auto"
65
67
 
66
68
  # Port for the HTTP server
67
69
  host: str = "127.0.0.1"
68
70
  port: int = 30000
71
+ nccl_port: Optional[int] = None
69
72
 
70
73
  # Memory and scheduling
71
74
  mem_fraction_static: Optional[float] = None
@@ -98,6 +101,7 @@ class ServerArgs:
98
101
  log_level_http: Optional[str] = None
99
102
  log_requests: bool = False
100
103
  log_requests_level: int = 0
104
+ crash_dump_folder: Optional[str] = None
101
105
  show_time_cost: bool = False
102
106
  enable_metrics: bool = False
103
107
  bucket_time_to_first_token: Optional[List[float]] = None
@@ -129,7 +133,7 @@ class ServerArgs:
129
133
  preferred_sampling_params: Optional[str] = None
130
134
 
131
135
  # LoRA
132
- lora_paths: Optional[List[str]] = None
136
+ lora_paths: Optional[Union[dict[str, str], List[str]]] = None
133
137
  max_loras_per_batch: int = 8
134
138
  lora_backend: str = "triton"
135
139
 
@@ -154,6 +158,7 @@ class ServerArgs:
154
158
  enable_ep_moe: bool = False
155
159
  enable_deepep_moe: bool = False
156
160
  enable_flashinfer_moe: bool = False
161
+ enable_flashinfer_allreduce_fusion: bool = False
157
162
  deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto"
158
163
  ep_num_redundant_experts: int = 0
159
164
  ep_dispatch_algorithm: Optional[Literal["static", "dynamic", "fake"]] = None
@@ -212,11 +217,13 @@ class ServerArgs:
212
217
  hicache_ratio: float = 2.0
213
218
  hicache_size: int = 0
214
219
  hicache_write_policy: str = "write_through_selective"
220
+ hicache_io_backend: str = ""
215
221
  flashinfer_mla_disable_ragged: bool = False
216
222
  disable_shared_experts_fusion: bool = False
217
223
  disable_chunked_prefix_cache: bool = False
218
224
  disable_fast_image_processor: bool = False
219
225
  enable_return_hidden_states: bool = False
226
+ enable_triton_kernel_moe: bool = False
220
227
  warmups: Optional[str] = None
221
228
 
222
229
  # Debug tensor dumps
@@ -315,6 +322,14 @@ class ServerArgs:
315
322
  else:
316
323
  self.mem_fraction_static = 0.88
317
324
 
325
+ # Lazy init to avoid circular import
326
+ from sglang.srt.configs.model_config import ModelConfig
327
+
328
+ # Multimodal models need more memory for the image processor
329
+ model_config = ModelConfig.from_server_args(self)
330
+ if model_config.is_multimodal:
331
+ self.mem_fraction_static *= 0.90
332
+
318
333
  # Set chunked prefill size, which depends on the gpu memory capacity
319
334
  if self.chunked_prefill_size is None:
320
335
  if gpu_mem is not None:
@@ -376,6 +391,12 @@ class ServerArgs:
376
391
  )
377
392
  self.disable_cuda_graph = True
378
393
 
394
+ if self.attention_backend == "ascend":
395
+ logger.warning(
396
+ "At this moment Ascend attention backend only supports a page_size of 128, change page_size to 128."
397
+ )
398
+ self.page_size = 128
399
+
379
400
  # Choose grammar backend
380
401
  if self.grammar_backend is None:
381
402
  self.grammar_backend = "xgrammar"
@@ -399,10 +420,6 @@ class ServerArgs:
399
420
 
400
421
  # DeepEP MoE
401
422
  if self.enable_deepep_moe:
402
- if self.deepep_mode == "auto":
403
- assert (
404
- not self.enable_dp_attention
405
- ), "DeepEP MoE `auto` mode is not supported with DP Attention."
406
423
  if self.deepep_mode == "normal":
407
424
  logger.warning("Cuda graph is disabled because deepep_mode=`normal`")
408
425
  self.disable_cuda_graph = True
@@ -485,12 +502,6 @@ class ServerArgs:
485
502
  self.speculative_num_draft_tokens,
486
503
  ) = auto_choose_speculative_params(self)
487
504
 
488
- if self.page_size > 1 and self.speculative_eagle_topk > 1:
489
- self.speculative_eagle_topk = 1
490
- logger.warning(
491
- "speculative_eagle_topk is adjusted to 1 when page_size > 1"
492
- )
493
-
494
505
  if (
495
506
  self.speculative_eagle_topk == 1
496
507
  and self.speculative_num_draft_tokens != self.speculative_num_steps + 1
@@ -587,6 +598,12 @@ class ServerArgs:
587
598
  default=ServerArgs.port,
588
599
  help="The port of the HTTP server.",
589
600
  )
601
+ parser.add_argument(
602
+ "--nccl-port",
603
+ type=int,
604
+ default=ServerArgs.nccl_port,
605
+ help="The port for NCCL distributed environment setup. Defaults to a random port.",
606
+ )
590
607
  parser.add_argument(
591
608
  "--tokenizer-mode",
592
609
  type=str,
@@ -601,6 +618,11 @@ class ServerArgs:
601
618
  action="store_true",
602
619
  help="If set, skip init tokenizer and pass input_ids in generate request.",
603
620
  )
621
+ parser.add_argument(
622
+ "--skip-server-warmup",
623
+ action="store_true",
624
+ help="If set, skip warmup.",
625
+ )
604
626
  parser.add_argument(
605
627
  "--load-format",
606
628
  type=str,
@@ -686,6 +708,7 @@ class ServerArgs:
686
708
  "w8a8_fp8",
687
709
  "moe_wna16",
688
710
  "qoq",
711
+ "w4afp8",
689
712
  ],
690
713
  help="The quantization method.",
691
714
  )
@@ -817,6 +840,18 @@ class ServerArgs:
817
840
  default=ServerArgs.page_size,
818
841
  help="The number of tokens in a page.",
819
842
  )
843
+ parser.add_argument(
844
+ "--hybrid-kvcache-ratio",
845
+ nargs="?",
846
+ const=0.5,
847
+ type=float,
848
+ default=ServerArgs.hybrid_kvcache_ratio,
849
+ help=(
850
+ "Mix ratio in [0,1] between uniform and hybrid kv buffers "
851
+ "(0.0 = pure uniform: swa_size / full_size = 1)"
852
+ "(1.0 = pure hybrid: swa_size / full_size = local_attention_size / context_length)"
853
+ ),
854
+ )
820
855
 
821
856
  # Other runtime options
822
857
  parser.add_argument(
@@ -920,8 +955,14 @@ class ServerArgs:
920
955
  "--log-requests-level",
921
956
  type=int,
922
957
  default=0,
923
- help="0: Log metadata. 1. Log metadata and partial input/output. 2. Log every input/output.",
924
- choices=[0, 1, 2],
958
+ help="0: Log metadata (no sampling parameters). 1: Log metadata and sampling parameters. 2: Log metadata, sampling parameters and partial input/output. 3: Log every input/output.",
959
+ choices=[0, 1, 2, 3],
960
+ )
961
+ parser.add_argument(
962
+ "--crash-dump-folder",
963
+ type=str,
964
+ default=ServerArgs.crash_dump_folder,
965
+ help="Folder path to dump requests from the last 5 min before a crash (if any). If not specified, crash dumping is disabled.",
925
966
  )
926
967
  parser.add_argument(
927
968
  "--show-time-cost",
@@ -1092,6 +1133,7 @@ class ServerArgs:
1092
1133
  "flashmla",
1093
1134
  "intel_amx",
1094
1135
  "torch_native",
1136
+ "ascend",
1095
1137
  "triton",
1096
1138
  ],
1097
1139
  default=ServerArgs.attention_backend,
@@ -1186,6 +1228,11 @@ class ServerArgs:
1186
1228
  action="store_true",
1187
1229
  help="Enable FlashInfer CUTLASS MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP with --enable-ep-moe",
1188
1230
  )
1231
+ parser.add_argument(
1232
+ "--enable-flashinfer-allreduce-fusion",
1233
+ action="store_true",
1234
+ help="Enable FlashInfer allreduce fusion for Add_RMSNorm.",
1235
+ )
1189
1236
  parser.add_argument(
1190
1237
  "--enable-deepep-moe",
1191
1238
  action="store_true",
@@ -1485,6 +1532,13 @@ class ServerArgs:
1485
1532
  default=ServerArgs.hicache_write_policy,
1486
1533
  help="The write policy of hierarchical cache.",
1487
1534
  )
1535
+ parser.add_argument(
1536
+ "--hicache-io-backend",
1537
+ type=str,
1538
+ choices=["direct", "kernel"],
1539
+ default=ServerArgs.hicache_io_backend,
1540
+ help="The IO backend for KV cache transfer between CPU and GPU",
1541
+ )
1488
1542
  parser.add_argument(
1489
1543
  "--flashinfer-mla-disable-ragged",
1490
1544
  action="store_true",
@@ -1510,6 +1564,11 @@ class ServerArgs:
1510
1564
  action="store_true",
1511
1565
  help="Enable returning hidden states with responses.",
1512
1566
  )
1567
+ parser.add_argument(
1568
+ "--enable-triton-kernel-moe",
1569
+ action="store_true",
1570
+ help="Use triton moe grouped gemm kernel.",
1571
+ )
1513
1572
  parser.add_argument(
1514
1573
  "--warmups",
1515
1574
  type=str,
@@ -1706,14 +1765,17 @@ class PortArgs:
1706
1765
 
1707
1766
  @staticmethod
1708
1767
  def init_new(server_args, dp_rank: Optional[int] = None) -> "PortArgs":
1709
- port = server_args.port + random.randint(100, 1000)
1710
- while True:
1711
- if is_port_available(port):
1712
- break
1713
- if port < 60000:
1714
- port += 42
1715
- else:
1716
- port -= 43
1768
+ if server_args.nccl_port is None:
1769
+ port = server_args.port + random.randint(100, 1000)
1770
+ while True:
1771
+ if is_port_available(port):
1772
+ break
1773
+ if port < 60000:
1774
+ port += 42
1775
+ else:
1776
+ port -= 43
1777
+ else:
1778
+ port = server_args.nccl_port
1717
1779
 
1718
1780
  if not server_args.enable_dp_attention:
1719
1781
  # Normal case, use IPC within a single node
@@ -1,10 +1,12 @@
1
1
  # NOTE: Please run this file to make sure the test cases are correct.
2
2
 
3
- from typing import List
3
+ import math
4
+ from enum import IntEnum
5
+ from typing import List, Optional
4
6
 
5
7
  import torch
6
8
 
7
- from sglang.srt.utils import is_cuda, is_hip, rank0_print
9
+ from sglang.srt.utils import is_cuda, is_hip
8
10
 
9
11
  if is_cuda() or is_hip():
10
12
  from sgl_kernel import (
@@ -40,6 +42,12 @@ def build_tree_kernel_efficient_preprocess(
40
42
  return parent_list, top_scores_index, draft_tokens
41
43
 
42
44
 
45
+ class TreeMaskMode(IntEnum):
46
+ FULL_MASK = 0
47
+ QLEN_ONLY = 1
48
+ QLEN_ONLY_BITPACKING = 2
49
+
50
+
43
51
  def build_tree_kernel_efficient(
44
52
  verified_id: torch.Tensor,
45
53
  score_list: List[torch.Tensor],
@@ -50,6 +58,9 @@ def build_tree_kernel_efficient(
50
58
  topk: int,
51
59
  spec_steps: int,
52
60
  num_verify_tokens: int,
61
+ tree_mask_mode: TreeMaskMode = TreeMaskMode.FULL_MASK,
62
+ tree_mask_buf: Optional[torch.Tensor] = None,
63
+ position_buf: Optional[torch.Tensor] = None,
53
64
  ):
54
65
  parent_list, top_scores_index, draft_tokens = (
55
66
  build_tree_kernel_efficient_preprocess(
@@ -66,15 +77,37 @@ def build_tree_kernel_efficient(
66
77
  device = seq_lens.device
67
78
  # e.g. for bs=1, tree_mask: num_draft_token, seq_lens_sum + num_draft_token (flattened)
68
79
  # where each row indicates the attending pattern of each draft token
80
+ # if use_partial_packed_tree_mask is True, tree_mask: num_draft_token (flattened, packed)
81
+ if tree_mask_buf is not None:
82
+ tree_mask = tree_mask_buf
83
+ elif tree_mask_mode == TreeMaskMode.QLEN_ONLY:
84
+ tree_mask = torch.full(
85
+ (num_verify_tokens * bs * num_verify_tokens,),
86
+ True,
87
+ dtype=torch.bool,
88
+ device=device,
89
+ )
90
+ elif tree_mask_mode == TreeMaskMode.QLEN_ONLY_BITPACKING:
91
+ packed_dtypes = [torch.uint8, torch.uint16, torch.uint32]
92
+ packed_dtype_idx = int(math.ceil(math.log2((num_verify_tokens + 7) // 8)))
93
+ tree_mask = torch.zeros(
94
+ (num_verify_tokens * bs,),
95
+ dtype=packed_dtypes[packed_dtype_idx],
96
+ device=device,
97
+ )
98
+ elif tree_mask_mode == TreeMaskMode.FULL_MASK:
99
+ tree_mask = torch.full(
100
+ (
101
+ seq_lens_sum * num_verify_tokens
102
+ + num_verify_tokens * num_verify_tokens * bs,
103
+ ),
104
+ True,
105
+ device=device,
106
+ )
107
+ else:
108
+ raise NotImplementedError(f"Invalid tree mask: {tree_mask_mode=}")
109
+
69
110
  # TODO: make them torch.empty and fuse them into `sgl_build_tree_kernel`
70
- tree_mask = torch.full(
71
- (
72
- seq_lens_sum * num_verify_tokens
73
- + num_verify_tokens * num_verify_tokens * bs,
74
- ),
75
- True,
76
- device=device,
77
- )
78
111
  retrive_index = torch.full(
79
112
  (bs, num_verify_tokens), -1, device=device, dtype=torch.long
80
113
  )
@@ -87,7 +120,12 @@ def build_tree_kernel_efficient(
87
120
  # position: where each token belongs to
88
121
  # e.g. if depth of each draft token is [0, 1, 1, 2] and the prompt length is 7
89
122
  # then, positions = [7, 8, 8, 9]
90
- positions = torch.empty((bs * num_verify_tokens,), device=device, dtype=torch.long)
123
+ if position_buf is not None:
124
+ positions = position_buf
125
+ else:
126
+ positions = torch.empty(
127
+ (bs * num_verify_tokens,), device=device, dtype=torch.long
128
+ )
91
129
 
92
130
  sgl_build_tree_kernel_efficient(
93
131
  parent_list,
@@ -101,6 +139,7 @@ def build_tree_kernel_efficient(
101
139
  topk,
102
140
  spec_steps,
103
141
  num_verify_tokens,
142
+ tree_mask_mode,
104
143
  )
105
144
  return (
106
145
  tree_mask,
@@ -344,13 +383,13 @@ def test_build_tree_kernel_efficient():
344
383
  num_verify_tokens=num_draft_token,
345
384
  )
346
385
 
347
- rank0_print("=========== build tree kernel efficient ==========")
348
- # rank0_print(f"{tree_mask=}", flush=True)
349
- rank0_print(f"{position=}", flush=True)
350
- rank0_print(f"{retrive_index=}", flush=True)
351
- rank0_print(f"{retrive_next_token=}", flush=True)
352
- rank0_print(f"{retrive_next_sibling=}", flush=True)
353
- rank0_print(f"{draft_tokens=}", flush=True)
386
+ print("=========== build tree kernel efficient ==========")
387
+ print(f"{tree_mask=}")
388
+ print(f"{position=}")
389
+ print(f"{retrive_index=}")
390
+ print(f"{retrive_next_token=}")
391
+ print(f"{retrive_next_sibling=}")
392
+ print(f"{draft_tokens=}")
354
393
  assert position.tolist() == [5, 6, 6, 7, 7, 8, 8, 9, 10, 11, 12, 12, 12, 12, 13, 14]
355
394
  assert retrive_index.tolist() == [
356
395
  [0, 1, 2, 3, 4, 5, 6, 7],
@@ -140,9 +140,11 @@ class EAGLEWorker(TpModelWorker):
140
140
  self.draft_model_runner.model.set_embed(embed)
141
141
 
142
142
  # grab hot token ids
143
- self.hot_token_id = self.draft_model_runner.model.get_hot_token_id().to(
144
- embed.device
145
- )
143
+ if self.draft_model_runner.model.hot_token_id is not None:
144
+ self.hot_token_id = self.draft_model_runner.model.hot_token_id.to(
145
+ embed.device
146
+ )
147
+
146
148
  else:
147
149
  if self.hot_token_id is not None:
148
150
  head = head.clone()
@@ -842,7 +844,7 @@ class EAGLEWorker(TpModelWorker):
842
844
  )
843
845
  batch.return_hidden_states = False
844
846
  model_worker_batch = batch.get_model_worker_batch()
845
- model_worker_batch.spec_num_draft_tokens = self.speculative_num_draft_tokens
847
+ model_worker_batch.spec_num_draft_tokens = self.speculative_num_steps + 1
846
848
  assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
847
849
  forward_batch = ForwardBatch.init_new(
848
850
  model_worker_batch, self.draft_model_runner