sglang 0.4.10__py3-none-any.whl → 0.4.10.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (92) hide show
  1. sglang/bench_offline_throughput.py +20 -0
  2. sglang/compile_deep_gemm.py +8 -1
  3. sglang/global_config.py +5 -1
  4. sglang/srt/configs/model_config.py +1 -0
  5. sglang/srt/conversation.py +0 -112
  6. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +1 -0
  7. sglang/srt/disaggregation/launch_lb.py +5 -20
  8. sglang/srt/disaggregation/mooncake/conn.py +33 -15
  9. sglang/srt/disaggregation/prefill.py +1 -0
  10. sglang/srt/distributed/device_communicators/pynccl.py +7 -0
  11. sglang/srt/distributed/device_communicators/pynccl_allocator.py +133 -0
  12. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +42 -3
  13. sglang/srt/distributed/parallel_state.py +11 -0
  14. sglang/srt/entrypoints/engine.py +4 -2
  15. sglang/srt/entrypoints/http_server.py +35 -15
  16. sglang/srt/eplb/expert_distribution.py +4 -2
  17. sglang/srt/hf_transformers_utils.py +25 -10
  18. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  19. sglang/srt/layers/attention/flashattention_backend.py +7 -11
  20. sglang/srt/layers/attention/trtllm_mla_backend.py +372 -0
  21. sglang/srt/layers/attention/utils.py +6 -1
  22. sglang/srt/layers/attention/vision.py +27 -10
  23. sglang/srt/layers/communicator.py +14 -4
  24. sglang/srt/layers/linear.py +7 -1
  25. sglang/srt/layers/logits_processor.py +9 -1
  26. sglang/srt/layers/moe/ep_moe/layer.py +29 -68
  27. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=352,device_name=NVIDIA_RTX_6000_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  28. sglang/srt/layers/moe/fused_moe_triton/layer.py +82 -25
  29. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +0 -31
  30. sglang/srt/layers/moe/token_dispatcher/__init__.py +23 -0
  31. sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +12 -1
  32. sglang/srt/layers/moe/{ep_moe/token_dispatcher.py → token_dispatcher/deepep.py} +8 -15
  33. sglang/srt/layers/moe/utils.py +43 -0
  34. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +3 -2
  35. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +1 -1
  36. sglang/srt/layers/quantization/fp8.py +57 -1
  37. sglang/srt/layers/quantization/fp8_kernel.py +0 -4
  38. sglang/srt/layers/quantization/w8a8_int8.py +4 -1
  39. sglang/srt/layers/vocab_parallel_embedding.py +7 -1
  40. sglang/srt/lora/lora_registry.py +7 -0
  41. sglang/srt/managers/cache_controller.py +43 -39
  42. sglang/srt/managers/data_parallel_controller.py +52 -2
  43. sglang/srt/managers/io_struct.py +6 -1
  44. sglang/srt/managers/schedule_batch.py +3 -2
  45. sglang/srt/managers/schedule_policy.py +3 -1
  46. sglang/srt/managers/scheduler.py +145 -6
  47. sglang/srt/managers/template_manager.py +25 -22
  48. sglang/srt/managers/tokenizer_manager.py +114 -62
  49. sglang/srt/managers/utils.py +45 -1
  50. sglang/srt/mem_cache/cpp_radix_tree/radix_tree.py +182 -0
  51. sglang/srt/mem_cache/hicache_storage.py +13 -12
  52. sglang/srt/mem_cache/hiradix_cache.py +21 -4
  53. sglang/srt/mem_cache/memory_pool.py +15 -118
  54. sglang/srt/mem_cache/memory_pool_host.py +350 -33
  55. sglang/srt/mem_cache/radix_cache_cpp.py +229 -0
  56. sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +8 -2
  57. sglang/srt/mem_cache/storage/hf3fs/hf3fs_utils.cpp +35 -0
  58. sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +163 -0
  59. sglang/srt/mem_cache/storage/nixl/nixl_utils.py +238 -0
  60. sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +216 -0
  61. sglang/srt/model_executor/cuda_graph_runner.py +42 -4
  62. sglang/srt/model_executor/forward_batch_info.py +13 -3
  63. sglang/srt/model_executor/model_runner.py +13 -1
  64. sglang/srt/model_loader/weight_utils.py +2 -0
  65. sglang/srt/models/deepseek_v2.py +28 -23
  66. sglang/srt/models/glm4_moe.py +85 -22
  67. sglang/srt/models/grok.py +3 -3
  68. sglang/srt/models/llama4.py +13 -2
  69. sglang/srt/models/mixtral.py +3 -3
  70. sglang/srt/models/mllama4.py +428 -19
  71. sglang/srt/models/qwen2_moe.py +1 -4
  72. sglang/srt/models/qwen3_moe.py +7 -8
  73. sglang/srt/models/step3_vl.py +1 -4
  74. sglang/srt/multimodal/processors/base_processor.py +4 -3
  75. sglang/srt/multimodal/processors/gemma3n.py +0 -7
  76. sglang/srt/operations_strategy.py +1 -1
  77. sglang/srt/server_args.py +115 -21
  78. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +18 -0
  79. sglang/srt/two_batch_overlap.py +6 -4
  80. sglang/srt/utils.py +4 -24
  81. sglang/srt/weight_sync/utils.py +1 -1
  82. sglang/test/attention/test_trtllm_mla_backend.py +945 -0
  83. sglang/test/runners.py +2 -2
  84. sglang/test/test_utils.py +3 -3
  85. sglang/version.py +1 -1
  86. {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/METADATA +3 -2
  87. {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/RECORD +92 -81
  88. /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/mooncake_store.py +0 -0
  89. /sglang/srt/mem_cache/{mooncake_store → storage/mooncake_store}/unit_test.py +0 -0
  90. {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/WHEEL +0 -0
  91. {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/licenses/LICENSE +0 -0
  92. {sglang-0.4.10.dist-info → sglang-0.4.10.post2.dist-info}/top_level.txt +0 -0
@@ -418,6 +418,26 @@ if __name__ == "__main__":
418
418
  ServerArgs.add_cli_args(parser)
419
419
  BenchArgs.add_cli_args(parser)
420
420
  args = parser.parse_args()
421
+
422
+ # handling ModelScope model downloads
423
+ if os.getenv("SGLANG_USE_MODELSCOPE", "false").lower() in ("true", "1"):
424
+ if os.path.exists(args.model_path):
425
+ print(f"Using local model path: {args.model_path}")
426
+ else:
427
+ try:
428
+ from modelscope import snapshot_download
429
+
430
+ print(f"Using ModelScope to download model: {args.model_path}")
431
+
432
+ # download the model and replace args.model_path
433
+ args.model_path = snapshot_download(
434
+ args.model_path,
435
+ )
436
+ print(f"Model downloaded to: {args.model_path}")
437
+ except Exception as e:
438
+ print(f"ModelScope download failed: {str(e)}")
439
+ raise e
440
+
421
441
  server_args = ServerArgs.from_cli_args(args)
422
442
  bench_args = BenchArgs.from_cli_args(args)
423
443
 
@@ -17,6 +17,7 @@ import time
17
17
 
18
18
  import requests
19
19
 
20
+ from sglang.srt.disaggregation.utils import FAKE_BOOTSTRAP_HOST
20
21
  from sglang.srt.entrypoints.http_server import launch_server
21
22
  from sglang.srt.managers.io_struct import GenerateReqInput
22
23
  from sglang.srt.managers.tokenizer_manager import TokenizerManager
@@ -52,7 +53,9 @@ class CompileArgs:
52
53
 
53
54
 
54
55
  @warmup("compile-deep-gemm")
55
- async def warm_up_compile(tokenizer_manager: TokenizerManager):
56
+ async def warm_up_compile(
57
+ disaggregation_mode: str, tokenizer_manager: TokenizerManager
58
+ ):
56
59
  print("\nGenerate warm up request for compiling DeepGEMM...\n")
57
60
  generate_req_input = GenerateReqInput(
58
61
  input_ids=[0, 1, 2, 3],
@@ -62,6 +65,10 @@ async def warm_up_compile(tokenizer_manager: TokenizerManager):
62
65
  "ignore_eos": True,
63
66
  },
64
67
  )
68
+ if disaggregation_mode != "null":
69
+ generate_req_input.bootstrap_room = 0
70
+ generate_req_input.bootstrap_host = FAKE_BOOTSTRAP_HOST
71
+
65
72
  await tokenizer_manager.generate_request(generate_req_input, None).__anext__()
66
73
 
67
74
 
sglang/global_config.py CHANGED
@@ -30,7 +30,11 @@ class GlobalConfig:
30
30
  self.default_new_token_ratio_decay_steps = float(
31
31
  os.environ.get("SGLANG_NEW_TOKEN_RATIO_DECAY_STEPS", 600)
32
32
  )
33
-
33
+ self.torch_empty_cache_interval = float(
34
+ os.environ.get(
35
+ "SGLANG_EMPTY_CACHE_INTERVAL", -1
36
+ ) # in seconds. Set if you observe high memory accumulation over a long serving period.
37
+ )
34
38
  # Runtime constants: others
35
39
  self.retract_decode_steps = 20
36
40
  self.flashinfer_workspace_size = os.environ.get(
@@ -112,6 +112,7 @@ class ModelConfig:
112
112
  mm_disabled_models = [
113
113
  "Gemma3ForConditionalGeneration",
114
114
  "Llama4ForConditionalGeneration",
115
+ "Step3VLForConditionalGeneration",
115
116
  ]
116
117
  if self.hf_config.architectures[0] in mm_disabled_models:
117
118
  enable_multimodal = False
@@ -954,20 +954,6 @@ register_conv_template(
954
954
  )
955
955
  )
956
956
 
957
- register_conv_template(
958
- Conversation(
959
- name="mimo-vl",
960
- system_message="You are MiMo, an AI assistant developed by Xiaomi.",
961
- system_template="<|im_start|>system\n{system_message}",
962
- roles=("<|im_start|>user", "<|im_start|>assistant"),
963
- sep="<|im_end|>\n",
964
- sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE,
965
- stop_str=["<|im_end|>"],
966
- image_token="<|vision_start|><|image_pad|><|vision_end|>",
967
- )
968
- )
969
-
970
-
971
957
  register_conv_template(
972
958
  Conversation(
973
959
  name="qwen2-audio",
@@ -981,51 +967,11 @@ register_conv_template(
981
967
  )
982
968
  )
983
969
 
984
- register_conv_template(
985
- Conversation(
986
- name="llama_4_vision",
987
- system_message="You are a helpful language and vision assistant. You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.",
988
- system_template="<|header_start|>system<|header_end|>\n\n{system_message}<|eot|>",
989
- roles=("user", "assistant"),
990
- sep_style=SeparatorStyle.LLAMA4,
991
- sep="",
992
- stop_str="<|eot|>",
993
- image_token="<|image|>",
994
- )
995
- )
996
-
997
- register_conv_template(
998
- Conversation(
999
- name="step3-vl",
1000
- system_message="<|begin▁of▁sentence|>You are a helpful assistant",
1001
- system_template="{system_message}\n",
1002
- roles=(
1003
- "<|BOT|>user\n",
1004
- "<|BOT|>assistant\n<think>\n",
1005
- ),
1006
- sep="<|EOT|>",
1007
- sep_style=SeparatorStyle.NO_COLON_SINGLE,
1008
- stop_str="<|EOT|>",
1009
- image_token="<im_patch>",
1010
- # add_bos=True,
1011
- )
1012
- )
1013
-
1014
970
 
1015
971
  @register_conv_template_matching_function
1016
972
  def match_internvl(model_path: str):
1017
973
  if re.search(r"internvl", model_path, re.IGNORECASE):
1018
974
  return "internvl-2-5"
1019
- if re.search(r"intern.*s1", model_path, re.IGNORECASE):
1020
- return "interns1"
1021
-
1022
-
1023
- @register_conv_template_matching_function
1024
- def match_llama_vision(model_path: str):
1025
- if re.search(r"llama.*3\.2.*vision", model_path, re.IGNORECASE):
1026
- return "llama_3_vision"
1027
- if re.search(r"llama.*4.*", model_path, re.IGNORECASE):
1028
- return "llama_4_vision"
1029
975
 
1030
976
 
1031
977
  @register_conv_template_matching_function
@@ -1040,22 +986,6 @@ def match_vicuna(model_path: str):
1040
986
  return "vicuna_v1.1"
1041
987
 
1042
988
 
1043
- @register_conv_template_matching_function
1044
- def match_llama2_chat(model_path: str):
1045
- if re.search(
1046
- r"llama-2.*chat|codellama.*instruct",
1047
- model_path,
1048
- re.IGNORECASE,
1049
- ):
1050
- return "llama-2"
1051
-
1052
-
1053
- @register_conv_template_matching_function
1054
- def match_mistral(model_path: str):
1055
- if re.search(r"pixtral|(mistral|mixtral).*instruct", model_path, re.IGNORECASE):
1056
- return "mistral"
1057
-
1058
-
1059
989
  @register_conv_template_matching_function
1060
990
  def match_deepseek_vl(model_path: str):
1061
991
  if re.search(r"deepseek.*vl2", model_path, re.IGNORECASE):
@@ -1064,12 +994,6 @@ def match_deepseek_vl(model_path: str):
1064
994
 
1065
995
  @register_conv_template_matching_function
1066
996
  def match_qwen_chat_ml(model_path: str):
1067
- if re.search(r"gme.*qwen.*vl", model_path, re.IGNORECASE):
1068
- return "gme-qwen2-vl"
1069
- if re.search(r"qwen.*vl", model_path, re.IGNORECASE):
1070
- return "qwen2-vl"
1071
- if re.search(r"qwen.*audio", model_path, re.IGNORECASE):
1072
- return "qwen2-audio"
1073
997
  if re.search(
1074
998
  r"llava-v1\.6-34b|llava-v1\.6-yi-34b|llava-next-video-34b|llava-onevision-qwen2",
1075
999
  model_path,
@@ -1078,12 +1002,6 @@ def match_qwen_chat_ml(model_path: str):
1078
1002
  return "chatml-llava"
1079
1003
 
1080
1004
 
1081
- @register_conv_template_matching_function
1082
- def match_gemma3_instruct(model_path: str):
1083
- if re.search(r"gemma-3.*it", model_path, re.IGNORECASE):
1084
- return "gemma-it"
1085
-
1086
-
1087
1005
  @register_conv_template_matching_function
1088
1006
  def match_openbmb_minicpm(model_path: str):
1089
1007
  if re.search(r"minicpm-v", model_path, re.IGNORECASE):
@@ -1092,37 +1010,7 @@ def match_openbmb_minicpm(model_path: str):
1092
1010
  return "minicpmo"
1093
1011
 
1094
1012
 
1095
- @register_conv_template_matching_function
1096
- def match_moonshot_kimivl(model_path: str):
1097
- if re.search(r"kimi.*vl", model_path, re.IGNORECASE):
1098
- return "kimi-vl"
1099
-
1100
-
1101
- @register_conv_template_matching_function
1102
- def match_devstral(model_path: str):
1103
- if re.search(r"devstral", model_path, re.IGNORECASE):
1104
- return "devstral"
1105
-
1106
-
1107
1013
  @register_conv_template_matching_function
1108
1014
  def match_phi_4_mm(model_path: str):
1109
1015
  if "phi-4-multimodal" in model_path.lower():
1110
1016
  return "phi-4-mm"
1111
-
1112
-
1113
- @register_conv_template_matching_function
1114
- def match_vila(model_path: str):
1115
- if re.search(r"vila", model_path, re.IGNORECASE):
1116
- return "chatml"
1117
-
1118
-
1119
- @register_conv_template_matching_function
1120
- def match_mimo_vl(model_path: str):
1121
- if re.search(r"mimo.*vl", model_path, re.IGNORECASE):
1122
- return "mimo-vl"
1123
-
1124
-
1125
- # @register_conv_template_matching_function
1126
- # def match_step3(model_path: str):
1127
- # if re.search(r"step3", model_path, re.IGNORECASE):
1128
- # return "step3-vl"
@@ -88,6 +88,7 @@ class ScheduleBatchDisaggregationDecodeMixin:
88
88
  self.extend_lens = [r.extend_input_len for r in reqs]
89
89
  self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
90
90
  self.extend_input_logprob_token_ids = extend_input_logprob_token_ids
91
+ self.multimodal_inputs = [r.multimodal_inputs for r in reqs]
91
92
 
92
93
  # Build sampling info
93
94
  self.sampling_info = SamplingBatchInfo.from_schedule_batch(
@@ -1,6 +1,8 @@
1
1
  import argparse
2
2
  import dataclasses
3
3
 
4
+ from sglang.srt.disaggregation.mini_lb import PrefillConfig, run
5
+
4
6
 
5
7
  @dataclasses.dataclass
6
8
  class LBArgs:
@@ -18,7 +20,7 @@ class LBArgs:
18
20
  parser.add_argument(
19
21
  "--rust-lb",
20
22
  action="store_true",
21
- help="Use Rust load balancer",
23
+ help="Deprecated, please use SGLang Router instead, this argument will have no effect.",
22
24
  )
23
25
  parser.add_argument(
24
26
  "--host",
@@ -115,25 +117,8 @@ def main():
115
117
  args = parser.parse_args()
116
118
  lb_args = LBArgs.from_cli_args(args)
117
119
 
118
- if lb_args.rust_lb:
119
- from sgl_pdlb._rust import LoadBalancer as RustLB
120
-
121
- RustLB(
122
- host=lb_args.host,
123
- port=lb_args.port,
124
- policy=lb_args.policy,
125
- prefill_infos=lb_args.prefill_infos,
126
- decode_infos=lb_args.decode_infos,
127
- log_interval=lb_args.log_interval,
128
- timeout=lb_args.timeout,
129
- ).start()
130
- else:
131
- from sglang.srt.disaggregation.mini_lb import PrefillConfig, run
132
-
133
- prefill_configs = [
134
- PrefillConfig(url, port) for url, port in lb_args.prefill_infos
135
- ]
136
- run(prefill_configs, lb_args.decode_infos, lb_args.host, lb_args.port)
120
+ prefill_configs = [PrefillConfig(url, port) for url, port in lb_args.prefill_infos]
121
+ run(prefill_configs, lb_args.decode_infos, lb_args.host, lb_args.port)
137
122
 
138
123
 
139
124
  if __name__ == "__main__":
@@ -37,6 +37,7 @@ from sglang.srt.disaggregation.utils import DisaggregationMode
37
37
  from sglang.srt.server_args import ServerArgs
38
38
  from sglang.srt.utils import (
39
39
  format_tcp_address,
40
+ get_bool_env_var,
40
41
  get_free_port,
41
42
  get_int_env_var,
42
43
  get_ip,
@@ -198,6 +199,10 @@ class MooncakeKVManager(BaseKVManager):
198
199
  self.bootstrap_timeout = get_int_env_var(
199
200
  "SGLANG_DISAGGREGATION_BOOTSTRAP_TIMEOUT", 300
200
201
  )
202
+
203
+ self.enable_custom_mem_pool = get_bool_env_var(
204
+ "SGLANG_MOONCAKE_CUSTOM_MEM_POOL", "false"
205
+ )
201
206
  elif self.disaggregation_mode == DisaggregationMode.DECODE:
202
207
  self.heartbeat_failures = {}
203
208
  self.session_pool = defaultdict(requests.Session)
@@ -258,6 +263,26 @@ class MooncakeKVManager(BaseKVManager):
258
263
  socket.connect(endpoint)
259
264
  return socket
260
265
 
266
+ def _transfer_data(self, mooncake_session_id, transfer_blocks):
267
+ if not transfer_blocks:
268
+ return 0
269
+
270
+ # TODO(shangming): Fix me when nvlink_transport of Mooncake is bug-free
271
+ if self.enable_custom_mem_pool:
272
+ # batch_transfer_sync has a higher chance to trigger an accuracy drop for MNNVL, fallback to transfer_sync temporarily
273
+ for src_addr, dst_addr, length in transfer_blocks:
274
+ status = self.engine.transfer_sync(
275
+ mooncake_session_id, src_addr, dst_addr, length
276
+ )
277
+ if status != 0:
278
+ return status
279
+ return 0
280
+ else:
281
+ src_addrs, dst_addrs, lengths = zip(*transfer_blocks)
282
+ return self.engine.batch_transfer_sync(
283
+ mooncake_session_id, list(src_addrs), list(dst_addrs), list(lengths)
284
+ )
285
+
261
286
  def send_kvcache(
262
287
  self,
263
288
  mooncake_session_id: str,
@@ -283,17 +308,14 @@ class MooncakeKVManager(BaseKVManager):
283
308
 
284
309
  # Worker function for processing a single layer
285
310
  def process_layer(src_ptr: int, dst_ptr: int, item_len: int) -> int:
311
+ transfer_blocks = []
286
312
  for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks):
287
313
  src_addr = src_ptr + int(prefill_index[0]) * item_len
288
314
  dst_addr = dst_ptr + int(decode_index[0]) * item_len
289
315
  length = item_len * len(prefill_index)
316
+ transfer_blocks.append((src_addr, dst_addr, length))
290
317
 
291
- status = self.engine.transfer_sync(
292
- mooncake_session_id, src_addr, dst_addr, length
293
- )
294
- if status != 0:
295
- return status
296
- return 0
318
+ return self._transfer_data(mooncake_session_id, transfer_blocks)
297
319
 
298
320
  futures = [
299
321
  executor.submit(
@@ -465,21 +487,17 @@ class MooncakeKVManager(BaseKVManager):
465
487
  dst_aux_ptrs: list[int],
466
488
  dst_aux_index: int,
467
489
  ):
468
- src_addr_list = []
469
- dst_addr_list = []
470
- length_list = []
490
+ transfer_blocks = []
471
491
  prefill_aux_ptrs = self.kv_args.aux_data_ptrs
472
492
  prefill_aux_item_lens = self.kv_args.aux_item_lens
493
+
473
494
  for i, dst_aux_ptr in enumerate(dst_aux_ptrs):
474
495
  length = prefill_aux_item_lens[i]
475
496
  src_addr = prefill_aux_ptrs[i] + length * prefill_aux_index
476
497
  dst_addr = dst_aux_ptrs[i] + length * dst_aux_index
477
- src_addr_list.append(src_addr)
478
- dst_addr_list.append(dst_addr)
479
- length_list.append(length)
480
- return self.engine.batch_transfer_sync(
481
- mooncake_session_id, src_addr_list, dst_addr_list, length_list
482
- )
498
+ transfer_blocks.append((src_addr, dst_addr, length))
499
+
500
+ return self._transfer_data(mooncake_session_id, transfer_blocks)
483
501
 
484
502
  def sync_status_to_decode_endpoint(
485
503
  self, remote: str, dst_port: int, room: int, status: int, prefill_rank: int
@@ -460,6 +460,7 @@ class SchedulerDisaggregationPrefillMixin:
460
460
 
461
461
  # We need to remove the sync in the following function for overlap schedule.
462
462
  self.set_next_batch_sampling_info_done(batch)
463
+ self.maybe_send_health_check_signal()
463
464
 
464
465
  def process_disagg_prefill_inflight_queue(
465
466
  self: Scheduler, rids_to_check: Optional[List[str]] = None
@@ -75,6 +75,7 @@ class PyNcclCommunicator:
75
75
  self.available = True
76
76
  self.disabled = False
77
77
 
78
+ self.nccl_version = self.nccl.ncclGetRawVersion()
78
79
  if self.rank == 0:
79
80
  logger.info("sglang is using nccl==%s", self.nccl.ncclGetVersion())
80
81
 
@@ -259,6 +260,12 @@ class PyNcclCommunicator:
259
260
  cudaStream_t(stream.cuda_stream),
260
261
  )
261
262
 
263
+ def register_comm_window_raw(self, ptr: int, size: int):
264
+ return self.nccl.ncclCommWindowRegister(self.comm, buffer_type(ptr), size, 1)
265
+
266
+ def deregister_comm_window(self, window):
267
+ return self.nccl.ncclCommWindowDeregister(self.comm, window)
268
+
262
269
  @contextmanager
263
270
  def change_state(
264
271
  self, enable: Optional[bool] = None, stream: Optional[torch.cuda.Stream] = None
@@ -0,0 +1,133 @@
1
+ import tempfile
2
+
3
+ import torch
4
+ from packaging import version
5
+ from torch.cuda.memory import CUDAPluggableAllocator
6
+
7
+ from sglang.srt.distributed.parallel_state import GroupCoordinator
8
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
9
+
10
+ nccl_allocator_source = """
11
+ #include <nccl.h>
12
+ extern "C" {
13
+
14
+ void* nccl_alloc_plug(size_t size, int device, void* stream) {
15
+ void* ptr;
16
+ ncclResult_t err = ncclMemAlloc(&ptr, size);
17
+ return ptr;
18
+
19
+ }
20
+
21
+ void nccl_free_plug(void* ptr, size_t size, int device, void* stream) {
22
+ ncclResult_t err = ncclMemFree(ptr);
23
+ }
24
+
25
+ }
26
+ """
27
+
28
+ _allocator = None
29
+ _mem_pool = None
30
+ _registered_base_addrs = set()
31
+ _graph_pool_id = None
32
+
33
+
34
+ def is_symmetric_memory_enabled():
35
+ return global_server_args_dict["enable_symm_mem"]
36
+
37
+
38
+ def set_graph_pool_id(graph_pool_id):
39
+ global _graph_pool_id
40
+ _graph_pool_id = graph_pool_id
41
+
42
+
43
+ def get_nccl_mem_pool():
44
+ global _allocator, _mem_pool
45
+ if _mem_pool is None:
46
+ out_dir = tempfile.gettempdir()
47
+ nccl_allocator_libname = "nccl_allocator"
48
+ torch.utils.cpp_extension.load_inline(
49
+ name=nccl_allocator_libname,
50
+ cpp_sources=nccl_allocator_source,
51
+ with_cuda=True,
52
+ extra_ldflags=["-lnccl"],
53
+ verbose=True,
54
+ is_python_module=False,
55
+ build_directory=out_dir,
56
+ )
57
+ _allocator = CUDAPluggableAllocator(
58
+ f"{out_dir}/{nccl_allocator_libname}.so",
59
+ "nccl_alloc_plug",
60
+ "nccl_free_plug",
61
+ ).allocator()
62
+ _mem_pool = torch.cuda.MemPool(_allocator)
63
+ return _mem_pool
64
+
65
+
66
+ class use_symmetric_memory:
67
+ def __init__(self, group_coordinator: GroupCoordinator):
68
+ if not is_symmetric_memory_enabled():
69
+ self.group_coordinator = None
70
+ self._mem_pool_ctx = None
71
+ self.is_graph_capture = None
72
+ self.device = None
73
+ self.pre_2_8_0 = None
74
+ else:
75
+ self.group_coordinator = group_coordinator
76
+ self._mem_pool_ctx = torch.cuda.use_mem_pool(get_nccl_mem_pool())
77
+ self.is_graph_capture = torch.cuda.is_current_stream_capturing()
78
+ self.device = torch.cuda.current_device()
79
+ self.pre_2_8_0 = version.parse(torch.__version__) < version.parse("2.8.0")
80
+
81
+ def __enter__(self):
82
+ if not is_symmetric_memory_enabled():
83
+ return self
84
+ assert (
85
+ self.group_coordinator.pynccl_comm is not None
86
+ ), f"Symmetric memory requires pynccl to be enabled in group '{self.group_coordinator.group_name}'"
87
+ assert (
88
+ self.group_coordinator.pynccl_comm.nccl_version >= 22703
89
+ ), "NCCL version 2.27.3 or higher is required for NCCL symmetric memory"
90
+ if self.is_graph_capture:
91
+ assert (
92
+ _graph_pool_id is not None
93
+ ), "graph_pool_id is not set under graph capture"
94
+ # Pause graph memory pool to use symmetric memory with cuda graph
95
+ if self.pre_2_8_0:
96
+ torch._C._cuda_endAllocateCurrentStreamToPool(
97
+ self.device, _graph_pool_id
98
+ )
99
+ else:
100
+ torch._C._cuda_endAllocateToPool(self.device, _graph_pool_id)
101
+ self._mem_pool_ctx.__enter__()
102
+ return self
103
+
104
+ def tag(self, tensor: torch.Tensor):
105
+ if not is_symmetric_memory_enabled():
106
+ return
107
+ tensor.symmetric_memory = True
108
+
109
+ def __exit__(self, exc_type, exc_val, exc_tb):
110
+ if not is_symmetric_memory_enabled():
111
+ return
112
+ global _registered_base_addrs
113
+ self._mem_pool_ctx.__exit__(exc_type, exc_val, exc_tb)
114
+ for segment in get_nccl_mem_pool().snapshot():
115
+ if segment["address"] not in _registered_base_addrs:
116
+ if segment["stream"] == 0 and self.pre_2_8_0:
117
+ # PyTorch version < 2.8.0 has a multi-thread MemPool bug
118
+ # See https://github.com/pytorch/pytorch/issues/152861
119
+ # Fixed at https://github.com/pytorch/pytorch/commit/f01e628e3b31852983ab30b25bf251f557ba9c0b
120
+ # WAR is to skip allocations on the default stream since the forward_pass thread always runs on a custom stream
121
+ continue
122
+ self.group_coordinator.pynccl_comm.register_comm_window_raw(
123
+ segment["address"], segment["total_size"]
124
+ )
125
+ _registered_base_addrs.add(segment["address"])
126
+
127
+ if self.is_graph_capture:
128
+ if self.pre_2_8_0:
129
+ torch._C._cuda_beginAllocateToPool(self.device, _graph_pool_id)
130
+ else:
131
+ torch._C._cuda_beginAllocateCurrentThreadToPool(
132
+ self.device, _graph_pool_id
133
+ )
@@ -67,6 +67,7 @@ def find_nccl_library() -> str:
67
67
 
68
68
  ncclResult_t = ctypes.c_int
69
69
  ncclComm_t = ctypes.c_void_p
70
+ ncclWindow_t = ctypes.c_void_p
70
71
 
71
72
 
72
73
  class ncclUniqueId(ctypes.Structure):
@@ -279,6 +280,23 @@ class NCCLLibrary:
279
280
  Function("ncclCommDestroy", ncclResult_t, [ncclComm_t]),
280
281
  ]
281
282
 
283
+ exported_functions_symm_mem = [
284
+ # ncclResult_t ncclCommWindowRegister(ncclComm_t comm, void* buff, size_t size, ncclWindow_t* win, int winFlags);
285
+ Function(
286
+ "ncclCommWindowRegister",
287
+ ncclResult_t,
288
+ [
289
+ ncclComm_t,
290
+ buffer_type,
291
+ ctypes.c_size_t,
292
+ ctypes.POINTER(ncclWindow_t),
293
+ ctypes.c_int,
294
+ ],
295
+ ),
296
+ # ncclResult_t ncclCommWindowDeregister(ncclComm_t comm, ncclWindow_t win);
297
+ Function("ncclCommWindowDeregister", ncclResult_t, [ncclComm_t, ncclWindow_t]),
298
+ ]
299
+
282
300
  # class attribute to store the mapping from the path to the library
283
301
  # to avoid loading the same library multiple times
284
302
  path_to_library_cache: Dict[str, Any] = {}
@@ -312,7 +330,10 @@ class NCCLLibrary:
312
330
 
313
331
  if so_file not in NCCLLibrary.path_to_dict_mapping:
314
332
  _funcs: Dict[str, Any] = {}
315
- for func in NCCLLibrary.exported_functions:
333
+ exported_functions = NCCLLibrary.exported_functions
334
+ if hasattr(self.lib, "ncclCommWindowRegister"):
335
+ exported_functions.extend(NCCLLibrary.exported_functions_symm_mem)
336
+ for func in exported_functions:
316
337
  f = getattr(self.lib, func.name)
317
338
  f.restype = func.restype
318
339
  f.argtypes = func.argtypes
@@ -328,10 +349,14 @@ class NCCLLibrary:
328
349
  error_str = self.ncclGetErrorString(result)
329
350
  raise RuntimeError(f"NCCL error: {error_str}")
330
351
 
331
- def ncclGetVersion(self) -> str:
352
+ def ncclGetRawVersion(self) -> int:
332
353
  version = ctypes.c_int()
333
354
  self.NCCL_CHECK(self._funcs["ncclGetVersion"](ctypes.byref(version)))
334
- version_str = str(version.value)
355
+ # something like 21903
356
+ return version.value
357
+
358
+ def ncclGetVersion(self) -> str:
359
+ version_str = str(self.ncclGetRawVersion())
335
360
  # something like 21903 --> "2.19.3"
336
361
  major = version_str[0].lstrip("0")
337
362
  minor = version_str[1:3].lstrip("0")
@@ -460,6 +485,20 @@ class NCCLLibrary:
460
485
  def ncclCommDestroy(self, comm: ncclComm_t) -> None:
461
486
  self.NCCL_CHECK(self._funcs["ncclCommDestroy"](comm))
462
487
 
488
+ def ncclCommWindowRegister(
489
+ self, comm: ncclComm_t, buff: buffer_type, size: int, win_flags: int
490
+ ) -> ncclWindow_t:
491
+ window = ncclWindow_t()
492
+ self.NCCL_CHECK(
493
+ self._funcs["ncclCommWindowRegister"](
494
+ comm, buff, size, ctypes.byref(window), win_flags
495
+ )
496
+ )
497
+ return window
498
+
499
+ def ncclCommWindowDeregister(self, comm: ncclComm_t, window: ncclWindow_t) -> None:
500
+ self.NCCL_CHECK(self._funcs["ncclCommWindowDeregister"](comm, window))
501
+
463
502
 
464
503
  __all__ = [
465
504
  "NCCLLibrary",
@@ -497,6 +497,17 @@ class GroupCoordinator:
497
497
  if self.npu_communicator is not None and not self.npu_communicator.disabled:
498
498
  return self.npu_communicator.all_reduce(input_)
499
499
 
500
+ if (
501
+ self.pynccl_comm is not None
502
+ and hasattr(input_, "symmetric_memory")
503
+ and input_.symmetric_memory
504
+ ):
505
+ with self.pynccl_comm.change_state(
506
+ enable=True, stream=torch.cuda.current_stream()
507
+ ):
508
+ self.pynccl_comm.all_reduce(input_)
509
+ return input_
510
+
500
511
  outplace_all_reduce_method = None
501
512
  if (
502
513
  self.qr_comm is not None
@@ -623,8 +623,9 @@ class Engine(EngineBase):
623
623
  def _set_envs_and_config(server_args: ServerArgs):
624
624
  # Set global environments
625
625
  os.environ["TF_CPP_MIN_LOG_LEVEL"] = "3"
626
- os.environ["NCCL_CUMEM_ENABLE"] = "0"
627
- os.environ["NCCL_NVLS_ENABLE"] = str(int(server_args.enable_nccl_nvls))
626
+ os.environ["NCCL_CUMEM_ENABLE"] = str(int(server_args.enable_symm_mem))
627
+ if not server_args.enable_symm_mem:
628
+ os.environ["NCCL_NVLS_ENABLE"] = str(int(server_args.enable_nccl_nvls))
628
629
  os.environ["TORCH_NCCL_AVOID_RECORD_STREAMS"] = "1"
629
630
  os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4"
630
631
  os.environ["CUDA_MODULE_LOADING"] = "AUTO"
@@ -731,6 +732,7 @@ def _launch_subprocesses(
731
732
  pp_rank,
732
733
  None,
733
734
  writer,
735
+ None,
734
736
  ),
735
737
  )
736
738