sglang 0.4.8.post1__py3-none-any.whl → 0.4.9.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +17 -2
- sglang/bench_serving.py +170 -24
- sglang/srt/configs/internvl.py +4 -2
- sglang/srt/configs/janus_pro.py +1 -1
- sglang/srt/configs/model_config.py +60 -1
- sglang/srt/configs/update_config.py +119 -0
- sglang/srt/conversation.py +69 -1
- sglang/srt/disaggregation/decode.py +21 -5
- sglang/srt/disaggregation/mooncake/conn.py +35 -4
- sglang/srt/disaggregation/nixl/conn.py +6 -6
- sglang/srt/disaggregation/prefill.py +2 -2
- sglang/srt/disaggregation/utils.py +1 -1
- sglang/srt/distributed/parallel_state.py +44 -17
- sglang/srt/entrypoints/EngineBase.py +8 -0
- sglang/srt/entrypoints/engine.py +40 -6
- sglang/srt/entrypoints/http_server.py +111 -24
- sglang/srt/entrypoints/http_server_engine.py +1 -1
- sglang/srt/entrypoints/openai/protocol.py +4 -2
- sglang/srt/eplb/__init__.py +0 -0
- sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
- sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
- sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
- sglang/srt/{managers → eplb}/expert_distribution.py +1 -5
- sglang/srt/{managers → eplb}/expert_location.py +1 -1
- sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
- sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
- sglang/srt/hf_transformers_utils.py +2 -1
- sglang/srt/layers/activation.py +2 -2
- sglang/srt/layers/amx_utils.py +86 -0
- sglang/srt/layers/attention/ascend_backend.py +219 -0
- sglang/srt/layers/attention/flashattention_backend.py +32 -9
- sglang/srt/layers/attention/tbo_backend.py +37 -9
- sglang/srt/layers/communicator.py +20 -2
- sglang/srt/layers/dp_attention.py +9 -3
- sglang/srt/layers/elementwise.py +76 -12
- sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
- sglang/srt/layers/layernorm.py +26 -0
- sglang/srt/layers/linear.py +84 -14
- sglang/srt/layers/logits_processor.py +4 -4
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +215 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +81 -8
- sglang/srt/layers/moe/ep_moe/layer.py +176 -15
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +23 -17
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +3 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +211 -74
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +176 -0
- sglang/srt/layers/moe/router.py +60 -22
- sglang/srt/layers/moe/topk.py +10 -28
- sglang/srt/layers/parameter.py +67 -7
- sglang/srt/layers/quantization/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
- sglang/srt/layers/quantization/fp8.py +72 -7
- sglang/srt/layers/quantization/fp8_kernel.py +1 -1
- sglang/srt/layers/quantization/fp8_utils.py +1 -2
- sglang/srt/layers/quantization/gptq.py +5 -1
- sglang/srt/layers/quantization/modelopt_quant.py +244 -1
- sglang/srt/layers/quantization/moe_wna16.py +1 -1
- sglang/srt/layers/quantization/quant_utils.py +166 -0
- sglang/srt/layers/quantization/w4afp8.py +264 -0
- sglang/srt/layers/quantization/w8a8_int8.py +52 -1
- sglang/srt/layers/rotary_embedding.py +2 -2
- sglang/srt/layers/vocab_parallel_embedding.py +20 -10
- sglang/srt/lora/lora.py +4 -5
- sglang/srt/lora/lora_manager.py +73 -20
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +30 -19
- sglang/srt/lora/triton_ops/qkv_lora_b.py +30 -19
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +27 -11
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +27 -15
- sglang/srt/managers/cache_controller.py +41 -195
- sglang/srt/managers/configure_logging.py +1 -1
- sglang/srt/managers/io_struct.py +58 -14
- sglang/srt/managers/mm_utils.py +77 -61
- sglang/srt/managers/multimodal_processor.py +2 -6
- sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
- sglang/srt/managers/schedule_batch.py +78 -85
- sglang/srt/managers/scheduler.py +130 -64
- sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
- sglang/srt/managers/session_controller.py +12 -3
- sglang/srt/managers/tokenizer_manager.py +314 -103
- sglang/srt/managers/tp_worker.py +13 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
- sglang/srt/mem_cache/allocator.py +290 -0
- sglang/srt/mem_cache/chunk_cache.py +34 -2
- sglang/srt/mem_cache/hiradix_cache.py +2 -0
- sglang/srt/mem_cache/memory_pool.py +402 -66
- sglang/srt/mem_cache/memory_pool_host.py +6 -109
- sglang/srt/mem_cache/multimodal_cache.py +3 -0
- sglang/srt/mem_cache/radix_cache.py +8 -4
- sglang/srt/model_executor/cuda_graph_runner.py +2 -1
- sglang/srt/model_executor/forward_batch_info.py +17 -4
- sglang/srt/model_executor/model_runner.py +297 -56
- sglang/srt/model_loader/loader.py +41 -0
- sglang/srt/model_loader/weight_utils.py +72 -4
- sglang/srt/models/deepseek_nextn.py +1 -3
- sglang/srt/models/deepseek_v2.py +195 -45
- sglang/srt/models/deepseek_vl2.py +3 -5
- sglang/srt/models/gemma3_causal.py +1 -2
- sglang/srt/models/gemma3n_causal.py +4 -3
- sglang/srt/models/gemma3n_mm.py +4 -20
- sglang/srt/models/hunyuan.py +1 -1
- sglang/srt/models/kimi_vl.py +1 -2
- sglang/srt/models/llama.py +10 -4
- sglang/srt/models/llama4.py +32 -45
- sglang/srt/models/llama_eagle3.py +61 -11
- sglang/srt/models/llava.py +5 -5
- sglang/srt/models/minicpmo.py +2 -2
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mllama4.py +402 -89
- sglang/srt/models/phi4mm.py +1 -3
- sglang/srt/models/pixtral.py +3 -7
- sglang/srt/models/qwen2.py +31 -3
- sglang/srt/models/qwen2_5_vl.py +1 -3
- sglang/srt/models/qwen2_audio.py +200 -0
- sglang/srt/models/qwen2_moe.py +32 -6
- sglang/srt/models/qwen2_vl.py +1 -4
- sglang/srt/models/qwen3.py +94 -25
- sglang/srt/models/qwen3_moe.py +68 -21
- sglang/srt/models/vila.py +3 -8
- sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +2 -2
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +140 -158
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3n.py +5 -20
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +65 -66
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
- sglang/srt/operations_strategy.py +6 -2
- sglang/srt/reasoning_parser.py +26 -0
- sglang/srt/sampling/sampling_batch_info.py +39 -1
- sglang/srt/server_args.py +84 -22
- sglang/srt/speculative/build_eagle_tree.py +57 -18
- sglang/srt/speculative/eagle_worker.py +6 -4
- sglang/srt/two_batch_overlap.py +203 -27
- sglang/srt/utils.py +343 -163
- sglang/srt/warmup.py +12 -3
- sglang/test/runners.py +10 -1
- sglang/test/test_cutlass_w4a8_moe.py +281 -0
- sglang/test/test_utils.py +15 -3
- sglang/utils.py +5 -5
- sglang/version.py +1 -1
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/METADATA +12 -8
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/RECORD +157 -146
- sglang/math_utils.py +0 -8
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
- /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/WHEEL +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.post1.dist-info}/top_level.txt +0 -0
sglang/srt/server_args.py
CHANGED
@@ -20,7 +20,7 @@ import logging
|
|
20
20
|
import os
|
21
21
|
import random
|
22
22
|
import tempfile
|
23
|
-
from typing import List, Literal, Optional
|
23
|
+
from typing import List, Literal, Optional, Union
|
24
24
|
|
25
25
|
from sglang.srt.hf_transformers_utils import check_gguf_file, get_config
|
26
26
|
from sglang.srt.reasoning_parser import ReasoningParser
|
@@ -46,6 +46,7 @@ class ServerArgs:
|
|
46
46
|
tokenizer_path: Optional[str] = None
|
47
47
|
tokenizer_mode: str = "auto"
|
48
48
|
skip_tokenizer_init: bool = False
|
49
|
+
skip_server_warmup: bool = False
|
49
50
|
load_format: str = "auto"
|
50
51
|
model_loader_extra_config: str = "{}"
|
51
52
|
trust_remote_code: bool = False
|
@@ -61,11 +62,13 @@ class ServerArgs:
|
|
61
62
|
is_embedding: bool = False
|
62
63
|
enable_multimodal: Optional[bool] = None
|
63
64
|
revision: Optional[str] = None
|
65
|
+
hybrid_kvcache_ratio: Optional[float] = None
|
64
66
|
impl: str = "auto"
|
65
67
|
|
66
68
|
# Port for the HTTP server
|
67
69
|
host: str = "127.0.0.1"
|
68
70
|
port: int = 30000
|
71
|
+
nccl_port: Optional[int] = None
|
69
72
|
|
70
73
|
# Memory and scheduling
|
71
74
|
mem_fraction_static: Optional[float] = None
|
@@ -98,6 +101,7 @@ class ServerArgs:
|
|
98
101
|
log_level_http: Optional[str] = None
|
99
102
|
log_requests: bool = False
|
100
103
|
log_requests_level: int = 0
|
104
|
+
crash_dump_folder: Optional[str] = None
|
101
105
|
show_time_cost: bool = False
|
102
106
|
enable_metrics: bool = False
|
103
107
|
bucket_time_to_first_token: Optional[List[float]] = None
|
@@ -129,7 +133,7 @@ class ServerArgs:
|
|
129
133
|
preferred_sampling_params: Optional[str] = None
|
130
134
|
|
131
135
|
# LoRA
|
132
|
-
lora_paths: Optional[List[str]] = None
|
136
|
+
lora_paths: Optional[Union[dict[str, str], List[str]]] = None
|
133
137
|
max_loras_per_batch: int = 8
|
134
138
|
lora_backend: str = "triton"
|
135
139
|
|
@@ -154,6 +158,7 @@ class ServerArgs:
|
|
154
158
|
enable_ep_moe: bool = False
|
155
159
|
enable_deepep_moe: bool = False
|
156
160
|
enable_flashinfer_moe: bool = False
|
161
|
+
enable_flashinfer_allreduce_fusion: bool = False
|
157
162
|
deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto"
|
158
163
|
ep_num_redundant_experts: int = 0
|
159
164
|
ep_dispatch_algorithm: Optional[Literal["static", "dynamic", "fake"]] = None
|
@@ -212,11 +217,13 @@ class ServerArgs:
|
|
212
217
|
hicache_ratio: float = 2.0
|
213
218
|
hicache_size: int = 0
|
214
219
|
hicache_write_policy: str = "write_through_selective"
|
220
|
+
hicache_io_backend: str = ""
|
215
221
|
flashinfer_mla_disable_ragged: bool = False
|
216
222
|
disable_shared_experts_fusion: bool = False
|
217
223
|
disable_chunked_prefix_cache: bool = False
|
218
224
|
disable_fast_image_processor: bool = False
|
219
225
|
enable_return_hidden_states: bool = False
|
226
|
+
enable_triton_kernel_moe: bool = False
|
220
227
|
warmups: Optional[str] = None
|
221
228
|
|
222
229
|
# Debug tensor dumps
|
@@ -315,6 +322,14 @@ class ServerArgs:
|
|
315
322
|
else:
|
316
323
|
self.mem_fraction_static = 0.88
|
317
324
|
|
325
|
+
# Lazy init to avoid circular import
|
326
|
+
from sglang.srt.configs.model_config import ModelConfig
|
327
|
+
|
328
|
+
# Multimodal models need more memory for the image processor
|
329
|
+
model_config = ModelConfig.from_server_args(self)
|
330
|
+
if model_config.is_multimodal:
|
331
|
+
self.mem_fraction_static *= 0.90
|
332
|
+
|
318
333
|
# Set chunked prefill size, which depends on the gpu memory capacity
|
319
334
|
if self.chunked_prefill_size is None:
|
320
335
|
if gpu_mem is not None:
|
@@ -376,6 +391,12 @@ class ServerArgs:
|
|
376
391
|
)
|
377
392
|
self.disable_cuda_graph = True
|
378
393
|
|
394
|
+
if self.attention_backend == "ascend":
|
395
|
+
logger.warning(
|
396
|
+
"At this moment Ascend attention backend only supports a page_size of 128, change page_size to 128."
|
397
|
+
)
|
398
|
+
self.page_size = 128
|
399
|
+
|
379
400
|
# Choose grammar backend
|
380
401
|
if self.grammar_backend is None:
|
381
402
|
self.grammar_backend = "xgrammar"
|
@@ -399,10 +420,6 @@ class ServerArgs:
|
|
399
420
|
|
400
421
|
# DeepEP MoE
|
401
422
|
if self.enable_deepep_moe:
|
402
|
-
if self.deepep_mode == "auto":
|
403
|
-
assert (
|
404
|
-
not self.enable_dp_attention
|
405
|
-
), "DeepEP MoE `auto` mode is not supported with DP Attention."
|
406
423
|
if self.deepep_mode == "normal":
|
407
424
|
logger.warning("Cuda graph is disabled because deepep_mode=`normal`")
|
408
425
|
self.disable_cuda_graph = True
|
@@ -485,12 +502,6 @@ class ServerArgs:
|
|
485
502
|
self.speculative_num_draft_tokens,
|
486
503
|
) = auto_choose_speculative_params(self)
|
487
504
|
|
488
|
-
if self.page_size > 1 and self.speculative_eagle_topk > 1:
|
489
|
-
self.speculative_eagle_topk = 1
|
490
|
-
logger.warning(
|
491
|
-
"speculative_eagle_topk is adjusted to 1 when page_size > 1"
|
492
|
-
)
|
493
|
-
|
494
505
|
if (
|
495
506
|
self.speculative_eagle_topk == 1
|
496
507
|
and self.speculative_num_draft_tokens != self.speculative_num_steps + 1
|
@@ -587,6 +598,12 @@ class ServerArgs:
|
|
587
598
|
default=ServerArgs.port,
|
588
599
|
help="The port of the HTTP server.",
|
589
600
|
)
|
601
|
+
parser.add_argument(
|
602
|
+
"--nccl-port",
|
603
|
+
type=int,
|
604
|
+
default=ServerArgs.nccl_port,
|
605
|
+
help="The port for NCCL distributed environment setup. Defaults to a random port.",
|
606
|
+
)
|
590
607
|
parser.add_argument(
|
591
608
|
"--tokenizer-mode",
|
592
609
|
type=str,
|
@@ -601,6 +618,11 @@ class ServerArgs:
|
|
601
618
|
action="store_true",
|
602
619
|
help="If set, skip init tokenizer and pass input_ids in generate request.",
|
603
620
|
)
|
621
|
+
parser.add_argument(
|
622
|
+
"--skip-server-warmup",
|
623
|
+
action="store_true",
|
624
|
+
help="If set, skip warmup.",
|
625
|
+
)
|
604
626
|
parser.add_argument(
|
605
627
|
"--load-format",
|
606
628
|
type=str,
|
@@ -686,6 +708,7 @@ class ServerArgs:
|
|
686
708
|
"w8a8_fp8",
|
687
709
|
"moe_wna16",
|
688
710
|
"qoq",
|
711
|
+
"w4afp8",
|
689
712
|
],
|
690
713
|
help="The quantization method.",
|
691
714
|
)
|
@@ -817,6 +840,18 @@ class ServerArgs:
|
|
817
840
|
default=ServerArgs.page_size,
|
818
841
|
help="The number of tokens in a page.",
|
819
842
|
)
|
843
|
+
parser.add_argument(
|
844
|
+
"--hybrid-kvcache-ratio",
|
845
|
+
nargs="?",
|
846
|
+
const=0.5,
|
847
|
+
type=float,
|
848
|
+
default=ServerArgs.hybrid_kvcache_ratio,
|
849
|
+
help=(
|
850
|
+
"Mix ratio in [0,1] between uniform and hybrid kv buffers "
|
851
|
+
"(0.0 = pure uniform: swa_size / full_size = 1)"
|
852
|
+
"(1.0 = pure hybrid: swa_size / full_size = local_attention_size / context_length)"
|
853
|
+
),
|
854
|
+
)
|
820
855
|
|
821
856
|
# Other runtime options
|
822
857
|
parser.add_argument(
|
@@ -920,8 +955,14 @@ class ServerArgs:
|
|
920
955
|
"--log-requests-level",
|
921
956
|
type=int,
|
922
957
|
default=0,
|
923
|
-
help="0: Log metadata. 1. Log metadata and partial input/output.
|
924
|
-
choices=[0, 1, 2],
|
958
|
+
help="0: Log metadata (no sampling parameters). 1: Log metadata and sampling parameters. 2: Log metadata, sampling parameters and partial input/output. 3: Log every input/output.",
|
959
|
+
choices=[0, 1, 2, 3],
|
960
|
+
)
|
961
|
+
parser.add_argument(
|
962
|
+
"--crash-dump-folder",
|
963
|
+
type=str,
|
964
|
+
default=ServerArgs.crash_dump_folder,
|
965
|
+
help="Folder path to dump requests from the last 5 min before a crash (if any). If not specified, crash dumping is disabled.",
|
925
966
|
)
|
926
967
|
parser.add_argument(
|
927
968
|
"--show-time-cost",
|
@@ -1092,6 +1133,7 @@ class ServerArgs:
|
|
1092
1133
|
"flashmla",
|
1093
1134
|
"intel_amx",
|
1094
1135
|
"torch_native",
|
1136
|
+
"ascend",
|
1095
1137
|
"triton",
|
1096
1138
|
],
|
1097
1139
|
default=ServerArgs.attention_backend,
|
@@ -1186,6 +1228,11 @@ class ServerArgs:
|
|
1186
1228
|
action="store_true",
|
1187
1229
|
help="Enable FlashInfer CUTLASS MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP with --enable-ep-moe",
|
1188
1230
|
)
|
1231
|
+
parser.add_argument(
|
1232
|
+
"--enable-flashinfer-allreduce-fusion",
|
1233
|
+
action="store_true",
|
1234
|
+
help="Enable FlashInfer allreduce fusion for Add_RMSNorm.",
|
1235
|
+
)
|
1189
1236
|
parser.add_argument(
|
1190
1237
|
"--enable-deepep-moe",
|
1191
1238
|
action="store_true",
|
@@ -1485,6 +1532,13 @@ class ServerArgs:
|
|
1485
1532
|
default=ServerArgs.hicache_write_policy,
|
1486
1533
|
help="The write policy of hierarchical cache.",
|
1487
1534
|
)
|
1535
|
+
parser.add_argument(
|
1536
|
+
"--hicache-io-backend",
|
1537
|
+
type=str,
|
1538
|
+
choices=["direct", "kernel"],
|
1539
|
+
default=ServerArgs.hicache_io_backend,
|
1540
|
+
help="The IO backend for KV cache transfer between CPU and GPU",
|
1541
|
+
)
|
1488
1542
|
parser.add_argument(
|
1489
1543
|
"--flashinfer-mla-disable-ragged",
|
1490
1544
|
action="store_true",
|
@@ -1510,6 +1564,11 @@ class ServerArgs:
|
|
1510
1564
|
action="store_true",
|
1511
1565
|
help="Enable returning hidden states with responses.",
|
1512
1566
|
)
|
1567
|
+
parser.add_argument(
|
1568
|
+
"--enable-triton-kernel-moe",
|
1569
|
+
action="store_true",
|
1570
|
+
help="Use triton moe grouped gemm kernel.",
|
1571
|
+
)
|
1513
1572
|
parser.add_argument(
|
1514
1573
|
"--warmups",
|
1515
1574
|
type=str,
|
@@ -1706,14 +1765,17 @@ class PortArgs:
|
|
1706
1765
|
|
1707
1766
|
@staticmethod
|
1708
1767
|
def init_new(server_args, dp_rank: Optional[int] = None) -> "PortArgs":
|
1709
|
-
|
1710
|
-
|
1711
|
-
|
1712
|
-
|
1713
|
-
|
1714
|
-
port
|
1715
|
-
|
1716
|
-
|
1768
|
+
if server_args.nccl_port is None:
|
1769
|
+
port = server_args.port + random.randint(100, 1000)
|
1770
|
+
while True:
|
1771
|
+
if is_port_available(port):
|
1772
|
+
break
|
1773
|
+
if port < 60000:
|
1774
|
+
port += 42
|
1775
|
+
else:
|
1776
|
+
port -= 43
|
1777
|
+
else:
|
1778
|
+
port = server_args.nccl_port
|
1717
1779
|
|
1718
1780
|
if not server_args.enable_dp_attention:
|
1719
1781
|
# Normal case, use IPC within a single node
|
@@ -1,10 +1,12 @@
|
|
1
1
|
# NOTE: Please run this file to make sure the test cases are correct.
|
2
2
|
|
3
|
-
|
3
|
+
import math
|
4
|
+
from enum import IntEnum
|
5
|
+
from typing import List, Optional
|
4
6
|
|
5
7
|
import torch
|
6
8
|
|
7
|
-
from sglang.srt.utils import is_cuda, is_hip
|
9
|
+
from sglang.srt.utils import is_cuda, is_hip
|
8
10
|
|
9
11
|
if is_cuda() or is_hip():
|
10
12
|
from sgl_kernel import (
|
@@ -40,6 +42,12 @@ def build_tree_kernel_efficient_preprocess(
|
|
40
42
|
return parent_list, top_scores_index, draft_tokens
|
41
43
|
|
42
44
|
|
45
|
+
class TreeMaskMode(IntEnum):
|
46
|
+
FULL_MASK = 0
|
47
|
+
QLEN_ONLY = 1
|
48
|
+
QLEN_ONLY_BITPACKING = 2
|
49
|
+
|
50
|
+
|
43
51
|
def build_tree_kernel_efficient(
|
44
52
|
verified_id: torch.Tensor,
|
45
53
|
score_list: List[torch.Tensor],
|
@@ -50,6 +58,9 @@ def build_tree_kernel_efficient(
|
|
50
58
|
topk: int,
|
51
59
|
spec_steps: int,
|
52
60
|
num_verify_tokens: int,
|
61
|
+
tree_mask_mode: TreeMaskMode = TreeMaskMode.FULL_MASK,
|
62
|
+
tree_mask_buf: Optional[torch.Tensor] = None,
|
63
|
+
position_buf: Optional[torch.Tensor] = None,
|
53
64
|
):
|
54
65
|
parent_list, top_scores_index, draft_tokens = (
|
55
66
|
build_tree_kernel_efficient_preprocess(
|
@@ -66,15 +77,37 @@ def build_tree_kernel_efficient(
|
|
66
77
|
device = seq_lens.device
|
67
78
|
# e.g. for bs=1, tree_mask: num_draft_token, seq_lens_sum + num_draft_token (flattened)
|
68
79
|
# where each row indicates the attending pattern of each draft token
|
80
|
+
# if use_partial_packed_tree_mask is True, tree_mask: num_draft_token (flattened, packed)
|
81
|
+
if tree_mask_buf is not None:
|
82
|
+
tree_mask = tree_mask_buf
|
83
|
+
elif tree_mask_mode == TreeMaskMode.QLEN_ONLY:
|
84
|
+
tree_mask = torch.full(
|
85
|
+
(num_verify_tokens * bs * num_verify_tokens,),
|
86
|
+
True,
|
87
|
+
dtype=torch.bool,
|
88
|
+
device=device,
|
89
|
+
)
|
90
|
+
elif tree_mask_mode == TreeMaskMode.QLEN_ONLY_BITPACKING:
|
91
|
+
packed_dtypes = [torch.uint8, torch.uint16, torch.uint32]
|
92
|
+
packed_dtype_idx = int(math.ceil(math.log2((num_verify_tokens + 7) // 8)))
|
93
|
+
tree_mask = torch.zeros(
|
94
|
+
(num_verify_tokens * bs,),
|
95
|
+
dtype=packed_dtypes[packed_dtype_idx],
|
96
|
+
device=device,
|
97
|
+
)
|
98
|
+
elif tree_mask_mode == TreeMaskMode.FULL_MASK:
|
99
|
+
tree_mask = torch.full(
|
100
|
+
(
|
101
|
+
seq_lens_sum * num_verify_tokens
|
102
|
+
+ num_verify_tokens * num_verify_tokens * bs,
|
103
|
+
),
|
104
|
+
True,
|
105
|
+
device=device,
|
106
|
+
)
|
107
|
+
else:
|
108
|
+
raise NotImplementedError(f"Invalid tree mask: {tree_mask_mode=}")
|
109
|
+
|
69
110
|
# TODO: make them torch.empty and fuse them into `sgl_build_tree_kernel`
|
70
|
-
tree_mask = torch.full(
|
71
|
-
(
|
72
|
-
seq_lens_sum * num_verify_tokens
|
73
|
-
+ num_verify_tokens * num_verify_tokens * bs,
|
74
|
-
),
|
75
|
-
True,
|
76
|
-
device=device,
|
77
|
-
)
|
78
111
|
retrive_index = torch.full(
|
79
112
|
(bs, num_verify_tokens), -1, device=device, dtype=torch.long
|
80
113
|
)
|
@@ -87,7 +120,12 @@ def build_tree_kernel_efficient(
|
|
87
120
|
# position: where each token belongs to
|
88
121
|
# e.g. if depth of each draft token is [0, 1, 1, 2] and the prompt length is 7
|
89
122
|
# then, positions = [7, 8, 8, 9]
|
90
|
-
|
123
|
+
if position_buf is not None:
|
124
|
+
positions = position_buf
|
125
|
+
else:
|
126
|
+
positions = torch.empty(
|
127
|
+
(bs * num_verify_tokens,), device=device, dtype=torch.long
|
128
|
+
)
|
91
129
|
|
92
130
|
sgl_build_tree_kernel_efficient(
|
93
131
|
parent_list,
|
@@ -101,6 +139,7 @@ def build_tree_kernel_efficient(
|
|
101
139
|
topk,
|
102
140
|
spec_steps,
|
103
141
|
num_verify_tokens,
|
142
|
+
tree_mask_mode,
|
104
143
|
)
|
105
144
|
return (
|
106
145
|
tree_mask,
|
@@ -344,13 +383,13 @@ def test_build_tree_kernel_efficient():
|
|
344
383
|
num_verify_tokens=num_draft_token,
|
345
384
|
)
|
346
385
|
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
386
|
+
print("=========== build tree kernel efficient ==========")
|
387
|
+
print(f"{tree_mask=}")
|
388
|
+
print(f"{position=}")
|
389
|
+
print(f"{retrive_index=}")
|
390
|
+
print(f"{retrive_next_token=}")
|
391
|
+
print(f"{retrive_next_sibling=}")
|
392
|
+
print(f"{draft_tokens=}")
|
354
393
|
assert position.tolist() == [5, 6, 6, 7, 7, 8, 8, 9, 10, 11, 12, 12, 12, 12, 13, 14]
|
355
394
|
assert retrive_index.tolist() == [
|
356
395
|
[0, 1, 2, 3, 4, 5, 6, 7],
|
@@ -140,9 +140,11 @@ class EAGLEWorker(TpModelWorker):
|
|
140
140
|
self.draft_model_runner.model.set_embed(embed)
|
141
141
|
|
142
142
|
# grab hot token ids
|
143
|
-
|
144
|
-
|
145
|
-
|
143
|
+
if self.draft_model_runner.model.hot_token_id is not None:
|
144
|
+
self.hot_token_id = self.draft_model_runner.model.hot_token_id.to(
|
145
|
+
embed.device
|
146
|
+
)
|
147
|
+
|
146
148
|
else:
|
147
149
|
if self.hot_token_id is not None:
|
148
150
|
head = head.clone()
|
@@ -842,7 +844,7 @@ class EAGLEWorker(TpModelWorker):
|
|
842
844
|
)
|
843
845
|
batch.return_hidden_states = False
|
844
846
|
model_worker_batch = batch.get_model_worker_batch()
|
845
|
-
model_worker_batch.spec_num_draft_tokens = self.
|
847
|
+
model_worker_batch.spec_num_draft_tokens = self.speculative_num_steps + 1
|
846
848
|
assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
|
847
849
|
forward_batch = ForwardBatch.init_new(
|
848
850
|
model_worker_batch, self.draft_model_runner
|