sglang 0.4.8.post1__py3-none-any.whl → 0.4.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +17 -2
- sglang/bench_serving.py +168 -22
- sglang/srt/configs/internvl.py +4 -2
- sglang/srt/configs/janus_pro.py +1 -1
- sglang/srt/configs/model_config.py +48 -0
- sglang/srt/configs/update_config.py +119 -0
- sglang/srt/conversation.py +34 -0
- sglang/srt/disaggregation/decode.py +21 -5
- sglang/srt/disaggregation/nixl/conn.py +6 -6
- sglang/srt/disaggregation/prefill.py +2 -2
- sglang/srt/disaggregation/utils.py +1 -1
- sglang/srt/distributed/parallel_state.py +44 -17
- sglang/srt/entrypoints/EngineBase.py +8 -0
- sglang/srt/entrypoints/engine.py +40 -6
- sglang/srt/entrypoints/http_server.py +111 -24
- sglang/srt/entrypoints/openai/protocol.py +4 -2
- sglang/srt/eplb/__init__.py +0 -0
- sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
- sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
- sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
- sglang/srt/{managers → eplb}/expert_distribution.py +1 -5
- sglang/srt/{managers → eplb}/expert_location.py +1 -1
- sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
- sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
- sglang/srt/hf_transformers_utils.py +2 -1
- sglang/srt/layers/activation.py +2 -2
- sglang/srt/layers/amx_utils.py +86 -0
- sglang/srt/layers/attention/ascend_backend.py +219 -0
- sglang/srt/layers/attention/flashattention_backend.py +32 -9
- sglang/srt/layers/attention/tbo_backend.py +37 -9
- sglang/srt/layers/communicator.py +18 -2
- sglang/srt/layers/dp_attention.py +9 -3
- sglang/srt/layers/elementwise.py +76 -12
- sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
- sglang/srt/layers/layernorm.py +26 -0
- sglang/srt/layers/linear.py +84 -14
- sglang/srt/layers/logits_processor.py +4 -4
- sglang/srt/layers/moe/ep_moe/kernels.py +23 -8
- sglang/srt/layers/moe/ep_moe/layer.py +36 -13
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +23 -17
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -2
- sglang/srt/layers/moe/fused_moe_triton/layer.py +76 -16
- sglang/srt/layers/moe/router.py +60 -22
- sglang/srt/layers/moe/topk.py +10 -28
- sglang/srt/layers/parameter.py +67 -7
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
- sglang/srt/layers/quantization/fp8.py +44 -0
- sglang/srt/layers/quantization/fp8_kernel.py +1 -1
- sglang/srt/layers/quantization/fp8_utils.py +1 -2
- sglang/srt/layers/quantization/gptq.py +5 -1
- sglang/srt/layers/quantization/moe_wna16.py +1 -1
- sglang/srt/layers/quantization/quant_utils.py +166 -0
- sglang/srt/layers/quantization/w8a8_int8.py +52 -1
- sglang/srt/layers/rotary_embedding.py +2 -2
- sglang/srt/layers/vocab_parallel_embedding.py +11 -7
- sglang/srt/lora/lora.py +4 -5
- sglang/srt/lora/lora_manager.py +73 -20
- sglang/srt/managers/configure_logging.py +1 -1
- sglang/srt/managers/io_struct.py +50 -13
- sglang/srt/managers/mm_utils.py +73 -59
- sglang/srt/managers/multimodal_processor.py +2 -6
- sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
- sglang/srt/managers/schedule_batch.py +77 -84
- sglang/srt/managers/scheduler.py +113 -59
- sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
- sglang/srt/managers/session_controller.py +12 -3
- sglang/srt/managers/tokenizer_manager.py +314 -103
- sglang/srt/managers/tp_worker.py +13 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
- sglang/srt/mem_cache/allocator.py +290 -0
- sglang/srt/mem_cache/chunk_cache.py +34 -2
- sglang/srt/mem_cache/memory_pool.py +289 -3
- sglang/srt/mem_cache/multimodal_cache.py +3 -0
- sglang/srt/model_executor/cuda_graph_runner.py +2 -1
- sglang/srt/model_executor/forward_batch_info.py +17 -4
- sglang/srt/model_executor/model_runner.py +297 -56
- sglang/srt/model_loader/loader.py +41 -0
- sglang/srt/model_loader/weight_utils.py +72 -4
- sglang/srt/models/deepseek_nextn.py +1 -3
- sglang/srt/models/deepseek_v2.py +181 -45
- sglang/srt/models/deepseek_vl2.py +3 -5
- sglang/srt/models/gemma3_causal.py +1 -2
- sglang/srt/models/gemma3n_causal.py +4 -3
- sglang/srt/models/gemma3n_mm.py +4 -20
- sglang/srt/models/hunyuan.py +1 -1
- sglang/srt/models/kimi_vl.py +1 -2
- sglang/srt/models/llama.py +10 -4
- sglang/srt/models/llama4.py +32 -45
- sglang/srt/models/llama_eagle3.py +61 -11
- sglang/srt/models/llava.py +5 -5
- sglang/srt/models/minicpmo.py +2 -2
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mllama4.py +43 -11
- sglang/srt/models/phi4mm.py +1 -3
- sglang/srt/models/pixtral.py +3 -7
- sglang/srt/models/qwen2.py +31 -3
- sglang/srt/models/qwen2_5_vl.py +1 -3
- sglang/srt/models/qwen2_audio.py +200 -0
- sglang/srt/models/qwen2_moe.py +32 -6
- sglang/srt/models/qwen2_vl.py +1 -4
- sglang/srt/models/qwen3.py +94 -25
- sglang/srt/models/qwen3_moe.py +68 -21
- sglang/srt/models/vila.py +3 -8
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +140 -158
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3n.py +5 -20
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +3 -6
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
- sglang/srt/operations_strategy.py +6 -2
- sglang/srt/reasoning_parser.py +26 -0
- sglang/srt/sampling/sampling_batch_info.py +39 -1
- sglang/srt/server_args.py +69 -22
- sglang/srt/speculative/build_eagle_tree.py +57 -18
- sglang/srt/speculative/eagle_worker.py +6 -4
- sglang/srt/two_batch_overlap.py +200 -27
- sglang/srt/utils.py +306 -146
- sglang/srt/warmup.py +12 -3
- sglang/test/runners.py +10 -1
- sglang/test/test_utils.py +15 -3
- sglang/version.py +1 -1
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/METADATA +9 -6
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/RECORD +140 -133
- sglang/math_utils.py +0 -8
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
- /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
- /sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/WHEEL +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.8.post1.dist-info → sglang-0.4.9.dist-info}/top_level.txt +0 -0
@@ -10,7 +10,6 @@ import torch
|
|
10
10
|
import sglang.srt.sampling.penaltylib as penaltylib
|
11
11
|
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
|
12
12
|
from sglang.srt.sampling.sampling_params import TOP_K_ALL
|
13
|
-
from sglang.srt.utils import merge_bias_tensor
|
14
13
|
|
15
14
|
if TYPE_CHECKING:
|
16
15
|
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
@@ -345,3 +344,42 @@ class SamplingBatchInfo:
|
|
345
344
|
self.logit_bias = merge_bias_tensor(
|
346
345
|
self.logit_bias, other.logit_bias, len(self), len(other), self.device, 0.0
|
347
346
|
)
|
347
|
+
|
348
|
+
|
349
|
+
def merge_bias_tensor(
|
350
|
+
lhs: Optional[torch.Tensor],
|
351
|
+
rhs: Optional[torch.Tensor],
|
352
|
+
bs1: int,
|
353
|
+
bs2: int,
|
354
|
+
device: str,
|
355
|
+
default: float,
|
356
|
+
):
|
357
|
+
"""Merge two bias tensors for batch merging.
|
358
|
+
|
359
|
+
Args:
|
360
|
+
lhs: Left-hand side tensor
|
361
|
+
rhs: Right-hand side tensor
|
362
|
+
bs1: Batch size of left-hand side tensor
|
363
|
+
bs2: Batch size of right-hand side tensor
|
364
|
+
device: Device to place the merged tensor on
|
365
|
+
default: Default value for missing tensor elements
|
366
|
+
|
367
|
+
Returns:
|
368
|
+
Merged tensor or None if both inputs are None
|
369
|
+
"""
|
370
|
+
if lhs is None and rhs is None:
|
371
|
+
return None
|
372
|
+
|
373
|
+
if lhs is not None and rhs is not None:
|
374
|
+
return torch.cat([lhs, rhs])
|
375
|
+
else:
|
376
|
+
if lhs is not None:
|
377
|
+
shape, dtype = lhs.shape[1:], lhs.dtype
|
378
|
+
else:
|
379
|
+
shape, dtype = rhs.shape[1:], rhs.dtype
|
380
|
+
|
381
|
+
if lhs is None:
|
382
|
+
lhs = torch.empty((bs1, *shape), device=device, dtype=dtype).fill_(default)
|
383
|
+
if rhs is None:
|
384
|
+
rhs = torch.empty((bs2, *shape), device=device, dtype=dtype).fill_(default)
|
385
|
+
return torch.cat([lhs, rhs])
|
sglang/srt/server_args.py
CHANGED
@@ -20,7 +20,7 @@ import logging
|
|
20
20
|
import os
|
21
21
|
import random
|
22
22
|
import tempfile
|
23
|
-
from typing import List, Literal, Optional
|
23
|
+
from typing import List, Literal, Optional, Union
|
24
24
|
|
25
25
|
from sglang.srt.hf_transformers_utils import check_gguf_file, get_config
|
26
26
|
from sglang.srt.reasoning_parser import ReasoningParser
|
@@ -46,6 +46,7 @@ class ServerArgs:
|
|
46
46
|
tokenizer_path: Optional[str] = None
|
47
47
|
tokenizer_mode: str = "auto"
|
48
48
|
skip_tokenizer_init: bool = False
|
49
|
+
skip_server_warmup: bool = False
|
49
50
|
load_format: str = "auto"
|
50
51
|
model_loader_extra_config: str = "{}"
|
51
52
|
trust_remote_code: bool = False
|
@@ -61,11 +62,13 @@ class ServerArgs:
|
|
61
62
|
is_embedding: bool = False
|
62
63
|
enable_multimodal: Optional[bool] = None
|
63
64
|
revision: Optional[str] = None
|
65
|
+
hybrid_kvcache_ratio: Optional[float] = None
|
64
66
|
impl: str = "auto"
|
65
67
|
|
66
68
|
# Port for the HTTP server
|
67
69
|
host: str = "127.0.0.1"
|
68
70
|
port: int = 30000
|
71
|
+
nccl_port: Optional[int] = None
|
69
72
|
|
70
73
|
# Memory and scheduling
|
71
74
|
mem_fraction_static: Optional[float] = None
|
@@ -98,6 +101,7 @@ class ServerArgs:
|
|
98
101
|
log_level_http: Optional[str] = None
|
99
102
|
log_requests: bool = False
|
100
103
|
log_requests_level: int = 0
|
104
|
+
crash_dump_folder: Optional[str] = None
|
101
105
|
show_time_cost: bool = False
|
102
106
|
enable_metrics: bool = False
|
103
107
|
bucket_time_to_first_token: Optional[List[float]] = None
|
@@ -129,7 +133,7 @@ class ServerArgs:
|
|
129
133
|
preferred_sampling_params: Optional[str] = None
|
130
134
|
|
131
135
|
# LoRA
|
132
|
-
lora_paths: Optional[List[str]] = None
|
136
|
+
lora_paths: Optional[Union[dict[str, str], List[str]]] = None
|
133
137
|
max_loras_per_batch: int = 8
|
134
138
|
lora_backend: str = "triton"
|
135
139
|
|
@@ -154,6 +158,7 @@ class ServerArgs:
|
|
154
158
|
enable_ep_moe: bool = False
|
155
159
|
enable_deepep_moe: bool = False
|
156
160
|
enable_flashinfer_moe: bool = False
|
161
|
+
enable_flashinfer_allreduce_fusion: bool = False
|
157
162
|
deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto"
|
158
163
|
ep_num_redundant_experts: int = 0
|
159
164
|
ep_dispatch_algorithm: Optional[Literal["static", "dynamic", "fake"]] = None
|
@@ -315,6 +320,14 @@ class ServerArgs:
|
|
315
320
|
else:
|
316
321
|
self.mem_fraction_static = 0.88
|
317
322
|
|
323
|
+
# Lazy init to avoid circular import
|
324
|
+
from sglang.srt.configs.model_config import ModelConfig
|
325
|
+
|
326
|
+
# Multimodal models need more memory for the image processor
|
327
|
+
model_config = ModelConfig.from_server_args(self)
|
328
|
+
if model_config.is_multimodal:
|
329
|
+
self.mem_fraction_static *= 0.90
|
330
|
+
|
318
331
|
# Set chunked prefill size, which depends on the gpu memory capacity
|
319
332
|
if self.chunked_prefill_size is None:
|
320
333
|
if gpu_mem is not None:
|
@@ -376,6 +389,12 @@ class ServerArgs:
|
|
376
389
|
)
|
377
390
|
self.disable_cuda_graph = True
|
378
391
|
|
392
|
+
if self.attention_backend == "ascend":
|
393
|
+
logger.warning(
|
394
|
+
"At this moment Ascend attention backend only supports a page_size of 128, change page_size to 128."
|
395
|
+
)
|
396
|
+
self.page_size = 128
|
397
|
+
|
379
398
|
# Choose grammar backend
|
380
399
|
if self.grammar_backend is None:
|
381
400
|
self.grammar_backend = "xgrammar"
|
@@ -399,10 +418,6 @@ class ServerArgs:
|
|
399
418
|
|
400
419
|
# DeepEP MoE
|
401
420
|
if self.enable_deepep_moe:
|
402
|
-
if self.deepep_mode == "auto":
|
403
|
-
assert (
|
404
|
-
not self.enable_dp_attention
|
405
|
-
), "DeepEP MoE `auto` mode is not supported with DP Attention."
|
406
421
|
if self.deepep_mode == "normal":
|
407
422
|
logger.warning("Cuda graph is disabled because deepep_mode=`normal`")
|
408
423
|
self.disable_cuda_graph = True
|
@@ -485,12 +500,6 @@ class ServerArgs:
|
|
485
500
|
self.speculative_num_draft_tokens,
|
486
501
|
) = auto_choose_speculative_params(self)
|
487
502
|
|
488
|
-
if self.page_size > 1 and self.speculative_eagle_topk > 1:
|
489
|
-
self.speculative_eagle_topk = 1
|
490
|
-
logger.warning(
|
491
|
-
"speculative_eagle_topk is adjusted to 1 when page_size > 1"
|
492
|
-
)
|
493
|
-
|
494
503
|
if (
|
495
504
|
self.speculative_eagle_topk == 1
|
496
505
|
and self.speculative_num_draft_tokens != self.speculative_num_steps + 1
|
@@ -587,6 +596,12 @@ class ServerArgs:
|
|
587
596
|
default=ServerArgs.port,
|
588
597
|
help="The port of the HTTP server.",
|
589
598
|
)
|
599
|
+
parser.add_argument(
|
600
|
+
"--nccl-port",
|
601
|
+
type=int,
|
602
|
+
default=ServerArgs.nccl_port,
|
603
|
+
help="The port for NCCL distributed environment setup. Defaults to a random port.",
|
604
|
+
)
|
590
605
|
parser.add_argument(
|
591
606
|
"--tokenizer-mode",
|
592
607
|
type=str,
|
@@ -601,6 +616,11 @@ class ServerArgs:
|
|
601
616
|
action="store_true",
|
602
617
|
help="If set, skip init tokenizer and pass input_ids in generate request.",
|
603
618
|
)
|
619
|
+
parser.add_argument(
|
620
|
+
"--skip-server-warmup",
|
621
|
+
action="store_true",
|
622
|
+
help="If set, skip warmup.",
|
623
|
+
)
|
604
624
|
parser.add_argument(
|
605
625
|
"--load-format",
|
606
626
|
type=str,
|
@@ -817,6 +837,18 @@ class ServerArgs:
|
|
817
837
|
default=ServerArgs.page_size,
|
818
838
|
help="The number of tokens in a page.",
|
819
839
|
)
|
840
|
+
parser.add_argument(
|
841
|
+
"--hybrid-kvcache-ratio",
|
842
|
+
nargs="?",
|
843
|
+
const=0.5,
|
844
|
+
type=float,
|
845
|
+
default=ServerArgs.hybrid_kvcache_ratio,
|
846
|
+
help=(
|
847
|
+
"Mix ratio in [0,1] between uniform and hybrid kv buffers "
|
848
|
+
"(0.0 = pure uniform: swa_size / full_size = 1)"
|
849
|
+
"(1.0 = pure hybrid: swa_size / full_size = local_attention_size / context_length)"
|
850
|
+
),
|
851
|
+
)
|
820
852
|
|
821
853
|
# Other runtime options
|
822
854
|
parser.add_argument(
|
@@ -920,8 +952,14 @@ class ServerArgs:
|
|
920
952
|
"--log-requests-level",
|
921
953
|
type=int,
|
922
954
|
default=0,
|
923
|
-
help="0: Log metadata. 1. Log metadata and partial input/output.
|
924
|
-
choices=[0, 1, 2],
|
955
|
+
help="0: Log metadata (no sampling parameters). 1: Log metadata and sampling parameters. 2: Log metadata, sampling parameters and partial input/output. 3: Log every input/output.",
|
956
|
+
choices=[0, 1, 2, 3],
|
957
|
+
)
|
958
|
+
parser.add_argument(
|
959
|
+
"--crash-dump-folder",
|
960
|
+
type=str,
|
961
|
+
default=ServerArgs.crash_dump_folder,
|
962
|
+
help="Folder path to dump requests from the last 5 min before a crash (if any). If not specified, crash dumping is disabled.",
|
925
963
|
)
|
926
964
|
parser.add_argument(
|
927
965
|
"--show-time-cost",
|
@@ -1092,6 +1130,7 @@ class ServerArgs:
|
|
1092
1130
|
"flashmla",
|
1093
1131
|
"intel_amx",
|
1094
1132
|
"torch_native",
|
1133
|
+
"ascend",
|
1095
1134
|
"triton",
|
1096
1135
|
],
|
1097
1136
|
default=ServerArgs.attention_backend,
|
@@ -1186,6 +1225,11 @@ class ServerArgs:
|
|
1186
1225
|
action="store_true",
|
1187
1226
|
help="Enable FlashInfer CUTLASS MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP with --enable-ep-moe",
|
1188
1227
|
)
|
1228
|
+
parser.add_argument(
|
1229
|
+
"--enable-flashinfer-allreduce-fusion",
|
1230
|
+
action="store_true",
|
1231
|
+
help="Enable FlashInfer allreduce fusion for Add_RMSNorm.",
|
1232
|
+
)
|
1189
1233
|
parser.add_argument(
|
1190
1234
|
"--enable-deepep-moe",
|
1191
1235
|
action="store_true",
|
@@ -1706,14 +1750,17 @@ class PortArgs:
|
|
1706
1750
|
|
1707
1751
|
@staticmethod
|
1708
1752
|
def init_new(server_args, dp_rank: Optional[int] = None) -> "PortArgs":
|
1709
|
-
|
1710
|
-
|
1711
|
-
|
1712
|
-
|
1713
|
-
|
1714
|
-
port
|
1715
|
-
|
1716
|
-
|
1753
|
+
if server_args.nccl_port is None:
|
1754
|
+
port = server_args.port + random.randint(100, 1000)
|
1755
|
+
while True:
|
1756
|
+
if is_port_available(port):
|
1757
|
+
break
|
1758
|
+
if port < 60000:
|
1759
|
+
port += 42
|
1760
|
+
else:
|
1761
|
+
port -= 43
|
1762
|
+
else:
|
1763
|
+
port = server_args.nccl_port
|
1717
1764
|
|
1718
1765
|
if not server_args.enable_dp_attention:
|
1719
1766
|
# Normal case, use IPC within a single node
|
@@ -1,10 +1,12 @@
|
|
1
1
|
# NOTE: Please run this file to make sure the test cases are correct.
|
2
2
|
|
3
|
-
|
3
|
+
import math
|
4
|
+
from enum import IntEnum
|
5
|
+
from typing import List, Optional
|
4
6
|
|
5
7
|
import torch
|
6
8
|
|
7
|
-
from sglang.srt.utils import is_cuda, is_hip
|
9
|
+
from sglang.srt.utils import is_cuda, is_hip
|
8
10
|
|
9
11
|
if is_cuda() or is_hip():
|
10
12
|
from sgl_kernel import (
|
@@ -40,6 +42,12 @@ def build_tree_kernel_efficient_preprocess(
|
|
40
42
|
return parent_list, top_scores_index, draft_tokens
|
41
43
|
|
42
44
|
|
45
|
+
class TreeMaskMode(IntEnum):
|
46
|
+
FULL_MASK = 0
|
47
|
+
QLEN_ONLY = 1
|
48
|
+
QLEN_ONLY_BITPACKING = 2
|
49
|
+
|
50
|
+
|
43
51
|
def build_tree_kernel_efficient(
|
44
52
|
verified_id: torch.Tensor,
|
45
53
|
score_list: List[torch.Tensor],
|
@@ -50,6 +58,9 @@ def build_tree_kernel_efficient(
|
|
50
58
|
topk: int,
|
51
59
|
spec_steps: int,
|
52
60
|
num_verify_tokens: int,
|
61
|
+
tree_mask_mode: TreeMaskMode = TreeMaskMode.FULL_MASK,
|
62
|
+
tree_mask_buf: Optional[torch.Tensor] = None,
|
63
|
+
position_buf: Optional[torch.Tensor] = None,
|
53
64
|
):
|
54
65
|
parent_list, top_scores_index, draft_tokens = (
|
55
66
|
build_tree_kernel_efficient_preprocess(
|
@@ -66,15 +77,37 @@ def build_tree_kernel_efficient(
|
|
66
77
|
device = seq_lens.device
|
67
78
|
# e.g. for bs=1, tree_mask: num_draft_token, seq_lens_sum + num_draft_token (flattened)
|
68
79
|
# where each row indicates the attending pattern of each draft token
|
80
|
+
# if use_partial_packed_tree_mask is True, tree_mask: num_draft_token (flattened, packed)
|
81
|
+
if tree_mask_buf is not None:
|
82
|
+
tree_mask = tree_mask_buf
|
83
|
+
elif tree_mask_mode == TreeMaskMode.QLEN_ONLY:
|
84
|
+
tree_mask = torch.full(
|
85
|
+
(num_verify_tokens * bs * num_verify_tokens,),
|
86
|
+
True,
|
87
|
+
dtype=torch.bool,
|
88
|
+
device=device,
|
89
|
+
)
|
90
|
+
elif tree_mask_mode == TreeMaskMode.QLEN_ONLY_BITPACKING:
|
91
|
+
packed_dtypes = [torch.uint8, torch.uint16, torch.uint32]
|
92
|
+
packed_dtype_idx = int(math.ceil(math.log2((num_verify_tokens + 7) // 8)))
|
93
|
+
tree_mask = torch.zeros(
|
94
|
+
(num_verify_tokens * bs,),
|
95
|
+
dtype=packed_dtypes[packed_dtype_idx],
|
96
|
+
device=device,
|
97
|
+
)
|
98
|
+
elif tree_mask_mode == TreeMaskMode.FULL_MASK:
|
99
|
+
tree_mask = torch.full(
|
100
|
+
(
|
101
|
+
seq_lens_sum * num_verify_tokens
|
102
|
+
+ num_verify_tokens * num_verify_tokens * bs,
|
103
|
+
),
|
104
|
+
True,
|
105
|
+
device=device,
|
106
|
+
)
|
107
|
+
else:
|
108
|
+
raise NotImplementedError(f"Invalid tree mask: {tree_mask_mode=}")
|
109
|
+
|
69
110
|
# TODO: make them torch.empty and fuse them into `sgl_build_tree_kernel`
|
70
|
-
tree_mask = torch.full(
|
71
|
-
(
|
72
|
-
seq_lens_sum * num_verify_tokens
|
73
|
-
+ num_verify_tokens * num_verify_tokens * bs,
|
74
|
-
),
|
75
|
-
True,
|
76
|
-
device=device,
|
77
|
-
)
|
78
111
|
retrive_index = torch.full(
|
79
112
|
(bs, num_verify_tokens), -1, device=device, dtype=torch.long
|
80
113
|
)
|
@@ -87,7 +120,12 @@ def build_tree_kernel_efficient(
|
|
87
120
|
# position: where each token belongs to
|
88
121
|
# e.g. if depth of each draft token is [0, 1, 1, 2] and the prompt length is 7
|
89
122
|
# then, positions = [7, 8, 8, 9]
|
90
|
-
|
123
|
+
if position_buf is not None:
|
124
|
+
positions = position_buf
|
125
|
+
else:
|
126
|
+
positions = torch.empty(
|
127
|
+
(bs * num_verify_tokens,), device=device, dtype=torch.long
|
128
|
+
)
|
91
129
|
|
92
130
|
sgl_build_tree_kernel_efficient(
|
93
131
|
parent_list,
|
@@ -101,6 +139,7 @@ def build_tree_kernel_efficient(
|
|
101
139
|
topk,
|
102
140
|
spec_steps,
|
103
141
|
num_verify_tokens,
|
142
|
+
tree_mask_mode,
|
104
143
|
)
|
105
144
|
return (
|
106
145
|
tree_mask,
|
@@ -344,13 +383,13 @@ def test_build_tree_kernel_efficient():
|
|
344
383
|
num_verify_tokens=num_draft_token,
|
345
384
|
)
|
346
385
|
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
386
|
+
print("=========== build tree kernel efficient ==========")
|
387
|
+
print(f"{tree_mask=}")
|
388
|
+
print(f"{position=}")
|
389
|
+
print(f"{retrive_index=}")
|
390
|
+
print(f"{retrive_next_token=}")
|
391
|
+
print(f"{retrive_next_sibling=}")
|
392
|
+
print(f"{draft_tokens=}")
|
354
393
|
assert position.tolist() == [5, 6, 6, 7, 7, 8, 8, 9, 10, 11, 12, 12, 12, 12, 13, 14]
|
355
394
|
assert retrive_index.tolist() == [
|
356
395
|
[0, 1, 2, 3, 4, 5, 6, 7],
|
@@ -140,9 +140,11 @@ class EAGLEWorker(TpModelWorker):
|
|
140
140
|
self.draft_model_runner.model.set_embed(embed)
|
141
141
|
|
142
142
|
# grab hot token ids
|
143
|
-
|
144
|
-
|
145
|
-
|
143
|
+
if self.draft_model_runner.model.hot_token_id is not None:
|
144
|
+
self.hot_token_id = self.draft_model_runner.model.hot_token_id.to(
|
145
|
+
embed.device
|
146
|
+
)
|
147
|
+
|
146
148
|
else:
|
147
149
|
if self.hot_token_id is not None:
|
148
150
|
head = head.clone()
|
@@ -842,7 +844,7 @@ class EAGLEWorker(TpModelWorker):
|
|
842
844
|
)
|
843
845
|
batch.return_hidden_states = False
|
844
846
|
model_worker_batch = batch.get_model_worker_batch()
|
845
|
-
model_worker_batch.spec_num_draft_tokens = self.
|
847
|
+
model_worker_batch.spec_num_draft_tokens = self.speculative_num_steps + 1
|
846
848
|
assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
|
847
849
|
forward_batch = ForwardBatch.init_new(
|
848
850
|
model_worker_batch, self.draft_model_runner
|