sglang 0.4.10__py3-none-any.whl → 0.4.10.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +20 -0
- sglang/compile_deep_gemm.py +8 -1
- sglang/global_config.py +5 -1
- sglang/srt/configs/model_config.py +1 -0
- sglang/srt/conversation.py +0 -112
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +1 -0
- sglang/srt/disaggregation/launch_lb.py +5 -20
- sglang/srt/disaggregation/mooncake/conn.py +33 -15
- sglang/srt/disaggregation/prefill.py +1 -0
- sglang/srt/distributed/device_communicators/pynccl.py +7 -0
- sglang/srt/distributed/device_communicators/pynccl_allocator.py +133 -0
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +42 -3
- sglang/srt/distributed/parallel_state.py +11 -0
- sglang/srt/entrypoints/engine.py +4 -2
- sglang/srt/entrypoints/http_server.py +35 -15
- sglang/srt/eplb/expert_distribution.py +4 -2
- sglang/srt/hf_transformers_utils.py +25 -10
- sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
- sglang/srt/layers/attention/flashattention_backend.py +7 -11
- sglang/srt/layers/attention/trtllm_mla_backend.py +372 -0
- sglang/srt/layers/attention/utils.py +6 -1
- sglang/srt/layers/attention/vision.py +27 -10
- sglang/srt/layers/communicator.py +14 -4
- sglang/srt/layers/linear.py +7 -1
- sglang/srt/layers/logits_processor.py +9 -1
- sglang/srt/layers/moe/ep_moe/layer.py +29 -68
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=352,device_name=NVIDIA_RTX_6000_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +82 -25
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +0 -31
- sglang/srt/layers/moe/token_dispatcher/__init__.py +23 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +12 -1
- sglang/srt/layers/moe/{ep_moe/token_dispatcher.py → token_dispatcher/deepep.py} +8 -15
- sglang/srt/layers/moe/utils.py +43 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +3 -2
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +1 -1
- sglang/srt/layers/quantization/fp8.py +57 -1
- sglang/srt/layers/quantization/fp8_kernel.py +0 -4
- sglang/srt/layers/quantization/w8a8_int8.py +4 -1
- sglang/srt/layers/vocab_parallel_embedding.py +7 -1
- sglang/srt/lora/lora_registry.py +7 -0
- sglang/srt/managers/cache_controller.py +43 -39
- sglang/srt/managers/data_parallel_controller.py +52 -2
- sglang/srt/managers/io_struct.py +6 -1
- sglang/srt/managers/schedule_batch.py +3 -2
- sglang/srt/managers/schedule_policy.py +3 -1
- sglang/srt/managers/scheduler.py +145 -6
- sglang/srt/managers/template_manager.py +25 -22
- sglang/srt/managers/tokenizer_manager.py +114 -62
- sglang/srt/managers/utils.py +45 -1
- sglang/srt/mem_cache/cpp_radix_tree/radix_tree.py +182 -0
- sglang/srt/mem_cache/hicache_storage.py +13 -12
- sglang/srt/mem_cache/hiradix_cache.py +21 -4
- sglang/srt/mem_cache/memory_pool.py +15 -118
- sglang/srt/mem_cache/memory_pool_host.py +350 -33
- sglang/srt/mem_cache/radix_cache_cpp.py +229 -0
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +8 -2
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp +35 -0
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +163 -0
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +238 -0
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +216 -0
- sglang/srt/model_executor/cuda_graph_runner.py +42 -4
- sglang/srt/model_executor/forward_batch_info.py +13 -3
- sglang/srt/model_executor/model_runner.py +13 -1
- sglang/srt/model_loader/weight_utils.py +2 -0
- sglang/srt/models/deepseek_v2.py +28 -23
- sglang/srt/models/glm4_moe.py +85 -22
- sglang/srt/models/grok.py +3 -3
- sglang/srt/models/llama4.py +13 -2
- sglang/srt/models/mixtral.py +3 -3
- sglang/srt/models/mllama4.py +428 -19
- sglang/srt/models/qwen2_moe.py +1 -4
- sglang/srt/models/qwen3_moe.py +7 -8
- sglang/srt/models/step3_vl.py +1 -4
- sglang/srt/multimodal/processors/base_processor.py +4 -3
- sglang/srt/multimodal/processors/gemma3n.py +0 -7
- sglang/srt/operations_strategy.py +1 -1
- sglang/srt/server_args.py +115 -21
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +18 -0
- sglang/srt/two_batch_overlap.py +6 -4
- sglang/srt/utils.py +4 -24
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +945 -0
- sglang/test/runners.py +2 -2
- sglang/test/test_utils.py +3 -3
- sglang/version.py +1 -1
- {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/METADATA +3 -2
- {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/RECORD +92 -81
- /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/mooncake_store.py +0 -0
- /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/unit_test.py +0 -0
- {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/WHEEL +0 -0
- {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/top_level.txt +0 -0
@@ -418,6 +418,26 @@ if __name__ == "__main__":
|
|
418
418
|
ServerArgs.add_cli_args(parser)
|
419
419
|
BenchArgs.add_cli_args(parser)
|
420
420
|
args = parser.parse_args()
|
421
|
+
|
422
|
+
# handling ModelScope model downloads
|
423
|
+
if os.getenv("SGLANG_USE_MODELSCOPE", "false").lower() in ("true", "1"):
|
424
|
+
if os.path.exists(args.model_path):
|
425
|
+
print(f"Using local model path: {args.model_path}")
|
426
|
+
else:
|
427
|
+
try:
|
428
|
+
from modelscope import snapshot_download
|
429
|
+
|
430
|
+
print(f"Using ModelScope to download model: {args.model_path}")
|
431
|
+
|
432
|
+
# download the model and replace args.model_path
|
433
|
+
args.model_path = snapshot_download(
|
434
|
+
args.model_path,
|
435
|
+
)
|
436
|
+
print(f"Model downloaded to: {args.model_path}")
|
437
|
+
except Exception as e:
|
438
|
+
print(f"ModelScope download failed: {str(e)}")
|
439
|
+
raise e
|
440
|
+
|
421
441
|
server_args = ServerArgs.from_cli_args(args)
|
422
442
|
bench_args = BenchArgs.from_cli_args(args)
|
423
443
|
|
sglang/compile_deep_gemm.py
CHANGED
@@ -17,6 +17,7 @@ import time
|
|
17
17
|
|
18
18
|
import requests
|
19
19
|
|
20
|
+
from sglang.srt.disaggregation.utils import FAKE_BOOTSTRAP_HOST
|
20
21
|
from sglang.srt.entrypoints.http_server import launch_server
|
21
22
|
from sglang.srt.managers.io_struct import GenerateReqInput
|
22
23
|
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
@@ -52,7 +53,9 @@ class CompileArgs:
|
|
52
53
|
|
53
54
|
|
54
55
|
@warmup("compile-deep-gemm")
|
55
|
-
async def warm_up_compile(
|
56
|
+
async def warm_up_compile(
|
57
|
+
disaggregation_mode: str, tokenizer_manager: TokenizerManager
|
58
|
+
):
|
56
59
|
print("\nGenerate warm up request for compiling DeepGEMM...\n")
|
57
60
|
generate_req_input = GenerateReqInput(
|
58
61
|
input_ids=[0, 1, 2, 3],
|
@@ -62,6 +65,10 @@ async def warm_up_compile(tokenizer_manager: TokenizerManager):
|
|
62
65
|
"ignore_eos": True,
|
63
66
|
},
|
64
67
|
)
|
68
|
+
if disaggregation_mode != "null":
|
69
|
+
generate_req_input.bootstrap_room = 0
|
70
|
+
generate_req_input.bootstrap_host = FAKE_BOOTSTRAP_HOST
|
71
|
+
|
65
72
|
await tokenizer_manager.generate_request(generate_req_input, None).__anext__()
|
66
73
|
|
67
74
|
|
sglang/global_config.py
CHANGED
@@ -30,7 +30,11 @@ class GlobalConfig:
|
|
30
30
|
self.default_new_token_ratio_decay_steps = float(
|
31
31
|
os.environ.get("SGLANG_NEW_TOKEN_RATIO_DECAY_STEPS", 600)
|
32
32
|
)
|
33
|
-
|
33
|
+
self.torch_empty_cache_interval = float(
|
34
|
+
os.environ.get(
|
35
|
+
"SGLANG_EMPTY_CACHE_INTERVAL", -1
|
36
|
+
) # in seconds. Set if you observe high memory accumulation over a long serving period.
|
37
|
+
)
|
34
38
|
# Runtime constants: others
|
35
39
|
self.retract_decode_steps = 20
|
36
40
|
self.flashinfer_workspace_size = os.environ.get(
|
@@ -112,6 +112,7 @@ class ModelConfig:
|
|
112
112
|
mm_disabled_models = [
|
113
113
|
"Gemma3ForConditionalGeneration",
|
114
114
|
"Llama4ForConditionalGeneration",
|
115
|
+
"Step3VLForConditionalGeneration",
|
115
116
|
]
|
116
117
|
if self.hf_config.architectures[0] in mm_disabled_models:
|
117
118
|
enable_multimodal = False
|
sglang/srt/conversation.py
CHANGED
@@ -954,20 +954,6 @@ register_conv_template(
|
|
954
954
|
)
|
955
955
|
)
|
956
956
|
|
957
|
-
register_conv_template(
|
958
|
-
Conversation(
|
959
|
-
name="mimo-vl",
|
960
|
-
system_message="You are MiMo, an AI assistant developed by Xiaomi.",
|
961
|
-
system_template="<|im_start|>system\n{system_message}",
|
962
|
-
roles=("<|im_start|>user", "<|im_start|>assistant"),
|
963
|
-
sep="<|im_end|>\n",
|
964
|
-
sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE,
|
965
|
-
stop_str=["<|im_end|>"],
|
966
|
-
image_token="<|vision_start|><|image_pad|><|vision_end|>",
|
967
|
-
)
|
968
|
-
)
|
969
|
-
|
970
|
-
|
971
957
|
register_conv_template(
|
972
958
|
Conversation(
|
973
959
|
name="qwen2-audio",
|
@@ -981,51 +967,11 @@ register_conv_template(
|
|
981
967
|
)
|
982
968
|
)
|
983
969
|
|
984
|
-
register_conv_template(
|
985
|
-
Conversation(
|
986
|
-
name="llama_4_vision",
|
987
|
-
system_message="You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.",
|
988
|
-
system_template="<|header_start|>system<|header_end|>\n\n{system_message}<|eot|>",
|
989
|
-
roles=("user", "assistant"),
|
990
|
-
sep_style=SeparatorStyle.LLAMA4,
|
991
|
-
sep="",
|
992
|
-
stop_str="<|eot|>",
|
993
|
-
image_token="<|image|>",
|
994
|
-
)
|
995
|
-
)
|
996
|
-
|
997
|
-
register_conv_template(
|
998
|
-
Conversation(
|
999
|
-
name="step3-vl",
|
1000
|
-
system_message="<|begin▁of▁sentence|>You are a helpful assistant",
|
1001
|
-
system_template="{system_message}\n",
|
1002
|
-
roles=(
|
1003
|
-
"<|BOT|>user\n",
|
1004
|
-
"<|BOT|>assistant\n<think>\n",
|
1005
|
-
),
|
1006
|
-
sep="<|EOT|>",
|
1007
|
-
sep_style=SeparatorStyle.NO_COLON_SINGLE,
|
1008
|
-
stop_str="<|EOT|>",
|
1009
|
-
image_token="<im_patch>",
|
1010
|
-
# add_bos=True,
|
1011
|
-
)
|
1012
|
-
)
|
1013
|
-
|
1014
970
|
|
1015
971
|
@register_conv_template_matching_function
|
1016
972
|
def match_internvl(model_path: str):
|
1017
973
|
if re.search(r"internvl", model_path, re.IGNORECASE):
|
1018
974
|
return "internvl-2-5"
|
1019
|
-
if re.search(r"intern.*s1", model_path, re.IGNORECASE):
|
1020
|
-
return "interns1"
|
1021
|
-
|
1022
|
-
|
1023
|
-
@register_conv_template_matching_function
|
1024
|
-
def match_llama_vision(model_path: str):
|
1025
|
-
if re.search(r"llama.*3\.2.*vision", model_path, re.IGNORECASE):
|
1026
|
-
return "llama_3_vision"
|
1027
|
-
if re.search(r"llama.*4.*", model_path, re.IGNORECASE):
|
1028
|
-
return "llama_4_vision"
|
1029
975
|
|
1030
976
|
|
1031
977
|
@register_conv_template_matching_function
|
@@ -1040,22 +986,6 @@ def match_vicuna(model_path: str):
|
|
1040
986
|
return "vicuna_v1.1"
|
1041
987
|
|
1042
988
|
|
1043
|
-
@register_conv_template_matching_function
|
1044
|
-
def match_llama2_chat(model_path: str):
|
1045
|
-
if re.search(
|
1046
|
-
r"llama-2.*chat|codellama.*instruct",
|
1047
|
-
model_path,
|
1048
|
-
re.IGNORECASE,
|
1049
|
-
):
|
1050
|
-
return "llama-2"
|
1051
|
-
|
1052
|
-
|
1053
|
-
@register_conv_template_matching_function
|
1054
|
-
def match_mistral(model_path: str):
|
1055
|
-
if re.search(r"pixtral|(mistral|mixtral).*instruct", model_path, re.IGNORECASE):
|
1056
|
-
return "mistral"
|
1057
|
-
|
1058
|
-
|
1059
989
|
@register_conv_template_matching_function
|
1060
990
|
def match_deepseek_vl(model_path: str):
|
1061
991
|
if re.search(r"deepseek.*vl2", model_path, re.IGNORECASE):
|
@@ -1064,12 +994,6 @@ def match_deepseek_vl(model_path: str):
|
|
1064
994
|
|
1065
995
|
@register_conv_template_matching_function
|
1066
996
|
def match_qwen_chat_ml(model_path: str):
|
1067
|
-
if re.search(r"gme.*qwen.*vl", model_path, re.IGNORECASE):
|
1068
|
-
return "gme-qwen2-vl"
|
1069
|
-
if re.search(r"qwen.*vl", model_path, re.IGNORECASE):
|
1070
|
-
return "qwen2-vl"
|
1071
|
-
if re.search(r"qwen.*audio", model_path, re.IGNORECASE):
|
1072
|
-
return "qwen2-audio"
|
1073
997
|
if re.search(
|
1074
998
|
r"llava-v1\.6-34b|llava-v1\.6-yi-34b|llava-next-video-34b|llava-onevision-qwen2",
|
1075
999
|
model_path,
|
@@ -1078,12 +1002,6 @@ def match_qwen_chat_ml(model_path: str):
|
|
1078
1002
|
return "chatml-llava"
|
1079
1003
|
|
1080
1004
|
|
1081
|
-
@register_conv_template_matching_function
|
1082
|
-
def match_gemma3_instruct(model_path: str):
|
1083
|
-
if re.search(r"gemma-3.*it", model_path, re.IGNORECASE):
|
1084
|
-
return "gemma-it"
|
1085
|
-
|
1086
|
-
|
1087
1005
|
@register_conv_template_matching_function
|
1088
1006
|
def match_openbmb_minicpm(model_path: str):
|
1089
1007
|
if re.search(r"minicpm-v", model_path, re.IGNORECASE):
|
@@ -1092,37 +1010,7 @@ def match_openbmb_minicpm(model_path: str):
|
|
1092
1010
|
return "minicpmo"
|
1093
1011
|
|
1094
1012
|
|
1095
|
-
@register_conv_template_matching_function
|
1096
|
-
def match_moonshot_kimivl(model_path: str):
|
1097
|
-
if re.search(r"kimi.*vl", model_path, re.IGNORECASE):
|
1098
|
-
return "kimi-vl"
|
1099
|
-
|
1100
|
-
|
1101
|
-
@register_conv_template_matching_function
|
1102
|
-
def match_devstral(model_path: str):
|
1103
|
-
if re.search(r"devstral", model_path, re.IGNORECASE):
|
1104
|
-
return "devstral"
|
1105
|
-
|
1106
|
-
|
1107
1013
|
@register_conv_template_matching_function
|
1108
1014
|
def match_phi_4_mm(model_path: str):
|
1109
1015
|
if "phi-4-multimodal" in model_path.lower():
|
1110
1016
|
return "phi-4-mm"
|
1111
|
-
|
1112
|
-
|
1113
|
-
@register_conv_template_matching_function
|
1114
|
-
def match_vila(model_path: str):
|
1115
|
-
if re.search(r"vila", model_path, re.IGNORECASE):
|
1116
|
-
return "chatml"
|
1117
|
-
|
1118
|
-
|
1119
|
-
@register_conv_template_matching_function
|
1120
|
-
def match_mimo_vl(model_path: str):
|
1121
|
-
if re.search(r"mimo.*vl", model_path, re.IGNORECASE):
|
1122
|
-
return "mimo-vl"
|
1123
|
-
|
1124
|
-
|
1125
|
-
# @register_conv_template_matching_function
|
1126
|
-
# def match_step3(model_path: str):
|
1127
|
-
# if re.search(r"step3", model_path, re.IGNORECASE):
|
1128
|
-
# return "step3-vl"
|
@@ -88,6 +88,7 @@ class ScheduleBatchDisaggregationDecodeMixin:
|
|
88
88
|
self.extend_lens = [r.extend_input_len for r in reqs]
|
89
89
|
self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
|
90
90
|
self.extend_input_logprob_token_ids = extend_input_logprob_token_ids
|
91
|
+
self.multimodal_inputs = [r.multimodal_inputs for r in reqs]
|
91
92
|
|
92
93
|
# Build sampling info
|
93
94
|
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
|
@@ -1,6 +1,8 @@
|
|
1
1
|
import argparse
|
2
2
|
import dataclasses
|
3
3
|
|
4
|
+
from sglang.srt.disaggregation.mini_lb import PrefillConfig, run
|
5
|
+
|
4
6
|
|
5
7
|
@dataclasses.dataclass
|
6
8
|
class LBArgs:
|
@@ -18,7 +20,7 @@ class LBArgs:
|
|
18
20
|
parser.add_argument(
|
19
21
|
"--rust-lb",
|
20
22
|
action="store_true",
|
21
|
-
help="
|
23
|
+
help="Deprecated, please use SGLang Router instead, this argument will have no effect.",
|
22
24
|
)
|
23
25
|
parser.add_argument(
|
24
26
|
"--host",
|
@@ -115,25 +117,8 @@ def main():
|
|
115
117
|
args = parser.parse_args()
|
116
118
|
lb_args = LBArgs.from_cli_args(args)
|
117
119
|
|
118
|
-
|
119
|
-
|
120
|
-
|
121
|
-
RustLB(
|
122
|
-
host=lb_args.host,
|
123
|
-
port=lb_args.port,
|
124
|
-
policy=lb_args.policy,
|
125
|
-
prefill_infos=lb_args.prefill_infos,
|
126
|
-
decode_infos=lb_args.decode_infos,
|
127
|
-
log_interval=lb_args.log_interval,
|
128
|
-
timeout=lb_args.timeout,
|
129
|
-
).start()
|
130
|
-
else:
|
131
|
-
from sglang.srt.disaggregation.mini_lb import PrefillConfig, run
|
132
|
-
|
133
|
-
prefill_configs = [
|
134
|
-
PrefillConfig(url, port) for url, port in lb_args.prefill_infos
|
135
|
-
]
|
136
|
-
run(prefill_configs, lb_args.decode_infos, lb_args.host, lb_args.port)
|
120
|
+
prefill_configs = [PrefillConfig(url, port) for url, port in lb_args.prefill_infos]
|
121
|
+
run(prefill_configs, lb_args.decode_infos, lb_args.host, lb_args.port)
|
137
122
|
|
138
123
|
|
139
124
|
if __name__ == "__main__":
|
@@ -37,6 +37,7 @@ from sglang.srt.disaggregation.utils import DisaggregationMode
|
|
37
37
|
from sglang.srt.server_args import ServerArgs
|
38
38
|
from sglang.srt.utils import (
|
39
39
|
format_tcp_address,
|
40
|
+
get_bool_env_var,
|
40
41
|
get_free_port,
|
41
42
|
get_int_env_var,
|
42
43
|
get_ip,
|
@@ -198,6 +199,10 @@ class MooncakeKVManager(BaseKVManager):
|
|
198
199
|
self.bootstrap_timeout = get_int_env_var(
|
199
200
|
"SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT", 300
|
200
201
|
)
|
202
|
+
|
203
|
+
self.enable_custom_mem_pool = get_bool_env_var(
|
204
|
+
"SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false"
|
205
|
+
)
|
201
206
|
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
202
207
|
self.heartbeat_failures = {}
|
203
208
|
self.session_pool = defaultdict(requests.Session)
|
@@ -258,6 +263,26 @@ class MooncakeKVManager(BaseKVManager):
|
|
258
263
|
socket.connect(endpoint)
|
259
264
|
return socket
|
260
265
|
|
266
|
+
def _transfer_data(self, mooncake_session_id, transfer_blocks):
|
267
|
+
if not transfer_blocks:
|
268
|
+
return 0
|
269
|
+
|
270
|
+
# TODO(shangming): Fix me when nvlink_transport of Mooncake is bug-free
|
271
|
+
if self.enable_custom_mem_pool:
|
272
|
+
# batch_transfer_sync has a higher chance to trigger an accuracy drop for MNNVL, fallback to transfer_sync temporarily
|
273
|
+
for src_addr, dst_addr, length in transfer_blocks:
|
274
|
+
status = self.engine.transfer_sync(
|
275
|
+
mooncake_session_id, src_addr, dst_addr, length
|
276
|
+
)
|
277
|
+
if status != 0:
|
278
|
+
return status
|
279
|
+
return 0
|
280
|
+
else:
|
281
|
+
src_addrs, dst_addrs, lengths = zip(*transfer_blocks)
|
282
|
+
return self.engine.batch_transfer_sync(
|
283
|
+
mooncake_session_id, list(src_addrs), list(dst_addrs), list(lengths)
|
284
|
+
)
|
285
|
+
|
261
286
|
def send_kvcache(
|
262
287
|
self,
|
263
288
|
mooncake_session_id: str,
|
@@ -283,17 +308,14 @@ class MooncakeKVManager(BaseKVManager):
|
|
283
308
|
|
284
309
|
# Worker function for processing a single layer
|
285
310
|
def process_layer(src_ptr: int, dst_ptr: int, item_len: int) -> int:
|
311
|
+
transfer_blocks = []
|
286
312
|
for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks):
|
287
313
|
src_addr = src_ptr + int(prefill_index[0]) * item_len
|
288
314
|
dst_addr = dst_ptr + int(decode_index[0]) * item_len
|
289
315
|
length = item_len * len(prefill_index)
|
316
|
+
transfer_blocks.append((src_addr, dst_addr, length))
|
290
317
|
|
291
|
-
|
292
|
-
mooncake_session_id, src_addr, dst_addr, length
|
293
|
-
)
|
294
|
-
if status != 0:
|
295
|
-
return status
|
296
|
-
return 0
|
318
|
+
return self._transfer_data(mooncake_session_id, transfer_blocks)
|
297
319
|
|
298
320
|
futures = [
|
299
321
|
executor.submit(
|
@@ -465,21 +487,17 @@ class MooncakeKVManager(BaseKVManager):
|
|
465
487
|
dst_aux_ptrs: list[int],
|
466
488
|
dst_aux_index: int,
|
467
489
|
):
|
468
|
-
|
469
|
-
dst_addr_list = []
|
470
|
-
length_list = []
|
490
|
+
transfer_blocks = []
|
471
491
|
prefill_aux_ptrs = self.kv_args.aux_data_ptrs
|
472
492
|
prefill_aux_item_lens = self.kv_args.aux_item_lens
|
493
|
+
|
473
494
|
for i, dst_aux_ptr in enumerate(dst_aux_ptrs):
|
474
495
|
length = prefill_aux_item_lens[i]
|
475
496
|
src_addr = prefill_aux_ptrs[i] + length * prefill_aux_index
|
476
497
|
dst_addr = dst_aux_ptrs[i] + length * dst_aux_index
|
477
|
-
|
478
|
-
|
479
|
-
|
480
|
-
return self.engine.batch_transfer_sync(
|
481
|
-
mooncake_session_id, src_addr_list, dst_addr_list, length_list
|
482
|
-
)
|
498
|
+
transfer_blocks.append((src_addr, dst_addr, length))
|
499
|
+
|
500
|
+
return self._transfer_data(mooncake_session_id, transfer_blocks)
|
483
501
|
|
484
502
|
def sync_status_to_decode_endpoint(
|
485
503
|
self, remote: str, dst_port: int, room: int, status: int, prefill_rank: int
|
@@ -460,6 +460,7 @@ class SchedulerDisaggregationPrefillMixin:
|
|
460
460
|
|
461
461
|
# We need to remove the sync in the following function for overlap schedule.
|
462
462
|
self.set_next_batch_sampling_info_done(batch)
|
463
|
+
self.maybe_send_health_check_signal()
|
463
464
|
|
464
465
|
def process_disagg_prefill_inflight_queue(
|
465
466
|
self: Scheduler, rids_to_check: Optional[List[str]] = None
|
@@ -75,6 +75,7 @@ class PyNcclCommunicator:
|
|
75
75
|
self.available = True
|
76
76
|
self.disabled = False
|
77
77
|
|
78
|
+
self.nccl_version = self.nccl.ncclGetRawVersion()
|
78
79
|
if self.rank == 0:
|
79
80
|
logger.info("sglang is using nccl==%s", self.nccl.ncclGetVersion())
|
80
81
|
|
@@ -259,6 +260,12 @@ class PyNcclCommunicator:
|
|
259
260
|
cudaStream_t(stream.cuda_stream),
|
260
261
|
)
|
261
262
|
|
263
|
+
def register_comm_window_raw(self, ptr: int, size: int):
|
264
|
+
return self.nccl.ncclCommWindowRegister(self.comm, buffer_type(ptr), size, 1)
|
265
|
+
|
266
|
+
def deregister_comm_window(self, window):
|
267
|
+
return self.nccl.ncclCommWindowDeregister(self.comm, window)
|
268
|
+
|
262
269
|
@contextmanager
|
263
270
|
def change_state(
|
264
271
|
self, enable: Optional[bool] = None, stream: Optional[torch.cuda.Stream] = None
|
@@ -0,0 +1,133 @@
|
|
1
|
+
import tempfile
|
2
|
+
|
3
|
+
import torch
|
4
|
+
from packaging import version
|
5
|
+
from torch.cuda.memory import CUDAPluggableAllocator
|
6
|
+
|
7
|
+
from sglang.srt.distributed.parallel_state import GroupCoordinator
|
8
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
9
|
+
|
10
|
+
nccl_allocator_source = """
|
11
|
+
#include <nccl.h>
|
12
|
+
extern "C" {
|
13
|
+
|
14
|
+
void* nccl_alloc_plug(size_t size, int device, void* stream) {
|
15
|
+
void* ptr;
|
16
|
+
ncclResult_t err = ncclMemAlloc(&ptr, size);
|
17
|
+
return ptr;
|
18
|
+
|
19
|
+
}
|
20
|
+
|
21
|
+
void nccl_free_plug(void* ptr, size_t size, int device, void* stream) {
|
22
|
+
ncclResult_t err = ncclMemFree(ptr);
|
23
|
+
}
|
24
|
+
|
25
|
+
}
|
26
|
+
"""
|
27
|
+
|
28
|
+
_allocator = None
|
29
|
+
_mem_pool = None
|
30
|
+
_registered_base_addrs = set()
|
31
|
+
_graph_pool_id = None
|
32
|
+
|
33
|
+
|
34
|
+
def is_symmetric_memory_enabled():
|
35
|
+
return global_server_args_dict["enable_symm_mem"]
|
36
|
+
|
37
|
+
|
38
|
+
def set_graph_pool_id(graph_pool_id):
|
39
|
+
global _graph_pool_id
|
40
|
+
_graph_pool_id = graph_pool_id
|
41
|
+
|
42
|
+
|
43
|
+
def get_nccl_mem_pool():
|
44
|
+
global _allocator, _mem_pool
|
45
|
+
if _mem_pool is None:
|
46
|
+
out_dir = tempfile.gettempdir()
|
47
|
+
nccl_allocator_libname = "nccl_allocator"
|
48
|
+
torch.utils.cpp_extension.load_inline(
|
49
|
+
name=nccl_allocator_libname,
|
50
|
+
cpp_sources=nccl_allocator_source,
|
51
|
+
with_cuda=True,
|
52
|
+
extra_ldflags=["-lnccl"],
|
53
|
+
verbose=True,
|
54
|
+
is_python_module=False,
|
55
|
+
build_directory=out_dir,
|
56
|
+
)
|
57
|
+
_allocator = CUDAPluggableAllocator(
|
58
|
+
f"{out_dir}/{nccl_allocator_libname}.so",
|
59
|
+
"nccl_alloc_plug",
|
60
|
+
"nccl_free_plug",
|
61
|
+
).allocator()
|
62
|
+
_mem_pool = torch.cuda.MemPool(_allocator)
|
63
|
+
return _mem_pool
|
64
|
+
|
65
|
+
|
66
|
+
class use_symmetric_memory:
|
67
|
+
def __init__(self, group_coordinator: GroupCoordinator):
|
68
|
+
if not is_symmetric_memory_enabled():
|
69
|
+
self.group_coordinator = None
|
70
|
+
self._mem_pool_ctx = None
|
71
|
+
self.is_graph_capture = None
|
72
|
+
self.device = None
|
73
|
+
self.pre_2_8_0 = None
|
74
|
+
else:
|
75
|
+
self.group_coordinator = group_coordinator
|
76
|
+
self._mem_pool_ctx = torch.cuda.use_mem_pool(get_nccl_mem_pool())
|
77
|
+
self.is_graph_capture = torch.cuda.is_current_stream_capturing()
|
78
|
+
self.device = torch.cuda.current_device()
|
79
|
+
self.pre_2_8_0 = version.parse(torch.__version__) < version.parse("2.8.0")
|
80
|
+
|
81
|
+
def __enter__(self):
|
82
|
+
if not is_symmetric_memory_enabled():
|
83
|
+
return self
|
84
|
+
assert (
|
85
|
+
self.group_coordinator.pynccl_comm is not None
|
86
|
+
), f"Symmetric memory requires pynccl to be enabled in group '{self.group_coordinator.group_name}'"
|
87
|
+
assert (
|
88
|
+
self.group_coordinator.pynccl_comm.nccl_version >= 22703
|
89
|
+
), "NCCL version 2.27.3 or higher is required for NCCL symmetric memory"
|
90
|
+
if self.is_graph_capture:
|
91
|
+
assert (
|
92
|
+
_graph_pool_id is not None
|
93
|
+
), "graph_pool_id is not set under graph capture"
|
94
|
+
# Pause graph memory pool to use symmetric memory with cuda graph
|
95
|
+
if self.pre_2_8_0:
|
96
|
+
torch._C._cuda_endAllocateCurrentStreamToPool(
|
97
|
+
self.device, _graph_pool_id
|
98
|
+
)
|
99
|
+
else:
|
100
|
+
torch._C._cuda_endAllocateToPool(self.device, _graph_pool_id)
|
101
|
+
self._mem_pool_ctx.__enter__()
|
102
|
+
return self
|
103
|
+
|
104
|
+
def tag(self, tensor: torch.Tensor):
|
105
|
+
if not is_symmetric_memory_enabled():
|
106
|
+
return
|
107
|
+
tensor.symmetric_memory = True
|
108
|
+
|
109
|
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
110
|
+
if not is_symmetric_memory_enabled():
|
111
|
+
return
|
112
|
+
global _registered_base_addrs
|
113
|
+
self._mem_pool_ctx.__exit__(exc_type, exc_val, exc_tb)
|
114
|
+
for segment in get_nccl_mem_pool().snapshot():
|
115
|
+
if segment["address"] not in _registered_base_addrs:
|
116
|
+
if segment["stream"] == 0 and self.pre_2_8_0:
|
117
|
+
# PyTorch version < 2.8.0 has a multi-thread MemPool bug
|
118
|
+
# See https://github.com/pytorch/pytorch/issues/152861
|
119
|
+
# Fixed at https://github.com/pytorch/pytorch/commit/f01e628e3b31852983ab30b25bf251f557ba9c0b
|
120
|
+
# WAR is to skip allocations on the default stream since the forward_pass thread always runs on a custom stream
|
121
|
+
continue
|
122
|
+
self.group_coordinator.pynccl_comm.register_comm_window_raw(
|
123
|
+
segment["address"], segment["total_size"]
|
124
|
+
)
|
125
|
+
_registered_base_addrs.add(segment["address"])
|
126
|
+
|
127
|
+
if self.is_graph_capture:
|
128
|
+
if self.pre_2_8_0:
|
129
|
+
torch._C._cuda_beginAllocateToPool(self.device, _graph_pool_id)
|
130
|
+
else:
|
131
|
+
torch._C._cuda_beginAllocateCurrentThreadToPool(
|
132
|
+
self.device, _graph_pool_id
|
133
|
+
)
|
@@ -67,6 +67,7 @@ def find_nccl_library() -> str:
|
|
67
67
|
|
68
68
|
ncclResult_t = ctypes.c_int
|
69
69
|
ncclComm_t = ctypes.c_void_p
|
70
|
+
ncclWindow_t = ctypes.c_void_p
|
70
71
|
|
71
72
|
|
72
73
|
class ncclUniqueId(ctypes.Structure):
|
@@ -279,6 +280,23 @@ class NCCLLibrary:
|
|
279
280
|
Function("ncclCommDestroy", ncclResult_t, [ncclComm_t]),
|
280
281
|
]
|
281
282
|
|
283
|
+
exported_functions_symm_mem = [
|
284
|
+
# ncclResult_t ncclCommWindowRegister(ncclComm_t comm, void* buff, size_t size, ncclWindow_t* win, int winFlags);
|
285
|
+
Function(
|
286
|
+
"ncclCommWindowRegister",
|
287
|
+
ncclResult_t,
|
288
|
+
[
|
289
|
+
ncclComm_t,
|
290
|
+
buffer_type,
|
291
|
+
ctypes.c_size_t,
|
292
|
+
ctypes.POINTER(ncclWindow_t),
|
293
|
+
ctypes.c_int,
|
294
|
+
],
|
295
|
+
),
|
296
|
+
# ncclResult_t ncclCommWindowDeregister(ncclComm_t comm, ncclWindow_t win);
|
297
|
+
Function("ncclCommWindowDeregister", ncclResult_t, [ncclComm_t, ncclWindow_t]),
|
298
|
+
]
|
299
|
+
|
282
300
|
# class attribute to store the mapping from the path to the library
|
283
301
|
# to avoid loading the same library multiple times
|
284
302
|
path_to_library_cache: Dict[str, Any] = {}
|
@@ -312,7 +330,10 @@ class NCCLLibrary:
|
|
312
330
|
|
313
331
|
if so_file not in NCCLLibrary.path_to_dict_mapping:
|
314
332
|
_funcs: Dict[str, Any] = {}
|
315
|
-
|
333
|
+
exported_functions = NCCLLibrary.exported_functions
|
334
|
+
if hasattr(self.lib, "ncclCommWindowRegister"):
|
335
|
+
exported_functions.extend(NCCLLibrary.exported_functions_symm_mem)
|
336
|
+
for func in exported_functions:
|
316
337
|
f = getattr(self.lib, func.name)
|
317
338
|
f.restype = func.restype
|
318
339
|
f.argtypes = func.argtypes
|
@@ -328,10 +349,14 @@ class NCCLLibrary:
|
|
328
349
|
error_str = self.ncclGetErrorString(result)
|
329
350
|
raise RuntimeError(f"NCCL error: {error_str}")
|
330
351
|
|
331
|
-
def
|
352
|
+
def ncclGetRawVersion(self) -> int:
|
332
353
|
version = ctypes.c_int()
|
333
354
|
self.NCCL_CHECK(self._funcs["ncclGetVersion"](ctypes.byref(version)))
|
334
|
-
|
355
|
+
# something like 21903
|
356
|
+
return version.value
|
357
|
+
|
358
|
+
def ncclGetVersion(self) -> str:
|
359
|
+
version_str = str(self.ncclGetRawVersion())
|
335
360
|
# something like 21903 --> "2.19.3"
|
336
361
|
major = version_str[0].lstrip("0")
|
337
362
|
minor = version_str[1:3].lstrip("0")
|
@@ -460,6 +485,20 @@ class NCCLLibrary:
|
|
460
485
|
def ncclCommDestroy(self, comm: ncclComm_t) -> None:
|
461
486
|
self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm))
|
462
487
|
|
488
|
+
def ncclCommWindowRegister(
|
489
|
+
self, comm: ncclComm_t, buff: buffer_type, size: int, win_flags: int
|
490
|
+
) -> ncclWindow_t:
|
491
|
+
window = ncclWindow_t()
|
492
|
+
self.NCCL_CHECK(
|
493
|
+
self._funcs["ncclCommWindowRegister"](
|
494
|
+
comm, buff, size, ctypes.byref(window), win_flags
|
495
|
+
)
|
496
|
+
)
|
497
|
+
return window
|
498
|
+
|
499
|
+
def ncclCommWindowDeregister(self, comm: ncclComm_t, window: ncclWindow_t) -> None:
|
500
|
+
self.NCCL_CHECK(self._funcs["ncclCommWindowDeregister"](comm, window))
|
501
|
+
|
463
502
|
|
464
503
|
__all__ = [
|
465
504
|
"NCCLLibrary",
|
@@ -497,6 +497,17 @@ class GroupCoordinator:
|
|
497
497
|
if self.npu_communicator is not None and not self.npu_communicator.disabled:
|
498
498
|
return self.npu_communicator.all_reduce(input_)
|
499
499
|
|
500
|
+
if (
|
501
|
+
self.pynccl_comm is not None
|
502
|
+
and hasattr(input_, "symmetric_memory")
|
503
|
+
and input_.symmetric_memory
|
504
|
+
):
|
505
|
+
with self.pynccl_comm.change_state(
|
506
|
+
enable=True, stream=torch.cuda.current_stream()
|
507
|
+
):
|
508
|
+
self.pynccl_comm.all_reduce(input_)
|
509
|
+
return input_
|
510
|
+
|
500
511
|
outplace_all_reduce_method = None
|
501
512
|
if (
|
502
513
|
self.qr_comm is not None
|
sglang/srt/entrypoints/engine.py
CHANGED
@@ -623,8 +623,9 @@ class Engine(EngineBase):
|
|
623
623
|
def _set_envs_and_config(server_args: ServerArgs):
|
624
624
|
# Set global environments
|
625
625
|
os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
|
626
|
-
os.environ["NCCL_CUMEM_ENABLE"] =
|
627
|
-
|
626
|
+
os.environ["NCCL_CUMEM_ENABLE"] = str(int(server_args.enable_symm_mem))
|
627
|
+
if not server_args.enable_symm_mem:
|
628
|
+
os.environ["NCCL_NVLS_ENABLE"] = str(int(server_args.enable_nccl_nvls))
|
628
629
|
os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
|
629
630
|
os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4"
|
630
631
|
os.environ["CUDA_MODULE_LOADING"] = "AUTO"
|
@@ -731,6 +732,7 @@ def _launch_subprocesses(
|
|
731
732
|
pp_rank,
|
732
733
|
None,
|
733
734
|
writer,
|
735
|
+
None,
|
734
736
|
),
|
735
737
|
)
|
736
738
|
|