sglang 0.4.8__py3-none-any.whl → 0.4.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +17 -2
- sglang/bench_serving.py +168 -22
- sglang/srt/configs/internvl.py +4 -2
- sglang/srt/configs/janus_pro.py +1 -1
- sglang/srt/configs/model_config.py +49 -0
- sglang/srt/configs/update_config.py +119 -0
- sglang/srt/conversation.py +35 -0
- sglang/srt/custom_op.py +7 -1
- sglang/srt/disaggregation/base/conn.py +2 -0
- sglang/srt/disaggregation/decode.py +22 -6
- sglang/srt/disaggregation/mooncake/conn.py +289 -48
- sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
- sglang/srt/disaggregation/nixl/conn.py +100 -52
- sglang/srt/disaggregation/prefill.py +5 -4
- sglang/srt/disaggregation/utils.py +13 -12
- sglang/srt/distributed/parallel_state.py +44 -17
- sglang/srt/entrypoints/EngineBase.py +8 -0
- sglang/srt/entrypoints/engine.py +45 -9
- sglang/srt/entrypoints/http_server.py +111 -24
- sglang/srt/entrypoints/openai/protocol.py +51 -6
- sglang/srt/entrypoints/openai/serving_chat.py +52 -76
- sglang/srt/entrypoints/openai/serving_completions.py +1 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/eplb/__init__.py +0 -0
- sglang/srt/{managers → eplb}/eplb_algorithms/__init__.py +1 -1
- sglang/srt/{managers → eplb}/eplb_manager.py +2 -4
- sglang/srt/{eplb_simulator → eplb/eplb_simulator}/reader.py +1 -1
- sglang/srt/{managers → eplb}/expert_distribution.py +18 -1
- sglang/srt/{managers → eplb}/expert_location.py +1 -1
- sglang/srt/{managers → eplb}/expert_location_dispatch.py +1 -1
- sglang/srt/{model_executor → eplb}/expert_location_updater.py +17 -1
- sglang/srt/hf_transformers_utils.py +2 -1
- sglang/srt/layers/activation.py +7 -0
- sglang/srt/layers/amx_utils.py +86 -0
- sglang/srt/layers/attention/ascend_backend.py +219 -0
- sglang/srt/layers/attention/flashattention_backend.py +56 -23
- sglang/srt/layers/attention/tbo_backend.py +37 -9
- sglang/srt/layers/communicator.py +18 -2
- sglang/srt/layers/dp_attention.py +9 -3
- sglang/srt/layers/elementwise.py +76 -12
- sglang/srt/layers/flashinfer_comm_fusion.py +202 -0
- sglang/srt/layers/layernorm.py +41 -0
- sglang/srt/layers/linear.py +99 -12
- sglang/srt/layers/logits_processor.py +15 -6
- sglang/srt/layers/moe/ep_moe/kernels.py +23 -8
- sglang/srt/layers/moe/ep_moe/layer.py +115 -25
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +42 -19
- sglang/srt/layers/moe/fused_moe_native.py +7 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +8 -4
- sglang/srt/layers/moe/fused_moe_triton/layer.py +129 -10
- sglang/srt/layers/moe/router.py +60 -22
- sglang/srt/layers/moe/topk.py +36 -28
- sglang/srt/layers/parameter.py +67 -7
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +1 -1
- sglang/srt/layers/quantization/fp8.py +44 -0
- sglang/srt/layers/quantization/fp8_kernel.py +1 -1
- sglang/srt/layers/quantization/fp8_utils.py +6 -6
- sglang/srt/layers/quantization/gptq.py +5 -1
- sglang/srt/layers/quantization/moe_wna16.py +1 -1
- sglang/srt/layers/quantization/quant_utils.py +166 -0
- sglang/srt/layers/quantization/w8a8_int8.py +52 -1
- sglang/srt/layers/rotary_embedding.py +105 -13
- sglang/srt/layers/vocab_parallel_embedding.py +19 -2
- sglang/srt/lora/lora.py +4 -5
- sglang/srt/lora/lora_manager.py +73 -20
- sglang/srt/managers/configure_logging.py +1 -1
- sglang/srt/managers/io_struct.py +60 -15
- sglang/srt/managers/mm_utils.py +73 -59
- sglang/srt/managers/multimodal_processor.py +2 -6
- sglang/srt/managers/multimodal_processors/qwen_audio.py +94 -0
- sglang/srt/managers/schedule_batch.py +80 -79
- sglang/srt/managers/scheduler.py +153 -63
- sglang/srt/managers/scheduler_output_processor_mixin.py +8 -2
- sglang/srt/managers/session_controller.py +12 -3
- sglang/srt/managers/tokenizer_manager.py +314 -103
- sglang/srt/managers/tp_worker.py +13 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -0
- sglang/srt/mem_cache/allocator.py +290 -0
- sglang/srt/mem_cache/chunk_cache.py +34 -2
- sglang/srt/mem_cache/memory_pool.py +289 -3
- sglang/srt/mem_cache/multimodal_cache.py +3 -0
- sglang/srt/model_executor/cuda_graph_runner.py +3 -2
- sglang/srt/model_executor/forward_batch_info.py +17 -4
- sglang/srt/model_executor/model_runner.py +302 -58
- sglang/srt/model_loader/loader.py +86 -10
- sglang/srt/model_loader/weight_utils.py +160 -3
- sglang/srt/models/deepseek_nextn.py +5 -4
- sglang/srt/models/deepseek_v2.py +305 -26
- sglang/srt/models/deepseek_vl2.py +3 -5
- sglang/srt/models/gemma3_causal.py +1 -2
- sglang/srt/models/gemma3n_audio.py +949 -0
- sglang/srt/models/gemma3n_causal.py +1010 -0
- sglang/srt/models/gemma3n_mm.py +495 -0
- sglang/srt/models/hunyuan.py +771 -0
- sglang/srt/models/kimi_vl.py +1 -2
- sglang/srt/models/llama.py +10 -4
- sglang/srt/models/llama4.py +32 -45
- sglang/srt/models/llama_eagle3.py +61 -11
- sglang/srt/models/llava.py +5 -5
- sglang/srt/models/minicpmo.py +2 -2
- sglang/srt/models/mistral.py +1 -1
- sglang/srt/models/mllama4.py +43 -11
- sglang/srt/models/phi4mm.py +1 -3
- sglang/srt/models/pixtral.py +3 -7
- sglang/srt/models/qwen2.py +31 -3
- sglang/srt/models/qwen2_5_vl.py +1 -3
- sglang/srt/models/qwen2_audio.py +200 -0
- sglang/srt/models/qwen2_moe.py +32 -6
- sglang/srt/models/qwen2_vl.py +1 -4
- sglang/srt/models/qwen3.py +94 -25
- sglang/srt/models/qwen3_moe.py +68 -21
- sglang/srt/models/vila.py +3 -8
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/base_processor.py +150 -133
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/clip.py +2 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/deepseek_vl_v2.py +4 -11
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/gemma3.py +3 -10
- sglang/srt/multimodal/processors/gemma3n.py +82 -0
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/internvl.py +3 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/janus_pro.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/kimi_vl.py +6 -13
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/llava.py +2 -10
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/minicpm.py +5 -12
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mlama.py +2 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/mllama4.py +3 -6
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/phi4mm.py +4 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/pixtral.py +3 -9
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/qwen_vl.py +8 -14
- sglang/srt/{managers/multimodal_processors → multimodal/processors}/vila.py +13 -31
- sglang/srt/operations_strategy.py +6 -2
- sglang/srt/reasoning_parser.py +26 -0
- sglang/srt/sampling/sampling_batch_info.py +39 -1
- sglang/srt/server_args.py +85 -24
- sglang/srt/speculative/build_eagle_tree.py +57 -18
- sglang/srt/speculative/eagle_worker.py +6 -4
- sglang/srt/two_batch_overlap.py +204 -28
- sglang/srt/utils.py +369 -138
- sglang/srt/warmup.py +12 -3
- sglang/test/runners.py +10 -1
- sglang/test/test_utils.py +15 -3
- sglang/version.py +1 -1
- {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/METADATA +9 -6
- {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/RECORD +149 -137
- sglang/math_utils.py +0 -8
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek.py +0 -0
- /sglang/srt/{managers → eplb}/eplb_algorithms/deepseek_vec.py +0 -0
- /sglang/srt/{eplb_simulator → eplb/eplb_simulator}/__init__.py +0 -0
- /sglang/srt/{mm_utils.py → multimodal/mm_utils.py} +0 -0
- {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/WHEEL +0 -0
- {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.8.dist-info → sglang-0.4.9.dist-info}/top_level.txt +0 -0
@@ -10,7 +10,6 @@ import torch
|
|
10
10
|
import sglang.srt.sampling.penaltylib as penaltylib
|
11
11
|
from sglang.srt.sampling.custom_logit_processor import CustomLogitProcessor
|
12
12
|
from sglang.srt.sampling.sampling_params import TOP_K_ALL
|
13
|
-
from sglang.srt.utils import merge_bias_tensor
|
14
13
|
|
15
14
|
if TYPE_CHECKING:
|
16
15
|
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
@@ -345,3 +344,42 @@ class SamplingBatchInfo:
|
|
345
344
|
self.logit_bias = merge_bias_tensor(
|
346
345
|
self.logit_bias, other.logit_bias, len(self), len(other), self.device, 0.0
|
347
346
|
)
|
347
|
+
|
348
|
+
|
349
|
+
def merge_bias_tensor(
|
350
|
+
lhs: Optional[torch.Tensor],
|
351
|
+
rhs: Optional[torch.Tensor],
|
352
|
+
bs1: int,
|
353
|
+
bs2: int,
|
354
|
+
device: str,
|
355
|
+
default: float,
|
356
|
+
):
|
357
|
+
"""Merge two bias tensors for batch merging.
|
358
|
+
|
359
|
+
Args:
|
360
|
+
lhs: Left-hand side tensor
|
361
|
+
rhs: Right-hand side tensor
|
362
|
+
bs1: Batch size of left-hand side tensor
|
363
|
+
bs2: Batch size of right-hand side tensor
|
364
|
+
device: Device to place the merged tensor on
|
365
|
+
default: Default value for missing tensor elements
|
366
|
+
|
367
|
+
Returns:
|
368
|
+
Merged tensor or None if both inputs are None
|
369
|
+
"""
|
370
|
+
if lhs is None and rhs is None:
|
371
|
+
return None
|
372
|
+
|
373
|
+
if lhs is not None and rhs is not None:
|
374
|
+
return torch.cat([lhs, rhs])
|
375
|
+
else:
|
376
|
+
if lhs is not None:
|
377
|
+
shape, dtype = lhs.shape[1:], lhs.dtype
|
378
|
+
else:
|
379
|
+
shape, dtype = rhs.shape[1:], rhs.dtype
|
380
|
+
|
381
|
+
if lhs is None:
|
382
|
+
lhs = torch.empty((bs1, *shape), device=device, dtype=dtype).fill_(default)
|
383
|
+
if rhs is None:
|
384
|
+
rhs = torch.empty((bs2, *shape), device=device, dtype=dtype).fill_(default)
|
385
|
+
return torch.cat([lhs, rhs])
|
sglang/srt/server_args.py
CHANGED
@@ -20,7 +20,7 @@ import logging
|
|
20
20
|
import os
|
21
21
|
import random
|
22
22
|
import tempfile
|
23
|
-
from typing import List, Literal, Optional
|
23
|
+
from typing import List, Literal, Optional, Union
|
24
24
|
|
25
25
|
from sglang.srt.hf_transformers_utils import check_gguf_file, get_config
|
26
26
|
from sglang.srt.reasoning_parser import ReasoningParser
|
@@ -46,7 +46,9 @@ class ServerArgs:
|
|
46
46
|
tokenizer_path: Optional[str] = None
|
47
47
|
tokenizer_mode: str = "auto"
|
48
48
|
skip_tokenizer_init: bool = False
|
49
|
+
skip_server_warmup: bool = False
|
49
50
|
load_format: str = "auto"
|
51
|
+
model_loader_extra_config: str = "{}"
|
50
52
|
trust_remote_code: bool = False
|
51
53
|
dtype: str = "auto"
|
52
54
|
kv_cache_dtype: str = "auto"
|
@@ -60,11 +62,13 @@ class ServerArgs:
|
|
60
62
|
is_embedding: bool = False
|
61
63
|
enable_multimodal: Optional[bool] = None
|
62
64
|
revision: Optional[str] = None
|
65
|
+
hybrid_kvcache_ratio: Optional[float] = None
|
63
66
|
impl: str = "auto"
|
64
67
|
|
65
68
|
# Port for the HTTP server
|
66
69
|
host: str = "127.0.0.1"
|
67
70
|
port: int = 30000
|
71
|
+
nccl_port: Optional[int] = None
|
68
72
|
|
69
73
|
# Memory and scheduling
|
70
74
|
mem_fraction_static: Optional[float] = None
|
@@ -97,6 +101,7 @@ class ServerArgs:
|
|
97
101
|
log_level_http: Optional[str] = None
|
98
102
|
log_requests: bool = False
|
99
103
|
log_requests_level: int = 0
|
104
|
+
crash_dump_folder: Optional[str] = None
|
100
105
|
show_time_cost: bool = False
|
101
106
|
enable_metrics: bool = False
|
102
107
|
bucket_time_to_first_token: Optional[List[float]] = None
|
@@ -128,7 +133,7 @@ class ServerArgs:
|
|
128
133
|
preferred_sampling_params: Optional[str] = None
|
129
134
|
|
130
135
|
# LoRA
|
131
|
-
lora_paths: Optional[List[str]] = None
|
136
|
+
lora_paths: Optional[Union[dict[str, str], List[str]]] = None
|
132
137
|
max_loras_per_batch: int = 8
|
133
138
|
lora_backend: str = "triton"
|
134
139
|
|
@@ -153,6 +158,7 @@ class ServerArgs:
|
|
153
158
|
enable_ep_moe: bool = False
|
154
159
|
enable_deepep_moe: bool = False
|
155
160
|
enable_flashinfer_moe: bool = False
|
161
|
+
enable_flashinfer_allreduce_fusion: bool = False
|
156
162
|
deepep_mode: Optional[Literal["auto", "normal", "low_latency"]] = "auto"
|
157
163
|
ep_num_redundant_experts: int = 0
|
158
164
|
ep_dispatch_algorithm: Optional[Literal["static", "dynamic", "fake"]] = None
|
@@ -314,6 +320,14 @@ class ServerArgs:
|
|
314
320
|
else:
|
315
321
|
self.mem_fraction_static = 0.88
|
316
322
|
|
323
|
+
# Lazy init to avoid circular import
|
324
|
+
from sglang.srt.configs.model_config import ModelConfig
|
325
|
+
|
326
|
+
# Multimodal models need more memory for the image processor
|
327
|
+
model_config = ModelConfig.from_server_args(self)
|
328
|
+
if model_config.is_multimodal:
|
329
|
+
self.mem_fraction_static *= 0.90
|
330
|
+
|
317
331
|
# Set chunked prefill size, which depends on the gpu memory capacity
|
318
332
|
if self.chunked_prefill_size is None:
|
319
333
|
if gpu_mem is not None:
|
@@ -375,6 +389,12 @@ class ServerArgs:
|
|
375
389
|
)
|
376
390
|
self.disable_cuda_graph = True
|
377
391
|
|
392
|
+
if self.attention_backend == "ascend":
|
393
|
+
logger.warning(
|
394
|
+
"At this moment Ascend attention backend only supports a page_size of 128, change page_size to 128."
|
395
|
+
)
|
396
|
+
self.page_size = 128
|
397
|
+
|
378
398
|
# Choose grammar backend
|
379
399
|
if self.grammar_backend is None:
|
380
400
|
self.grammar_backend = "xgrammar"
|
@@ -398,10 +418,6 @@ class ServerArgs:
|
|
398
418
|
|
399
419
|
# DeepEP MoE
|
400
420
|
if self.enable_deepep_moe:
|
401
|
-
if self.deepep_mode == "auto":
|
402
|
-
assert (
|
403
|
-
not self.enable_dp_attention
|
404
|
-
), "DeepEP MoE `auto` mode is not supported with DP Attention."
|
405
421
|
if self.deepep_mode == "normal":
|
406
422
|
logger.warning("Cuda graph is disabled because deepep_mode=`normal`")
|
407
423
|
self.disable_cuda_graph = True
|
@@ -484,12 +500,6 @@ class ServerArgs:
|
|
484
500
|
self.speculative_num_draft_tokens,
|
485
501
|
) = auto_choose_speculative_params(self)
|
486
502
|
|
487
|
-
if self.page_size > 1 and self.speculative_eagle_topk > 1:
|
488
|
-
self.speculative_eagle_topk = 1
|
489
|
-
logger.warning(
|
490
|
-
"speculative_eagle_topk is adjusted to 1 when page_size > 1"
|
491
|
-
)
|
492
|
-
|
493
503
|
if (
|
494
504
|
self.speculative_eagle_topk == 1
|
495
505
|
and self.speculative_num_draft_tokens != self.speculative_num_steps + 1
|
@@ -563,6 +573,7 @@ class ServerArgs:
|
|
563
573
|
# Model and port args
|
564
574
|
parser.add_argument(
|
565
575
|
"--model-path",
|
576
|
+
"--model",
|
566
577
|
type=str,
|
567
578
|
help="The path of the model weights. This can be a local folder or a Hugging Face repo ID.",
|
568
579
|
required=True,
|
@@ -585,6 +596,12 @@ class ServerArgs:
|
|
585
596
|
default=ServerArgs.port,
|
586
597
|
help="The port of the HTTP server.",
|
587
598
|
)
|
599
|
+
parser.add_argument(
|
600
|
+
"--nccl-port",
|
601
|
+
type=int,
|
602
|
+
default=ServerArgs.nccl_port,
|
603
|
+
help="The port for NCCL distributed environment setup. Defaults to a random port.",
|
604
|
+
)
|
588
605
|
parser.add_argument(
|
589
606
|
"--tokenizer-mode",
|
590
607
|
type=str,
|
@@ -599,6 +616,11 @@ class ServerArgs:
|
|
599
616
|
action="store_true",
|
600
617
|
help="If set, skip init tokenizer and pass input_ids in generate request.",
|
601
618
|
)
|
619
|
+
parser.add_argument(
|
620
|
+
"--skip-server-warmup",
|
621
|
+
action="store_true",
|
622
|
+
help="If set, skip warmup.",
|
623
|
+
)
|
602
624
|
parser.add_argument(
|
603
625
|
"--load-format",
|
604
626
|
type=str,
|
@@ -632,6 +654,13 @@ class ServerArgs:
|
|
632
654
|
"layer before loading another to make the peak memory envelope "
|
633
655
|
"smaller.",
|
634
656
|
)
|
657
|
+
parser.add_argument(
|
658
|
+
"--model-loader-extra-config",
|
659
|
+
type=str,
|
660
|
+
help="Extra config for model loader. "
|
661
|
+
"This will be passed to the model loader corresponding to the chosen load_format.",
|
662
|
+
default=ServerArgs.model_loader_extra_config,
|
663
|
+
)
|
635
664
|
parser.add_argument(
|
636
665
|
"--trust-remote-code",
|
637
666
|
action="store_true",
|
@@ -808,6 +837,18 @@ class ServerArgs:
|
|
808
837
|
default=ServerArgs.page_size,
|
809
838
|
help="The number of tokens in a page.",
|
810
839
|
)
|
840
|
+
parser.add_argument(
|
841
|
+
"--hybrid-kvcache-ratio",
|
842
|
+
nargs="?",
|
843
|
+
const=0.5,
|
844
|
+
type=float,
|
845
|
+
default=ServerArgs.hybrid_kvcache_ratio,
|
846
|
+
help=(
|
847
|
+
"Mix ratio in [0,1] between uniform and hybrid kv buffers "
|
848
|
+
"(0.0 = pure uniform: swa_size / full_size = 1)"
|
849
|
+
"(1.0 = pure hybrid: swa_size / full_size = local_attention_size / context_length)"
|
850
|
+
),
|
851
|
+
)
|
811
852
|
|
812
853
|
# Other runtime options
|
813
854
|
parser.add_argument(
|
@@ -911,8 +952,14 @@ class ServerArgs:
|
|
911
952
|
"--log-requests-level",
|
912
953
|
type=int,
|
913
954
|
default=0,
|
914
|
-
help="0: Log metadata. 1. Log metadata and partial input/output.
|
915
|
-
choices=[0, 1, 2],
|
955
|
+
help="0: Log metadata (no sampling parameters). 1: Log metadata and sampling parameters. 2: Log metadata, sampling parameters and partial input/output. 3: Log every input/output.",
|
956
|
+
choices=[0, 1, 2, 3],
|
957
|
+
)
|
958
|
+
parser.add_argument(
|
959
|
+
"--crash-dump-folder",
|
960
|
+
type=str,
|
961
|
+
default=ServerArgs.crash_dump_folder,
|
962
|
+
help="Folder path to dump requests from the last 5 min before a crash (if any). If not specified, crash dumping is disabled.",
|
916
963
|
)
|
917
964
|
parser.add_argument(
|
918
965
|
"--show-time-cost",
|
@@ -1083,6 +1130,7 @@ class ServerArgs:
|
|
1083
1130
|
"flashmla",
|
1084
1131
|
"intel_amx",
|
1085
1132
|
"torch_native",
|
1133
|
+
"ascend",
|
1086
1134
|
"triton",
|
1087
1135
|
],
|
1088
1136
|
default=ServerArgs.attention_backend,
|
@@ -1177,6 +1225,11 @@ class ServerArgs:
|
|
1177
1225
|
action="store_true",
|
1178
1226
|
help="Enable FlashInfer CUTLASS MoE backend for modelopt_fp4 quant on Blackwell. Supports MoE-EP with --enable-ep-moe",
|
1179
1227
|
)
|
1228
|
+
parser.add_argument(
|
1229
|
+
"--enable-flashinfer-allreduce-fusion",
|
1230
|
+
action="store_true",
|
1231
|
+
help="Enable FlashInfer allreduce fusion for Add_RMSNorm.",
|
1232
|
+
)
|
1180
1233
|
parser.add_argument(
|
1181
1234
|
"--enable-deepep-moe",
|
1182
1235
|
action="store_true",
|
@@ -1692,16 +1745,22 @@ class PortArgs:
|
|
1692
1745
|
# The ipc filename for rpc call between Engine and Scheduler
|
1693
1746
|
rpc_ipc_name: str
|
1694
1747
|
|
1748
|
+
# The ipc filename for Scheduler to send metrics
|
1749
|
+
metrics_ipc_name: str
|
1750
|
+
|
1695
1751
|
@staticmethod
|
1696
1752
|
def init_new(server_args, dp_rank: Optional[int] = None) -> "PortArgs":
|
1697
|
-
|
1698
|
-
|
1699
|
-
|
1700
|
-
|
1701
|
-
|
1702
|
-
port
|
1703
|
-
|
1704
|
-
|
1753
|
+
if server_args.nccl_port is None:
|
1754
|
+
port = server_args.port + random.randint(100, 1000)
|
1755
|
+
while True:
|
1756
|
+
if is_port_available(port):
|
1757
|
+
break
|
1758
|
+
if port < 60000:
|
1759
|
+
port += 42
|
1760
|
+
else:
|
1761
|
+
port -= 43
|
1762
|
+
else:
|
1763
|
+
port = server_args.nccl_port
|
1705
1764
|
|
1706
1765
|
if not server_args.enable_dp_attention:
|
1707
1766
|
# Normal case, use IPC within a single node
|
@@ -1711,6 +1770,7 @@ class PortArgs:
|
|
1711
1770
|
detokenizer_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
|
1712
1771
|
nccl_port=port,
|
1713
1772
|
rpc_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
|
1773
|
+
metrics_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
|
1714
1774
|
)
|
1715
1775
|
else:
|
1716
1776
|
# DP attention. Use TCP + port to handle both single-node and multi-node.
|
@@ -1730,9 +1790,9 @@ class PortArgs:
|
|
1730
1790
|
port_base = int(dist_init_port) + 1
|
1731
1791
|
if dp_rank is None:
|
1732
1792
|
# TokenizerManager to DataParallelController
|
1733
|
-
scheduler_input_port = port_base +
|
1793
|
+
scheduler_input_port = port_base + 4
|
1734
1794
|
else:
|
1735
|
-
scheduler_input_port = port_base +
|
1795
|
+
scheduler_input_port = port_base + 4 + 1 + dp_rank
|
1736
1796
|
|
1737
1797
|
return PortArgs(
|
1738
1798
|
tokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base}",
|
@@ -1740,6 +1800,7 @@ class PortArgs:
|
|
1740
1800
|
detokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base + 1}",
|
1741
1801
|
nccl_port=port,
|
1742
1802
|
rpc_ipc_name=f"tcp://{dist_init_host}:{port_base + 2}",
|
1803
|
+
metrics_ipc_name=f"tcp://{dist_init_host}:{port_base + 3}",
|
1743
1804
|
)
|
1744
1805
|
|
1745
1806
|
|
@@ -1,10 +1,12 @@
|
|
1
1
|
# NOTE: Please run this file to make sure the test cases are correct.
|
2
2
|
|
3
|
-
|
3
|
+
import math
|
4
|
+
from enum import IntEnum
|
5
|
+
from typing import List, Optional
|
4
6
|
|
5
7
|
import torch
|
6
8
|
|
7
|
-
from sglang.srt.utils import is_cuda, is_hip
|
9
|
+
from sglang.srt.utils import is_cuda, is_hip
|
8
10
|
|
9
11
|
if is_cuda() or is_hip():
|
10
12
|
from sgl_kernel import (
|
@@ -40,6 +42,12 @@ def build_tree_kernel_efficient_preprocess(
|
|
40
42
|
return parent_list, top_scores_index, draft_tokens
|
41
43
|
|
42
44
|
|
45
|
+
class TreeMaskMode(IntEnum):
|
46
|
+
FULL_MASK = 0
|
47
|
+
QLEN_ONLY = 1
|
48
|
+
QLEN_ONLY_BITPACKING = 2
|
49
|
+
|
50
|
+
|
43
51
|
def build_tree_kernel_efficient(
|
44
52
|
verified_id: torch.Tensor,
|
45
53
|
score_list: List[torch.Tensor],
|
@@ -50,6 +58,9 @@ def build_tree_kernel_efficient(
|
|
50
58
|
topk: int,
|
51
59
|
spec_steps: int,
|
52
60
|
num_verify_tokens: int,
|
61
|
+
tree_mask_mode: TreeMaskMode = TreeMaskMode.FULL_MASK,
|
62
|
+
tree_mask_buf: Optional[torch.Tensor] = None,
|
63
|
+
position_buf: Optional[torch.Tensor] = None,
|
53
64
|
):
|
54
65
|
parent_list, top_scores_index, draft_tokens = (
|
55
66
|
build_tree_kernel_efficient_preprocess(
|
@@ -66,15 +77,37 @@ def build_tree_kernel_efficient(
|
|
66
77
|
device = seq_lens.device
|
67
78
|
# e.g. for bs=1, tree_mask: num_draft_token, seq_lens_sum + num_draft_token (flattened)
|
68
79
|
# where each row indicates the attending pattern of each draft token
|
80
|
+
# if use_partial_packed_tree_mask is True, tree_mask: num_draft_token (flattened, packed)
|
81
|
+
if tree_mask_buf is not None:
|
82
|
+
tree_mask = tree_mask_buf
|
83
|
+
elif tree_mask_mode == TreeMaskMode.QLEN_ONLY:
|
84
|
+
tree_mask = torch.full(
|
85
|
+
(num_verify_tokens * bs * num_verify_tokens,),
|
86
|
+
True,
|
87
|
+
dtype=torch.bool,
|
88
|
+
device=device,
|
89
|
+
)
|
90
|
+
elif tree_mask_mode == TreeMaskMode.QLEN_ONLY_BITPACKING:
|
91
|
+
packed_dtypes = [torch.uint8, torch.uint16, torch.uint32]
|
92
|
+
packed_dtype_idx = int(math.ceil(math.log2((num_verify_tokens + 7) // 8)))
|
93
|
+
tree_mask = torch.zeros(
|
94
|
+
(num_verify_tokens * bs,),
|
95
|
+
dtype=packed_dtypes[packed_dtype_idx],
|
96
|
+
device=device,
|
97
|
+
)
|
98
|
+
elif tree_mask_mode == TreeMaskMode.FULL_MASK:
|
99
|
+
tree_mask = torch.full(
|
100
|
+
(
|
101
|
+
seq_lens_sum * num_verify_tokens
|
102
|
+
+ num_verify_tokens * num_verify_tokens * bs,
|
103
|
+
),
|
104
|
+
True,
|
105
|
+
device=device,
|
106
|
+
)
|
107
|
+
else:
|
108
|
+
raise NotImplementedError(f"Invalid tree mask: {tree_mask_mode=}")
|
109
|
+
|
69
110
|
# TODO: make them torch.empty and fuse them into `sgl_build_tree_kernel`
|
70
|
-
tree_mask = torch.full(
|
71
|
-
(
|
72
|
-
seq_lens_sum * num_verify_tokens
|
73
|
-
+ num_verify_tokens * num_verify_tokens * bs,
|
74
|
-
),
|
75
|
-
True,
|
76
|
-
device=device,
|
77
|
-
)
|
78
111
|
retrive_index = torch.full(
|
79
112
|
(bs, num_verify_tokens), -1, device=device, dtype=torch.long
|
80
113
|
)
|
@@ -87,7 +120,12 @@ def build_tree_kernel_efficient(
|
|
87
120
|
# position: where each token belongs to
|
88
121
|
# e.g. if depth of each draft token is [0, 1, 1, 2] and the prompt length is 7
|
89
122
|
# then, positions = [7, 8, 8, 9]
|
90
|
-
|
123
|
+
if position_buf is not None:
|
124
|
+
positions = position_buf
|
125
|
+
else:
|
126
|
+
positions = torch.empty(
|
127
|
+
(bs * num_verify_tokens,), device=device, dtype=torch.long
|
128
|
+
)
|
91
129
|
|
92
130
|
sgl_build_tree_kernel_efficient(
|
93
131
|
parent_list,
|
@@ -101,6 +139,7 @@ def build_tree_kernel_efficient(
|
|
101
139
|
topk,
|
102
140
|
spec_steps,
|
103
141
|
num_verify_tokens,
|
142
|
+
tree_mask_mode,
|
104
143
|
)
|
105
144
|
return (
|
106
145
|
tree_mask,
|
@@ -344,13 +383,13 @@ def test_build_tree_kernel_efficient():
|
|
344
383
|
num_verify_tokens=num_draft_token,
|
345
384
|
)
|
346
385
|
|
347
|
-
|
348
|
-
|
349
|
-
|
350
|
-
|
351
|
-
|
352
|
-
|
353
|
-
|
386
|
+
print("=========== build tree kernel efficient ==========")
|
387
|
+
print(f"{tree_mask=}")
|
388
|
+
print(f"{position=}")
|
389
|
+
print(f"{retrive_index=}")
|
390
|
+
print(f"{retrive_next_token=}")
|
391
|
+
print(f"{retrive_next_sibling=}")
|
392
|
+
print(f"{draft_tokens=}")
|
354
393
|
assert position.tolist() == [5, 6, 6, 7, 7, 8, 8, 9, 10, 11, 12, 12, 12, 12, 13, 14]
|
355
394
|
assert retrive_index.tolist() == [
|
356
395
|
[0, 1, 2, 3, 4, 5, 6, 7],
|
@@ -140,9 +140,11 @@ class EAGLEWorker(TpModelWorker):
|
|
140
140
|
self.draft_model_runner.model.set_embed(embed)
|
141
141
|
|
142
142
|
# grab hot token ids
|
143
|
-
|
144
|
-
|
145
|
-
|
143
|
+
if self.draft_model_runner.model.hot_token_id is not None:
|
144
|
+
self.hot_token_id = self.draft_model_runner.model.hot_token_id.to(
|
145
|
+
embed.device
|
146
|
+
)
|
147
|
+
|
146
148
|
else:
|
147
149
|
if self.hot_token_id is not None:
|
148
150
|
head = head.clone()
|
@@ -842,7 +844,7 @@ class EAGLEWorker(TpModelWorker):
|
|
842
844
|
)
|
843
845
|
batch.return_hidden_states = False
|
844
846
|
model_worker_batch = batch.get_model_worker_batch()
|
845
|
-
model_worker_batch.spec_num_draft_tokens = self.
|
847
|
+
model_worker_batch.spec_num_draft_tokens = self.speculative_num_steps + 1
|
846
848
|
assert model_worker_batch.capture_hidden_mode == CaptureHiddenMode.LAST
|
847
849
|
forward_batch = ForwardBatch.init_new(
|
848
850
|
model_worker_batch, self.draft_model_runner
|