sglang 0.4.4__py3-none-any.whl → 0.4.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +6 -0
- sglang/bench_one_batch.py +1 -1
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +3 -1
- sglang/check_env.py +3 -4
- sglang/lang/backend/openai.py +18 -5
- sglang/lang/chat_template.py +28 -7
- sglang/lang/interpreter.py +7 -3
- sglang/lang/ir.py +10 -0
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/code_completion_parser.py +174 -0
- sglang/srt/configs/__init__.py +2 -6
- sglang/srt/configs/deepseekvl2.py +667 -0
- sglang/srt/configs/janus_pro.py +3 -4
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +63 -11
- sglang/srt/configs/utils.py +25 -0
- sglang/srt/connector/__init__.py +51 -0
- sglang/srt/connector/base_connector.py +112 -0
- sglang/srt/connector/redis.py +85 -0
- sglang/srt/connector/s3.py +122 -0
- sglang/srt/connector/serde/__init__.py +31 -0
- sglang/srt/connector/serde/safe_serde.py +29 -0
- sglang/srt/connector/serde/serde.py +43 -0
- sglang/srt/connector/utils.py +35 -0
- sglang/srt/conversation.py +88 -0
- sglang/srt/disaggregation/conn.py +81 -0
- sglang/srt/disaggregation/decode.py +495 -0
- sglang/srt/disaggregation/mini_lb.py +285 -0
- sglang/srt/disaggregation/prefill.py +249 -0
- sglang/srt/disaggregation/utils.py +44 -0
- sglang/srt/distributed/parallel_state.py +10 -3
- sglang/srt/entrypoints/engine.py +55 -5
- sglang/srt/entrypoints/http_server.py +71 -12
- sglang/srt/function_call_parser.py +164 -54
- sglang/srt/hf_transformers_utils.py +28 -3
- sglang/srt/layers/activation.py +4 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +295 -0
- sglang/srt/layers/attention/flashinfer_backend.py +1 -1
- sglang/srt/layers/attention/flashmla_backend.py +284 -0
- sglang/srt/layers/attention/triton_backend.py +171 -38
- sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
- sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
- sglang/srt/layers/attention/utils.py +53 -0
- sglang/srt/layers/attention/vision.py +9 -28
- sglang/srt/layers/dp_attention.py +62 -23
- sglang/srt/layers/elementwise.py +411 -0
- sglang/srt/layers/layernorm.py +24 -2
- sglang/srt/layers/linear.py +17 -5
- sglang/srt/layers/logits_processor.py +26 -7
- sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
- sglang/srt/layers/moe/ep_moe/layer.py +273 -1
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
- sglang/srt/layers/moe/fused_moe_native.py +2 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
- sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
- sglang/srt/layers/moe/router.py +342 -0
- sglang/srt/layers/moe/topk.py +31 -18
- sglang/srt/layers/parameter.py +1 -1
- sglang/srt/layers/quantization/__init__.py +184 -126
- sglang/srt/layers/quantization/base_config.py +5 -0
- sglang/srt/layers/quantization/blockwise_int8.py +1 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
- sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
- sglang/srt/layers/quantization/fp8.py +76 -34
- sglang/srt/layers/quantization/fp8_kernel.py +24 -8
- sglang/srt/layers/quantization/fp8_utils.py +284 -28
- sglang/srt/layers/quantization/gptq.py +36 -9
- sglang/srt/layers/quantization/kv_cache.py +98 -0
- sglang/srt/layers/quantization/modelopt_quant.py +9 -7
- sglang/srt/layers/quantization/utils.py +153 -0
- sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
- sglang/srt/layers/rotary_embedding.py +66 -87
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/layers.py +68 -0
- sglang/srt/lora/lora.py +2 -22
- sglang/srt/lora/lora_manager.py +47 -23
- sglang/srt/lora/mem_pool.py +110 -51
- sglang/srt/lora/utils.py +12 -1
- sglang/srt/managers/cache_controller.py +4 -5
- sglang/srt/managers/data_parallel_controller.py +31 -9
- sglang/srt/managers/expert_distribution.py +81 -0
- sglang/srt/managers/io_struct.py +39 -3
- sglang/srt/managers/mm_utils.py +373 -0
- sglang/srt/managers/multimodal_processor.py +68 -0
- sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
- sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
- sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
- sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
- sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
- sglang/srt/managers/schedule_batch.py +134 -31
- sglang/srt/managers/scheduler.py +325 -38
- sglang/srt/managers/scheduler_output_processor_mixin.py +4 -1
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +59 -23
- sglang/srt/managers/tp_worker.py +1 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
- sglang/srt/managers/utils.py +6 -1
- sglang/srt/mem_cache/hiradix_cache.py +27 -8
- sglang/srt/mem_cache/memory_pool.py +258 -98
- sglang/srt/mem_cache/paged_allocator.py +2 -2
- sglang/srt/mem_cache/radix_cache.py +4 -4
- sglang/srt/model_executor/cuda_graph_runner.py +85 -28
- sglang/srt/model_executor/forward_batch_info.py +81 -15
- sglang/srt/model_executor/model_runner.py +70 -6
- sglang/srt/model_loader/loader.py +160 -2
- sglang/srt/model_loader/weight_utils.py +45 -0
- sglang/srt/models/deepseek_janus_pro.py +29 -86
- sglang/srt/models/deepseek_nextn.py +22 -10
- sglang/srt/models/deepseek_v2.py +326 -192
- sglang/srt/models/deepseek_vl2.py +358 -0
- sglang/srt/models/gemma3_causal.py +684 -0
- sglang/srt/models/gemma3_mm.py +462 -0
- sglang/srt/models/grok.py +374 -119
- sglang/srt/models/llama.py +47 -7
- sglang/srt/models/llama_eagle.py +1 -0
- sglang/srt/models/llama_eagle3.py +196 -0
- sglang/srt/models/llava.py +3 -3
- sglang/srt/models/llavavid.py +3 -3
- sglang/srt/models/minicpmo.py +1995 -0
- sglang/srt/models/minicpmv.py +62 -137
- sglang/srt/models/mllama.py +4 -4
- sglang/srt/models/phi3_small.py +1 -1
- sglang/srt/models/qwen2.py +3 -0
- sglang/srt/models/qwen2_5_vl.py +68 -146
- sglang/srt/models/qwen2_classification.py +75 -0
- sglang/srt/models/qwen2_moe.py +9 -1
- sglang/srt/models/qwen2_vl.py +25 -63
- sglang/srt/openai_api/adapter.py +145 -47
- sglang/srt/openai_api/protocol.py +23 -2
- sglang/srt/sampling/sampling_batch_info.py +1 -1
- sglang/srt/sampling/sampling_params.py +6 -6
- sglang/srt/server_args.py +104 -14
- sglang/srt/speculative/build_eagle_tree.py +7 -347
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
- sglang/srt/speculative/eagle_utils.py +208 -252
- sglang/srt/speculative/eagle_worker.py +139 -53
- sglang/srt/speculative/spec_info.py +6 -1
- sglang/srt/torch_memory_saver_adapter.py +22 -0
- sglang/srt/utils.py +182 -21
- sglang/test/__init__.py +0 -0
- sglang/test/attention/__init__.py +0 -0
- sglang/test/attention/test_flashattn_backend.py +312 -0
- sglang/test/runners.py +2 -0
- sglang/test/test_activation.py +2 -1
- sglang/test/test_block_fp8.py +5 -4
- sglang/test/test_block_fp8_ep.py +2 -1
- sglang/test/test_dynamic_grad_mode.py +58 -0
- sglang/test/test_layernorm.py +3 -2
- sglang/test/test_utils.py +55 -4
- sglang/utils.py +31 -0
- sglang/version.py +1 -1
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/METADATA +12 -8
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/RECORD +171 -125
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/WHEEL +1 -1
- sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
- sglang/srt/managers/image_processor.py +0 -55
- sglang/srt/managers/image_processors/base_image_processor.py +0 -219
- sglang/srt/managers/image_processors/minicpmv.py +0 -86
- sglang/srt/managers/multi_modality_padding.py +0 -134
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info/licenses}/LICENSE +0 -0
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/top_level.txt +0 -0
sglang/srt/server_args.py
CHANGED
@@ -16,6 +16,7 @@
|
|
16
16
|
import argparse
|
17
17
|
import dataclasses
|
18
18
|
import logging
|
19
|
+
import os
|
19
20
|
import random
|
20
21
|
import tempfile
|
21
22
|
from typing import List, Optional
|
@@ -24,12 +25,14 @@ from sglang.srt.hf_transformers_utils import check_gguf_file
|
|
24
25
|
from sglang.srt.reasoning_parser import ReasoningParser
|
25
26
|
from sglang.srt.utils import (
|
26
27
|
get_amdgpu_memory_capacity,
|
28
|
+
get_device,
|
27
29
|
get_hpu_memory_capacity,
|
28
30
|
get_nvgpu_memory_capacity,
|
29
31
|
is_cuda,
|
30
32
|
is_flashinfer_available,
|
31
33
|
is_hip,
|
32
34
|
is_port_available,
|
35
|
+
is_remote_url,
|
33
36
|
is_valid_ipv6_address,
|
34
37
|
nullable_str,
|
35
38
|
)
|
@@ -51,9 +54,10 @@ class ServerArgs:
|
|
51
54
|
quantization: Optional[str] = None
|
52
55
|
quantization_param_path: nullable_str = None
|
53
56
|
context_length: Optional[int] = None
|
54
|
-
device: str =
|
57
|
+
device: Optional[str] = None
|
55
58
|
served_model_name: Optional[str] = None
|
56
59
|
chat_template: Optional[str] = None
|
60
|
+
completion_template: Optional[str] = None
|
57
61
|
is_embedding: bool = False
|
58
62
|
revision: Optional[str] = None
|
59
63
|
|
@@ -122,7 +126,7 @@ class ServerArgs:
|
|
122
126
|
# Kernel backend
|
123
127
|
attention_backend: Optional[str] = None
|
124
128
|
sampling_backend: Optional[str] = None
|
125
|
-
grammar_backend: Optional[str] = "
|
129
|
+
grammar_backend: Optional[str] = "xgrammar"
|
126
130
|
|
127
131
|
# Speculative decoding
|
128
132
|
speculative_algorithm: Optional[str] = None
|
@@ -154,6 +158,7 @@ class ServerArgs:
|
|
154
158
|
enable_mixed_chunk: bool = False
|
155
159
|
enable_dp_attention: bool = False
|
156
160
|
enable_ep_moe: bool = False
|
161
|
+
enable_deepep_moe: bool = False
|
157
162
|
enable_torch_compile: bool = False
|
158
163
|
torch_compile_max_bs: int = 32
|
159
164
|
cuda_graph_max_bs: Optional[int] = None
|
@@ -170,7 +175,9 @@ class ServerArgs:
|
|
170
175
|
enable_custom_logit_processor: bool = False
|
171
176
|
tool_call_parser: str = None
|
172
177
|
enable_hierarchical_cache: bool = False
|
178
|
+
hicache_ratio: float = 2.0
|
173
179
|
enable_flashinfer_mla: bool = False
|
180
|
+
enable_flashmla: bool = False
|
174
181
|
flashinfer_mla_disable_ragged: bool = False
|
175
182
|
warmups: Optional[str] = None
|
176
183
|
|
@@ -179,11 +186,18 @@ class ServerArgs:
|
|
179
186
|
debug_tensor_dump_input_file: Optional[str] = None
|
180
187
|
debug_tensor_dump_inject: bool = False
|
181
188
|
|
189
|
+
# For PD disaggregation: can be "null" (not disaggregated), "prefill" (prefill-only), or "decode" (decode-only)
|
190
|
+
disaggregation_mode: str = "null"
|
191
|
+
disaggregation_bootstrap_port: int = 8998
|
192
|
+
|
182
193
|
def __post_init__(self):
|
183
194
|
# Set missing default values
|
184
195
|
if self.tokenizer_path is None:
|
185
196
|
self.tokenizer_path = self.model_path
|
186
197
|
|
198
|
+
if self.device is None:
|
199
|
+
self.device = get_device()
|
200
|
+
|
187
201
|
if self.served_model_name is None:
|
188
202
|
self.served_model_name = self.model_path
|
189
203
|
|
@@ -222,6 +236,11 @@ class ServerArgs:
|
|
222
236
|
|
223
237
|
assert self.chunked_prefill_size % self.page_size == 0
|
224
238
|
|
239
|
+
if self.enable_flashmla is True:
|
240
|
+
logger.warning(
|
241
|
+
"FlashMLA only supports a page_size of 64, change page_size to 64."
|
242
|
+
)
|
243
|
+
self.page_size = 64
|
225
244
|
# Set cuda graph max batch size
|
226
245
|
if self.cuda_graph_max_bs is None:
|
227
246
|
# Based on detailed statistics, when serving TP1/TP2 models on lower-end GPUs with HBM<25G, you can either disable cuda graph or set `cuda_graph_max_bs` to a very small value to reduce the memory overhead of creating cuda graphs, with almost no impact on performance. However, when serving models with TP4 or TP8, we need to enable cuda graph to maintain high performance. In this case, we can set `cuda_graph_max_bs` to 80 (half of the default value 160) to reduce the memory overhead of creating cuda graphs. Looking at the logs from TP4 serving of qwen2-72b, a value of 80 is sufficient and can reduce the memory overhead of creating cuda graphs on lower-end GPUs compared to the original 160, avoiding OOM issues.
|
@@ -262,25 +281,33 @@ class ServerArgs:
|
|
262
281
|
|
263
282
|
# Data parallelism attention
|
264
283
|
if self.enable_dp_attention:
|
265
|
-
self.dp_size = self.tp_size
|
266
|
-
assert self.tp_size % self.dp_size == 0
|
267
|
-
self.chunked_prefill_size = self.chunked_prefill_size // 2
|
268
284
|
self.schedule_conservativeness = self.schedule_conservativeness * 0.3
|
285
|
+
assert (
|
286
|
+
self.dp_size > 1
|
287
|
+
), "Please set a dp-size > 1. You can use 1 < dp-size <= tp-size "
|
288
|
+
assert self.tp_size % self.dp_size == 0
|
289
|
+
self.chunked_prefill_size = self.chunked_prefill_size // self.dp_size
|
269
290
|
logger.warning(
|
270
291
|
f"DP attention is enabled. The chunked prefill size is adjusted to {self.chunked_prefill_size} to avoid MoE kernel issues. "
|
271
|
-
f"The schedule conservativeness is adjusted to {self.schedule_conservativeness}. "
|
272
|
-
"Data parallel size is adjusted to be the same as tensor parallel size. "
|
273
292
|
)
|
293
|
+
# DeepEP MoE
|
294
|
+
if self.enable_deepep_moe:
|
295
|
+
self.ep_size = self.dp_size
|
296
|
+
logger.info(
|
297
|
+
f"DeepEP MoE is enabled. The expert parallel size is adjusted to be the same as the data parallel size[{self.dp_size}]."
|
298
|
+
)
|
274
299
|
|
275
300
|
# Speculative Decoding
|
276
301
|
if self.speculative_algorithm == "NEXTN":
|
277
302
|
# NEXTN shares the same implementation of EAGLE
|
278
303
|
self.speculative_algorithm = "EAGLE"
|
279
304
|
|
280
|
-
if
|
305
|
+
if (
|
306
|
+
self.speculative_algorithm == "EAGLE"
|
307
|
+
or self.speculative_algorithm == "EAGLE3"
|
308
|
+
):
|
281
309
|
if self.max_running_requests is None:
|
282
310
|
self.max_running_requests = 32
|
283
|
-
self.disable_cuda_graph_padding = True
|
284
311
|
self.disable_overlap_schedule = True
|
285
312
|
logger.info(
|
286
313
|
"Overlap scheduler is disabled because of using "
|
@@ -296,10 +323,29 @@ class ServerArgs:
|
|
296
323
|
) and check_gguf_file(self.model_path):
|
297
324
|
self.quantization = self.load_format = "gguf"
|
298
325
|
|
326
|
+
if is_remote_url(self.model_path):
|
327
|
+
self.load_format = "remote"
|
328
|
+
|
299
329
|
# AMD-specific Triton attention KV splits default number
|
300
330
|
if is_hip():
|
301
331
|
self.triton_attention_num_kv_splits = 16
|
302
332
|
|
333
|
+
# PD disaggregation
|
334
|
+
if self.disaggregation_mode == "prefill":
|
335
|
+
self.disable_cuda_graph = True
|
336
|
+
logger.warning("KV cache is forced as chunk cache for decode server")
|
337
|
+
self.disable_overlap_schedule = True
|
338
|
+
logger.warning("Overlap scheduler is disabled for prefill server")
|
339
|
+
elif self.disaggregation_mode == "decode":
|
340
|
+
self.disable_radix_cache = True
|
341
|
+
logger.warning("Cuda graph is disabled for prefill server")
|
342
|
+
self.disable_overlap_schedule = True
|
343
|
+
logger.warning("Overlap scheduler is disabled for decode server")
|
344
|
+
|
345
|
+
os.environ["SGLANG_ENABLE_TORCH_COMPILE"] = (
|
346
|
+
"1" if self.enable_torch_compile else "0"
|
347
|
+
)
|
348
|
+
|
303
349
|
@staticmethod
|
304
350
|
def add_cli_args(parser: argparse.ArgumentParser):
|
305
351
|
# Model and port args
|
@@ -345,9 +391,11 @@ class ServerArgs:
|
|
345
391
|
"safetensors",
|
346
392
|
"npcache",
|
347
393
|
"dummy",
|
394
|
+
"sharded_state",
|
348
395
|
"gguf",
|
349
396
|
"bitsandbytes",
|
350
397
|
"layered",
|
398
|
+
"remote",
|
351
399
|
],
|
352
400
|
help="The format of the model weights to load. "
|
353
401
|
'"auto" will try to load the weights in the safetensors format '
|
@@ -429,9 +477,8 @@ class ServerArgs:
|
|
429
477
|
parser.add_argument(
|
430
478
|
"--device",
|
431
479
|
type=str,
|
432
|
-
default=
|
433
|
-
|
434
|
-
help="The device type.",
|
480
|
+
default=ServerArgs.device,
|
481
|
+
help="The device to use ('cuda', 'xpu', 'hpu', 'cpu'). Defaults to auto-detection if not specified.",
|
435
482
|
)
|
436
483
|
parser.add_argument(
|
437
484
|
"--served-model-name",
|
@@ -445,6 +492,12 @@ class ServerArgs:
|
|
445
492
|
default=ServerArgs.chat_template,
|
446
493
|
help="The buliltin chat template name or the path of the chat template file. This is only used for OpenAI-compatible API server.",
|
447
494
|
)
|
495
|
+
parser.add_argument(
|
496
|
+
"--completion-template",
|
497
|
+
type=str,
|
498
|
+
default=ServerArgs.completion_template,
|
499
|
+
help="The buliltin completion template name or the path of the completion template file. This is only used for OpenAI-compatible API server. only for code completion currently.",
|
500
|
+
)
|
448
501
|
parser.add_argument(
|
449
502
|
"--is-embedding",
|
450
503
|
action="store_true",
|
@@ -722,7 +775,7 @@ class ServerArgs:
|
|
722
775
|
parser.add_argument(
|
723
776
|
"--attention-backend",
|
724
777
|
type=str,
|
725
|
-
choices=["flashinfer", "triton", "torch_native"],
|
778
|
+
choices=["flashinfer", "triton", "torch_native", "fa3"],
|
726
779
|
default=ServerArgs.attention_backend,
|
727
780
|
help="Choose the kernels for attention layers.",
|
728
781
|
)
|
@@ -745,6 +798,11 @@ class ServerArgs:
|
|
745
798
|
action="store_true",
|
746
799
|
help="Enable FlashInfer MLA optimization",
|
747
800
|
)
|
801
|
+
parser.add_argument(
|
802
|
+
"--enable-flashmla",
|
803
|
+
action="store_true",
|
804
|
+
help="Enable FlashMLA decode optimization",
|
805
|
+
)
|
748
806
|
parser.add_argument(
|
749
807
|
"--flashinfer-mla-disable-ragged",
|
750
808
|
action="store_true",
|
@@ -755,7 +813,7 @@ class ServerArgs:
|
|
755
813
|
parser.add_argument(
|
756
814
|
"--speculative-algorithm",
|
757
815
|
type=str,
|
758
|
-
choices=["EAGLE", "NEXTN"],
|
816
|
+
choices=["EAGLE", "EAGLE3", "NEXTN"],
|
759
817
|
help="Speculative algorithm.",
|
760
818
|
)
|
761
819
|
parser.add_argument(
|
@@ -984,6 +1042,18 @@ class ServerArgs:
|
|
984
1042
|
action="store_true",
|
985
1043
|
help="Enable hierarchical cache",
|
986
1044
|
)
|
1045
|
+
parser.add_argument(
|
1046
|
+
"--hicache-ratio",
|
1047
|
+
type=float,
|
1048
|
+
required=False,
|
1049
|
+
default=ServerArgs.hicache_ratio,
|
1050
|
+
help="The ratio of the size of host KV cache memory pool to the size of device pool.",
|
1051
|
+
)
|
1052
|
+
parser.add_argument(
|
1053
|
+
"--enable-deepep-moe",
|
1054
|
+
action="store_true",
|
1055
|
+
help="Enabling DeepEP MoE implementation for EP MoE.",
|
1056
|
+
)
|
987
1057
|
|
988
1058
|
# Server warmups
|
989
1059
|
parser.add_argument(
|
@@ -1014,6 +1084,21 @@ class ServerArgs:
|
|
1014
1084
|
help="Inject the outputs from jax as the input of every layer.",
|
1015
1085
|
)
|
1016
1086
|
|
1087
|
+
# Disaggregation
|
1088
|
+
parser.add_argument(
|
1089
|
+
"--disaggregation-mode",
|
1090
|
+
type=str,
|
1091
|
+
default="null",
|
1092
|
+
choices=["null", "prefill", "decode"],
|
1093
|
+
help='Only used for PD disaggregation. "prefill" for prefill-only server, and "decode" for decode-only server. If not specified, it is not PD disaggregated',
|
1094
|
+
)
|
1095
|
+
parser.add_argument(
|
1096
|
+
"--disaggregation-bootstrap-port",
|
1097
|
+
type=int,
|
1098
|
+
default=ServerArgs.disaggregation_bootstrap_port,
|
1099
|
+
help="Bootstrap server port on the prefill server. Default is 8998.",
|
1100
|
+
)
|
1101
|
+
|
1017
1102
|
@classmethod
|
1018
1103
|
def from_cli_args(cls, args: argparse.Namespace):
|
1019
1104
|
args.tp_size = args.tensor_parallel_size
|
@@ -1088,6 +1173,9 @@ class PortArgs:
|
|
1088
1173
|
# The port for nccl initialization (torch.dist)
|
1089
1174
|
nccl_port: int
|
1090
1175
|
|
1176
|
+
# The ipc filename for rpc call between Engine and Scheduler
|
1177
|
+
rpc_ipc_name: str
|
1178
|
+
|
1091
1179
|
@staticmethod
|
1092
1180
|
def init_new(server_args, dp_rank: Optional[int] = None) -> "PortArgs":
|
1093
1181
|
port = server_args.port + random.randint(100, 1000)
|
@@ -1106,6 +1194,7 @@ class PortArgs:
|
|
1106
1194
|
scheduler_input_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
|
1107
1195
|
detokenizer_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
|
1108
1196
|
nccl_port=port,
|
1197
|
+
rpc_ipc_name=f"ipc://{tempfile.NamedTemporaryFile(delete=False).name}",
|
1109
1198
|
)
|
1110
1199
|
else:
|
1111
1200
|
# DP attention. Use TCP + port to handle both single-node and multi-node.
|
@@ -1131,6 +1220,7 @@ class PortArgs:
|
|
1131
1220
|
scheduler_input_ipc_name=f"tcp://{dist_init_host}:{scheduler_input_port}",
|
1132
1221
|
detokenizer_ipc_name=f"tcp://{dist_init_host}:{port_base + 1}",
|
1133
1222
|
nccl_port=port,
|
1223
|
+
rpc_ipc_name=f"tcp://{dist_init_host}:{port_base + 2}",
|
1134
1224
|
)
|
1135
1225
|
|
1136
1226
|
|
@@ -3,8 +3,13 @@
|
|
3
3
|
from typing import List
|
4
4
|
|
5
5
|
import torch
|
6
|
-
|
7
|
-
from
|
6
|
+
|
7
|
+
from sglang.srt.utils import is_cuda_available, is_hip
|
8
|
+
|
9
|
+
if is_cuda_available() or is_hip():
|
10
|
+
from sgl_kernel import (
|
11
|
+
build_tree_kernel_efficient as sgl_build_tree_kernel_efficient,
|
12
|
+
)
|
8
13
|
|
9
14
|
|
10
15
|
def build_tree_kernel_efficient_preprocess(
|
@@ -23,7 +28,6 @@ def build_tree_kernel_efficient_preprocess(
|
|
23
28
|
top_scores = torch.topk(score_list, num_verify_tokens - 1, dim=-1)
|
24
29
|
top_scores_index = top_scores.indices
|
25
30
|
top_scores_index = torch.sort(top_scores_index).values
|
26
|
-
|
27
31
|
draft_tokens = torch.gather(ss_token_list, index=top_scores_index, dim=1)
|
28
32
|
draft_tokens = torch.cat((verified_id.unsqueeze(1), draft_tokens), dim=1).flatten()
|
29
33
|
|
@@ -108,296 +112,6 @@ def build_tree_kernel_efficient(
|
|
108
112
|
)
|
109
113
|
|
110
114
|
|
111
|
-
def build_tree_kernel(
|
112
|
-
verified_id: torch.Tensor,
|
113
|
-
score_list: List[torch.Tensor],
|
114
|
-
token_list: List[torch.Tensor],
|
115
|
-
parents_list: List[torch.Tensor],
|
116
|
-
seq_lens: torch.Tensor,
|
117
|
-
seq_lens_sum: int,
|
118
|
-
topk: int,
|
119
|
-
spec_steps: int,
|
120
|
-
num_verify_tokens: int,
|
121
|
-
):
|
122
|
-
parent_list, top_scores_index, draft_tokens = (
|
123
|
-
build_tree_kernel_efficient_preprocess(
|
124
|
-
verified_id,
|
125
|
-
score_list,
|
126
|
-
token_list,
|
127
|
-
parents_list,
|
128
|
-
num_verify_tokens,
|
129
|
-
)
|
130
|
-
)
|
131
|
-
|
132
|
-
bs = seq_lens.numel()
|
133
|
-
device = seq_lens.device
|
134
|
-
|
135
|
-
tree_mask = torch.full(
|
136
|
-
(
|
137
|
-
seq_lens_sum * num_verify_tokens
|
138
|
-
+ num_verify_tokens * num_verify_tokens * bs,
|
139
|
-
),
|
140
|
-
True,
|
141
|
-
device=device,
|
142
|
-
)
|
143
|
-
retrive_index = torch.full(
|
144
|
-
(bs, num_verify_tokens, spec_steps + 2), -1, device=device, dtype=torch.long
|
145
|
-
)
|
146
|
-
positions = torch.empty((bs * num_verify_tokens,), device=device, dtype=torch.long)
|
147
|
-
|
148
|
-
sgl_build_tree_kernel(
|
149
|
-
parent_list,
|
150
|
-
top_scores_index,
|
151
|
-
seq_lens.to(torch.int32),
|
152
|
-
tree_mask,
|
153
|
-
positions,
|
154
|
-
retrive_index,
|
155
|
-
topk,
|
156
|
-
spec_steps,
|
157
|
-
num_verify_tokens,
|
158
|
-
)
|
159
|
-
|
160
|
-
index = retrive_index.sum(dim=-1) != -spec_steps - 2
|
161
|
-
cum_len = torch.cumsum(torch.sum(index, dim=-1), dim=-1)
|
162
|
-
retrive_cum_len = torch.zeros(
|
163
|
-
(cum_len.numel() + 1,), dtype=torch.int32, device="cuda"
|
164
|
-
)
|
165
|
-
retrive_cum_len[1:] = cum_len
|
166
|
-
# TODO: this indexing cause a synchronization, optimize this
|
167
|
-
retrive_index = retrive_index[index]
|
168
|
-
return tree_mask, positions, retrive_index, retrive_cum_len, draft_tokens
|
169
|
-
|
170
|
-
|
171
|
-
def test_build_tree_kernel():
|
172
|
-
def findp(p_i, index, parent_list):
|
173
|
-
pos = index // 10
|
174
|
-
index_list = index.tolist()
|
175
|
-
parent_list = parent_list.tolist()
|
176
|
-
res = [p_i]
|
177
|
-
while True:
|
178
|
-
p = pos[p_i]
|
179
|
-
if p == 0:
|
180
|
-
break
|
181
|
-
token_idx = parent_list[p]
|
182
|
-
p_i = index_list.index(token_idx)
|
183
|
-
res.append(p_i)
|
184
|
-
return res
|
185
|
-
|
186
|
-
def create_mask(seq_len, draft_token, index, parent_list, max_depth):
|
187
|
-
mask = []
|
188
|
-
positions = []
|
189
|
-
retrive_index = []
|
190
|
-
for i, lens in enumerate(seq_len.tolist()):
|
191
|
-
first_mask = torch.full((lens + draft_token,), True)
|
192
|
-
first_mask[-(draft_token - 1) :] = False
|
193
|
-
positions.append(lens)
|
194
|
-
mask.append(first_mask)
|
195
|
-
seq_order = []
|
196
|
-
first_index = torch.Tensor([0] + [-1] * (depth + 1)).cuda().to(torch.long)
|
197
|
-
r_index = [first_index]
|
198
|
-
for j in range(draft_token - 1):
|
199
|
-
mask.append(torch.full((lens + 1,), True))
|
200
|
-
idx = findp(j, index, parent_list)
|
201
|
-
|
202
|
-
seq_order.append(idx)
|
203
|
-
positions.append(len(idx) + seq_len)
|
204
|
-
t = torch.full((draft_token - 1,), False)
|
205
|
-
t[idx] = True
|
206
|
-
mask.append(t)
|
207
|
-
|
208
|
-
for i in range(1, draft_token - 1):
|
209
|
-
is_leaf = 0
|
210
|
-
for j in range(draft_token - 1):
|
211
|
-
if i in seq_order[j]:
|
212
|
-
is_leaf += 1
|
213
|
-
|
214
|
-
if is_leaf == 1:
|
215
|
-
order_list = [0] + [x + 1 for x in seq_order[i][::-1]]
|
216
|
-
for _ in range(max_depth + 1 - len(seq_order[i])):
|
217
|
-
order_list.append(-1)
|
218
|
-
order = torch.Tensor(order_list).cuda().to(torch.long)
|
219
|
-
r_index.append(order)
|
220
|
-
retrive_index.append(torch.stack(r_index))
|
221
|
-
|
222
|
-
return (
|
223
|
-
torch.cat(mask).cuda(),
|
224
|
-
torch.Tensor(positions).cuda().to(torch.long),
|
225
|
-
torch.stack(retrive_index),
|
226
|
-
)
|
227
|
-
|
228
|
-
index = (
|
229
|
-
torch.Tensor(
|
230
|
-
[
|
231
|
-
0,
|
232
|
-
1,
|
233
|
-
2,
|
234
|
-
3,
|
235
|
-
10,
|
236
|
-
11,
|
237
|
-
12,
|
238
|
-
13,
|
239
|
-
20,
|
240
|
-
21,
|
241
|
-
22,
|
242
|
-
30,
|
243
|
-
110,
|
244
|
-
130,
|
245
|
-
150,
|
246
|
-
160,
|
247
|
-
210,
|
248
|
-
211,
|
249
|
-
212,
|
250
|
-
213,
|
251
|
-
214,
|
252
|
-
215,
|
253
|
-
216,
|
254
|
-
217,
|
255
|
-
218,
|
256
|
-
219,
|
257
|
-
220,
|
258
|
-
230,
|
259
|
-
310,
|
260
|
-
311,
|
261
|
-
312,
|
262
|
-
313,
|
263
|
-
314,
|
264
|
-
315,
|
265
|
-
316,
|
266
|
-
317,
|
267
|
-
320,
|
268
|
-
321,
|
269
|
-
322,
|
270
|
-
330,
|
271
|
-
360,
|
272
|
-
380,
|
273
|
-
390,
|
274
|
-
410,
|
275
|
-
411,
|
276
|
-
412,
|
277
|
-
413,
|
278
|
-
414,
|
279
|
-
415,
|
280
|
-
416,
|
281
|
-
417,
|
282
|
-
418,
|
283
|
-
419,
|
284
|
-
420,
|
285
|
-
421,
|
286
|
-
422,
|
287
|
-
423,
|
288
|
-
430,
|
289
|
-
431,
|
290
|
-
440,
|
291
|
-
441,
|
292
|
-
460,
|
293
|
-
470,
|
294
|
-
]
|
295
|
-
)
|
296
|
-
.to(torch.long)
|
297
|
-
.cuda()
|
298
|
-
)
|
299
|
-
|
300
|
-
parent_list = (
|
301
|
-
torch.Tensor(
|
302
|
-
[
|
303
|
-
-1,
|
304
|
-
0,
|
305
|
-
1,
|
306
|
-
2,
|
307
|
-
3,
|
308
|
-
4,
|
309
|
-
5,
|
310
|
-
6,
|
311
|
-
7,
|
312
|
-
8,
|
313
|
-
9,
|
314
|
-
10,
|
315
|
-
11,
|
316
|
-
12,
|
317
|
-
20,
|
318
|
-
30,
|
319
|
-
21,
|
320
|
-
13,
|
321
|
-
22,
|
322
|
-
40,
|
323
|
-
23,
|
324
|
-
110,
|
325
|
-
130,
|
326
|
-
160,
|
327
|
-
150,
|
328
|
-
190,
|
329
|
-
120,
|
330
|
-
111,
|
331
|
-
121,
|
332
|
-
200,
|
333
|
-
180,
|
334
|
-
210,
|
335
|
-
211,
|
336
|
-
212,
|
337
|
-
213,
|
338
|
-
214,
|
339
|
-
215,
|
340
|
-
216,
|
341
|
-
220,
|
342
|
-
230,
|
343
|
-
217,
|
344
|
-
310,
|
345
|
-
311,
|
346
|
-
312,
|
347
|
-
313,
|
348
|
-
320,
|
349
|
-
314,
|
350
|
-
321,
|
351
|
-
315,
|
352
|
-
316,
|
353
|
-
317,
|
354
|
-
]
|
355
|
-
)
|
356
|
-
.to(torch.long)
|
357
|
-
.cuda()
|
358
|
-
)
|
359
|
-
|
360
|
-
verified_seq_len = torch.Tensor([47]).to(torch.long).cuda()
|
361
|
-
bs = verified_seq_len.shape[0]
|
362
|
-
topk = 10
|
363
|
-
depth = 5 # depth <= 10
|
364
|
-
num_draft_token = 64
|
365
|
-
|
366
|
-
tree_mask = torch.full(
|
367
|
-
(
|
368
|
-
torch.sum(verified_seq_len).item() * num_draft_token
|
369
|
-
+ num_draft_token * num_draft_token * bs,
|
370
|
-
),
|
371
|
-
True,
|
372
|
-
).cuda()
|
373
|
-
retrive_index = torch.full(
|
374
|
-
(bs, num_draft_token, depth + 2), -1, device="cuda", dtype=torch.long
|
375
|
-
)
|
376
|
-
positions = torch.empty((bs * num_draft_token,), device="cuda", dtype=torch.long)
|
377
|
-
|
378
|
-
sgl_build_tree_kernel(
|
379
|
-
parent_list.unsqueeze(0),
|
380
|
-
index.unsqueeze(0),
|
381
|
-
verified_seq_len,
|
382
|
-
tree_mask,
|
383
|
-
positions,
|
384
|
-
retrive_index,
|
385
|
-
topk,
|
386
|
-
depth,
|
387
|
-
num_draft_token,
|
388
|
-
)
|
389
|
-
|
390
|
-
retrive_index = retrive_index[retrive_index.sum(dim=-1) != -depth - 2]
|
391
|
-
|
392
|
-
c_mask, c_positions, c_retive_index = create_mask(
|
393
|
-
verified_seq_len, num_draft_token, index, parent_list, depth
|
394
|
-
)
|
395
|
-
|
396
|
-
assert torch.allclose(tree_mask, c_mask), "tree mask has error."
|
397
|
-
assert torch.allclose(positions, c_positions), "positions has error."
|
398
|
-
assert torch.allclose(retrive_index, c_retive_index), "retrive_index has error."
|
399
|
-
|
400
|
-
|
401
115
|
def test_build_tree_kernel_efficient():
|
402
116
|
verified_id = torch.tensor([29974, 13], device="cuda", dtype=torch.int32)
|
403
117
|
score_list = [
|
@@ -611,59 +325,6 @@ def test_build_tree_kernel_efficient():
|
|
611
325
|
depth = 4
|
612
326
|
num_draft_token = 8
|
613
327
|
|
614
|
-
tree_mask, position, retrive_index, retrive_cum_len, draft_tokens = (
|
615
|
-
build_tree_kernel(
|
616
|
-
verified_id=verified_id,
|
617
|
-
score_list=score_list,
|
618
|
-
token_list=token_list,
|
619
|
-
parents_list=parents_list,
|
620
|
-
seq_lens=seq_lens,
|
621
|
-
seq_lens_sum=torch.sum(seq_lens).item(),
|
622
|
-
topk=topk,
|
623
|
-
spec_steps=depth,
|
624
|
-
num_verify_tokens=num_draft_token,
|
625
|
-
)
|
626
|
-
)
|
627
|
-
|
628
|
-
from sglang.srt.utils import first_rank_print
|
629
|
-
|
630
|
-
first_rank_print("=========== build tree kernel ==========")
|
631
|
-
# first_rank_print(f"{tree_mask=}", flush=True)
|
632
|
-
first_rank_print(f"{position=}", flush=True)
|
633
|
-
first_rank_print(f"{retrive_index=}", flush=True)
|
634
|
-
first_rank_print(f"{retrive_cum_len=}", flush=True)
|
635
|
-
first_rank_print(f"{draft_tokens=}", flush=True)
|
636
|
-
assert position.tolist() == [5, 6, 6, 7, 7, 8, 8, 9, 10, 11, 12, 12, 12, 12, 13, 14]
|
637
|
-
assert retrive_index.tolist() == [
|
638
|
-
[0, -1, -1, -1, -1, -1],
|
639
|
-
[0, 2, 4, 6, -1, -1],
|
640
|
-
[0, 1, 3, 5, 7, -1],
|
641
|
-
[8, -1, -1, -1, -1, -1],
|
642
|
-
[8, 9, 10, -1, -1, -1],
|
643
|
-
[8, 9, 12, -1, -1, -1],
|
644
|
-
[8, 9, 13, -1, -1, -1],
|
645
|
-
[8, 9, 11, 14, 15, -1],
|
646
|
-
]
|
647
|
-
assert retrive_cum_len.tolist() == [0, 3, 8]
|
648
|
-
assert draft_tokens.tolist() == [
|
649
|
-
29974,
|
650
|
-
29896,
|
651
|
-
29906,
|
652
|
-
29889,
|
653
|
-
29974,
|
654
|
-
29946,
|
655
|
-
29896,
|
656
|
-
29946,
|
657
|
-
13,
|
658
|
-
13,
|
659
|
-
22550,
|
660
|
-
4136,
|
661
|
-
16492,
|
662
|
-
8439,
|
663
|
-
29871,
|
664
|
-
29941,
|
665
|
-
]
|
666
|
-
|
667
328
|
(
|
668
329
|
tree_mask,
|
669
330
|
position,
|
@@ -725,4 +386,3 @@ def test_build_tree_kernel_efficient():
|
|
725
386
|
|
726
387
|
if __name__ == "__main__":
|
727
388
|
test_build_tree_kernel_efficient()
|
728
|
-
test_build_tree_kernel()
|