sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. sglang/bench_one_batch.py +2 -1
  2. sglang/eval/loogle_eval.py +7 -0
  3. sglang/srt/_custom_ops.py +29 -1
  4. sglang/srt/configs/deepseekvl2.py +11 -2
  5. sglang/srt/configs/internvl.py +3 -0
  6. sglang/srt/configs/janus_pro.py +3 -0
  7. sglang/srt/configs/model_config.py +10 -8
  8. sglang/srt/configs/update_config.py +3 -1
  9. sglang/srt/conversation.py +2 -1
  10. sglang/srt/custom_op.py +5 -2
  11. sglang/srt/disaggregation/common/conn.py +34 -6
  12. sglang/srt/disaggregation/decode.py +9 -1
  13. sglang/srt/disaggregation/mini_lb.py +3 -2
  14. sglang/srt/disaggregation/mooncake/conn.py +93 -76
  15. sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
  16. sglang/srt/disaggregation/nixl/conn.py +17 -13
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
  18. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
  19. sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
  20. sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
  21. sglang/srt/distributed/parallel_state.py +103 -15
  22. sglang/srt/entrypoints/engine.py +31 -33
  23. sglang/srt/entrypoints/http_server.py +20 -32
  24. sglang/srt/entrypoints/openai/protocol.py +3 -3
  25. sglang/srt/entrypoints/openai/serving_chat.py +48 -6
  26. sglang/srt/eplb/expert_location_dispatch.py +1 -1
  27. sglang/srt/function_call/base_format_detector.py +74 -12
  28. sglang/srt/function_call/deepseekv3_detector.py +26 -11
  29. sglang/srt/function_call/ebnf_composer.py +95 -63
  30. sglang/srt/function_call/function_call_parser.py +4 -2
  31. sglang/srt/function_call/kimik2_detector.py +41 -16
  32. sglang/srt/function_call/llama32_detector.py +6 -3
  33. sglang/srt/function_call/mistral_detector.py +11 -3
  34. sglang/srt/function_call/pythonic_detector.py +16 -14
  35. sglang/srt/function_call/qwen25_detector.py +12 -3
  36. sglang/srt/function_call/qwen3_coder_detector.py +151 -0
  37. sglang/srt/hf_transformers_utils.py +0 -1
  38. sglang/srt/layers/activation.py +24 -3
  39. sglang/srt/layers/attention/base_attn_backend.py +3 -1
  40. sglang/srt/layers/attention/flashattention_backend.py +3 -3
  41. sglang/srt/layers/attention/flashinfer_backend.py +40 -1
  42. sglang/srt/layers/communicator.py +12 -12
  43. sglang/srt/layers/dp_attention.py +72 -24
  44. sglang/srt/layers/linear.py +13 -102
  45. sglang/srt/layers/logits_processor.py +34 -24
  46. sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
  47. sglang/srt/layers/moe/ep_moe/layer.py +23 -402
  48. sglang/srt/layers/moe/fused_moe_native.py +7 -47
  49. sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
  50. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  52. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +54 -263
  57. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
  58. sglang/srt/layers/moe/topk.py +190 -23
  59. sglang/srt/layers/quantization/__init__.py +20 -134
  60. sglang/srt/layers/quantization/awq.py +578 -11
  61. sglang/srt/layers/quantization/awq_triton.py +339 -0
  62. sglang/srt/layers/quantization/base_config.py +85 -10
  63. sglang/srt/layers/quantization/blockwise_int8.py +17 -55
  64. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
  65. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +23 -79
  66. sglang/srt/layers/quantization/fp8.py +273 -62
  67. sglang/srt/layers/quantization/fp8_kernel.py +210 -46
  68. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  69. sglang/srt/layers/quantization/gptq.py +501 -143
  70. sglang/srt/layers/quantization/marlin_utils.py +790 -0
  71. sglang/srt/layers/quantization/modelopt_quant.py +34 -112
  72. sglang/srt/layers/quantization/moe_wna16.py +45 -49
  73. sglang/srt/layers/quantization/petit.py +252 -0
  74. sglang/srt/layers/quantization/petit_utils.py +104 -0
  75. sglang/srt/layers/quantization/qoq.py +7 -6
  76. sglang/srt/layers/quantization/scalar_type.py +352 -0
  77. sglang/srt/layers/quantization/unquant.py +422 -0
  78. sglang/srt/layers/quantization/utils.py +340 -9
  79. sglang/srt/layers/quantization/w4afp8.py +8 -4
  80. sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
  81. sglang/srt/layers/quantization/w8a8_int8.py +51 -115
  82. sglang/srt/layers/radix_attention.py +5 -3
  83. sglang/srt/layers/vocab_parallel_embedding.py +1 -41
  84. sglang/srt/lora/lora.py +0 -4
  85. sglang/srt/lora/lora_manager.py +162 -164
  86. sglang/srt/lora/lora_registry.py +124 -0
  87. sglang/srt/lora/mem_pool.py +83 -35
  88. sglang/srt/lora/utils.py +12 -5
  89. sglang/srt/managers/cache_controller.py +288 -0
  90. sglang/srt/managers/io_struct.py +60 -30
  91. sglang/srt/managers/mm_utils.py +7 -8
  92. sglang/srt/managers/schedule_batch.py +163 -113
  93. sglang/srt/managers/schedule_policy.py +68 -27
  94. sglang/srt/managers/scheduler.py +256 -86
  95. sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
  96. sglang/srt/managers/tokenizer_manager.py +38 -27
  97. sglang/srt/managers/tp_worker.py +16 -4
  98. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  99. sglang/srt/mem_cache/allocator.py +74 -23
  100. sglang/srt/mem_cache/base_prefix_cache.py +14 -2
  101. sglang/srt/mem_cache/chunk_cache.py +5 -2
  102. sglang/srt/mem_cache/hicache_storage.py +168 -0
  103. sglang/srt/mem_cache/hiradix_cache.py +194 -5
  104. sglang/srt/mem_cache/memory_pool.py +16 -1
  105. sglang/srt/mem_cache/memory_pool_host.py +44 -2
  106. sglang/srt/mem_cache/radix_cache.py +26 -0
  107. sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
  108. sglang/srt/metrics/collector.py +9 -0
  109. sglang/srt/model_executor/cuda_graph_runner.py +66 -31
  110. sglang/srt/model_executor/forward_batch_info.py +210 -25
  111. sglang/srt/model_executor/model_runner.py +147 -42
  112. sglang/srt/model_loader/loader.py +7 -1
  113. sglang/srt/model_loader/utils.py +4 -4
  114. sglang/srt/models/clip.py +1 -1
  115. sglang/srt/models/deepseek.py +9 -6
  116. sglang/srt/models/deepseek_janus_pro.py +1 -1
  117. sglang/srt/models/deepseek_v2.py +192 -173
  118. sglang/srt/models/deepseek_vl2.py +5 -5
  119. sglang/srt/models/gemma.py +48 -0
  120. sglang/srt/models/gemma2.py +52 -0
  121. sglang/srt/models/gemma3_causal.py +63 -0
  122. sglang/srt/models/gemma3_mm.py +1 -1
  123. sglang/srt/models/gemma3n_mm.py +2 -4
  124. sglang/srt/models/granitemoe.py +385 -0
  125. sglang/srt/models/grok.py +9 -3
  126. sglang/srt/models/hunyuan.py +63 -16
  127. sglang/srt/models/internvl.py +1 -1
  128. sglang/srt/models/kimi_vl.py +1 -1
  129. sglang/srt/models/llama.py +41 -0
  130. sglang/srt/models/llama4.py +11 -11
  131. sglang/srt/models/llava.py +2 -2
  132. sglang/srt/models/llavavid.py +1 -1
  133. sglang/srt/models/minicpm.py +0 -2
  134. sglang/srt/models/minicpmo.py +3 -7
  135. sglang/srt/models/minicpmv.py +1 -1
  136. sglang/srt/models/mistral.py +1 -1
  137. sglang/srt/models/mixtral.py +9 -2
  138. sglang/srt/models/mllama.py +3 -5
  139. sglang/srt/models/mllama4.py +13 -6
  140. sglang/srt/models/olmoe.py +8 -5
  141. sglang/srt/models/persimmon.py +330 -0
  142. sglang/srt/models/phi.py +321 -0
  143. sglang/srt/models/phi4mm.py +44 -4
  144. sglang/srt/models/phi4mm_audio.py +1260 -0
  145. sglang/srt/models/phi4mm_utils.py +1917 -0
  146. sglang/srt/models/phimoe.py +9 -3
  147. sglang/srt/models/qwen.py +37 -0
  148. sglang/srt/models/qwen2.py +41 -0
  149. sglang/srt/models/qwen2_5_vl.py +4 -4
  150. sglang/srt/models/qwen2_audio.py +1 -1
  151. sglang/srt/models/qwen2_moe.py +53 -9
  152. sglang/srt/models/qwen2_vl.py +4 -4
  153. sglang/srt/models/qwen3.py +65 -1
  154. sglang/srt/models/qwen3_moe.py +57 -24
  155. sglang/srt/models/vila.py +1 -1
  156. sglang/srt/multimodal/processors/base_processor.py +91 -97
  157. sglang/srt/multimodal/processors/clip.py +21 -19
  158. sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
  159. sglang/srt/multimodal/processors/gemma3.py +13 -17
  160. sglang/srt/multimodal/processors/gemma3n.py +19 -23
  161. sglang/srt/multimodal/processors/internvl.py +9 -10
  162. sglang/srt/multimodal/processors/janus_pro.py +12 -27
  163. sglang/srt/multimodal/processors/kimi_vl.py +12 -14
  164. sglang/srt/multimodal/processors/llava.py +4 -2
  165. sglang/srt/multimodal/processors/minicpm.py +35 -44
  166. sglang/srt/multimodal/processors/mlama.py +21 -18
  167. sglang/srt/multimodal/processors/mllama4.py +4 -5
  168. sglang/srt/multimodal/processors/phi4mm.py +63 -39
  169. sglang/srt/multimodal/processors/pixtral.py +14 -35
  170. sglang/srt/multimodal/processors/qwen_audio.py +65 -0
  171. sglang/srt/multimodal/processors/qwen_vl.py +16 -21
  172. sglang/srt/multimodal/processors/vila.py +14 -14
  173. sglang/srt/reasoning_parser.py +46 -4
  174. sglang/srt/sampling/sampling_batch_info.py +6 -5
  175. sglang/srt/sampling/sampling_params.py +8 -1
  176. sglang/srt/server_args.py +454 -270
  177. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
  178. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +46 -37
  179. sglang/srt/speculative/eagle_utils.py +51 -23
  180. sglang/srt/speculative/eagle_worker.py +59 -44
  181. sglang/srt/two_batch_overlap.py +10 -5
  182. sglang/srt/utils.py +44 -69
  183. sglang/test/runners.py +14 -3
  184. sglang/test/test_activation.py +50 -1
  185. sglang/test/test_block_fp8.py +8 -3
  186. sglang/test/test_block_fp8_ep.py +1 -1
  187. sglang/test/test_custom_ops.py +12 -7
  188. sglang/test/test_cutlass_w4a8_moe.py +1 -3
  189. sglang/test/test_fp4_moe.py +1 -3
  190. sglang/test/test_marlin_moe.py +286 -0
  191. sglang/test/test_marlin_utils.py +171 -0
  192. sglang/test/test_utils.py +35 -0
  193. sglang/version.py +1 -1
  194. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/METADATA +10 -10
  195. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/RECORD +198 -175
  196. sglang/srt/layers/quantization/quant_utils.py +0 -166
  197. sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
  198. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/WHEEL +0 -0
  199. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/licenses/LICENSE +0 -0
  200. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/top_level.txt +0 -0
@@ -35,7 +35,15 @@ from sglang.srt.disaggregation.common.utils import (
35
35
  from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
36
36
  from sglang.srt.disaggregation.utils import DisaggregationMode
37
37
  from sglang.srt.server_args import ServerArgs
38
- from sglang.srt.utils import get_free_port, get_int_env_var, get_ip, get_local_ip_auto
38
+ from sglang.srt.utils import (
39
+ format_tcp_address,
40
+ get_free_port,
41
+ get_int_env_var,
42
+ get_ip,
43
+ get_local_ip_auto,
44
+ is_valid_ipv6_address,
45
+ maybe_wrap_ipv6_address,
46
+ )
39
47
 
40
48
  logger = logging.getLogger(__name__)
41
49
 
@@ -148,6 +156,9 @@ class MooncakeKVManager(BaseKVManager):
148
156
  self.request_status: Dict[int, KVPoll] = {}
149
157
  self.rank_port = None
150
158
  self.server_socket = zmq.Context().socket(zmq.PULL)
159
+ if is_valid_ipv6_address(self.local_ip):
160
+ self.server_socket.setsockopt(zmq.IPV6, 1)
161
+
151
162
  self.register_buffer_to_engine()
152
163
  if self.disaggregation_mode == DisaggregationMode.PREFILL:
153
164
  self.transfer_infos: Dict[int, Dict[str, TransferInfo]] = {}
@@ -240,8 +251,10 @@ class MooncakeKVManager(BaseKVManager):
240
251
  self.engine.register(aux_data_ptr, aux_data_len)
241
252
 
242
253
  @cache
243
- def _connect(self, endpoint: str):
254
+ def _connect(self, endpoint: str, is_ipv6: bool = False):
244
255
  socket = zmq.Context().socket(zmq.PUSH)
256
+ if is_ipv6:
257
+ socket.setsockopt(zmq.IPV6, 1)
245
258
  socket.connect(endpoint)
246
259
  return socket
247
260
 
@@ -321,67 +334,60 @@ class MooncakeKVManager(BaseKVManager):
321
334
  This may introduce performance overhead (increased TTFT) for long sequences.
322
335
  """
323
336
  # Extract configuration
324
- local_tp_rank = self.kv_args.engine_rank
325
337
  local_tp_size = self.tp_size // self.dp_size
338
+ local_tp_rank_in_group = self.kv_args.engine_rank % local_tp_size
339
+ src_kv_item_len = self.kv_args.kv_item_lens[0]
340
+ dst_tp_rank_in_group = dst_tp_rank % dst_tp_size
326
341
  num_kv_heads = self.kv_args.kv_head_num
327
342
  num_layers = len(self.kv_args.kv_data_ptrs)
328
343
  page_size = self.kv_args.page_size
329
344
 
330
345
  # Calculate head distribution
331
- heads_per_decode_rank = num_kv_heads * local_tp_size // dst_tp_size
332
- heads_per_prefill_rank = num_kv_heads
333
- decode_global_head_start = dst_tp_rank * heads_per_decode_rank
334
- prefill_global_head_start = local_tp_rank * heads_per_prefill_rank
335
- bytes_per_head = dst_kv_item_len // heads_per_decode_rank // page_size
336
-
337
- decode_rank_item_lens = [dst_kv_item_len for _ in range(num_layers)]
346
+ src_heads_per_rank = num_kv_heads
347
+ dst_heads_per_rank = num_kv_heads * local_tp_size // dst_tp_size
348
+ bytes_per_head_slice_to_send = (
349
+ dst_kv_item_len // page_size // dst_heads_per_rank
350
+ )
338
351
 
339
352
  # Determine slicing parameters based on TP configuration
340
353
  if local_tp_size > dst_tp_size:
341
- src_head_offset = 0
342
- num_heads_to_send = heads_per_prefill_rank
343
- dst_head_offset = prefill_global_head_start - decode_global_head_start
354
+ # Send KVCache from multiple prefill instances to 1 decode instance
355
+ src_head_start_offset = 0
356
+ num_heads_to_send = src_heads_per_rank
357
+ dst_head_start_offset = local_tp_rank_in_group * src_heads_per_rank
344
358
  else:
345
- src_head_offset = decode_global_head_start - prefill_global_head_start
346
- num_heads_to_send = heads_per_decode_rank
347
- dst_head_offset = 0
359
+ # Send KVCache from 1 prefill instance to multiple decode instances
360
+ src_head_start_offset = dst_tp_rank_in_group * dst_heads_per_rank
361
+ num_heads_to_send = dst_heads_per_rank
362
+ dst_head_start_offset = 0
348
363
 
349
- layer_transfer_params = []
364
+ layers_params = []
350
365
  for layer_id in range(num_layers):
351
- item_len_of_prefill_rank_page = self.kv_args.kv_item_lens[layer_id]
352
-
353
- # Page stride on the target dst decode rank for its slice pages
354
- item_len_of_decode_rank_page = decode_rank_item_lens[layer_id]
355
-
356
- if item_len_of_prefill_rank_page == 0 or num_kv_heads == 0:
357
- logger.error(
358
- f"Invalid item_len_of_prefill_rank_page or num_kv_heads for layer {layer_id}"
359
- )
360
- return -1
361
-
362
- # Calculate precise byte offset and length for the sub-slice within the prefill page data
363
- src_slice_offset = src_head_offset * bytes_per_head
364
- dst_slice_offset = dst_head_offset * bytes_per_head
365
- slice_lens_per_page = num_heads_to_send * bytes_per_head
366
+ # Calculate precise byte offset and length for the sub-slice within the token
367
+ src_head_slice_offset = src_head_start_offset * bytes_per_head_slice_to_send
368
+ dst_head_slice_offset = dst_head_start_offset * bytes_per_head_slice_to_send
369
+ heads_bytes_per_token_to_send = (
370
+ num_heads_to_send * bytes_per_head_slice_to_send
371
+ )
366
372
 
367
- # Sanity check: The data sub-slice to be sent should fit into the decode instance's page.
368
- # This means slice_lens_per_page <= item_len_of_decode_rank_page
369
- if slice_lens_per_page > item_len_of_decode_rank_page:
373
+ # Sanity check: The data sub-slice to be sent should fit into the dst buffer.
374
+ # This means heads_bytes_per_token_to_send <= (dst_kv_item_len // page_size)
375
+ if heads_bytes_per_token_to_send > (dst_kv_item_len // page_size):
370
376
  logger.error(
371
377
  f"[{mooncake_session_id}] Layer {layer_id}: "
372
- f"slice size ({slice_lens_per_page}) exceeds "
373
- f"target page size ({item_len_of_decode_rank_page})"
378
+ f"slice size ({heads_bytes_per_token_to_send}) exceeds "
379
+ f"target token slot size ({dst_kv_item_len // page_size})"
374
380
  )
375
381
  return -1
376
- layer_transfer_params.append(
382
+ layers_params.append(
377
383
  (
378
384
  self.kv_args.kv_data_ptrs[layer_id],
379
385
  dst_kv_ptrs[layer_id],
380
- item_len_of_prefill_rank_page,
381
- item_len_of_decode_rank_page,
382
- src_slice_offset,
383
- dst_slice_offset,
384
- slice_lens_per_page,
386
+ src_kv_item_len,
387
+ dst_kv_item_len,
388
+ src_head_slice_offset,
389
+ dst_head_slice_offset,
390
+ heads_bytes_per_token_to_send,
385
391
  )
386
392
  )
387
393
 
@@ -391,9 +397,9 @@ class MooncakeKVManager(BaseKVManager):
391
397
  dst_ptr,
392
398
  src_item_len,
393
399
  dst_item_len,
394
- src_offset,
395
- dst_offset,
396
- slice_lens_per_page,
400
+ src_head_slice_offset,
401
+ dst_head_slice_offset,
402
+ heads_bytes_per_token_to_send,
397
403
  ) = layer_params
398
404
  src_addr_list = []
399
405
  dst_addr_list = []
@@ -424,17 +430,12 @@ class MooncakeKVManager(BaseKVManager):
424
430
  )
425
431
 
426
432
  # Calculate final src and dst addresses by applying head-slice offsets
427
- src_slice_addr = src_token_slot_start_addr + src_offset
428
- dst_slice_addr = dst_token_slot_start_addr + dst_offset
433
+ src_slice_addr = src_token_slot_start_addr + src_head_slice_offset
434
+ dst_slice_addr = dst_token_slot_start_addr + dst_head_slice_offset
429
435
 
430
436
  src_addr_list.append(src_slice_addr)
431
437
  dst_addr_list.append(dst_slice_addr)
432
- length_list.append(slice_lens_per_page)
433
-
434
- logger.debug(
435
- f"SYNC: sid={mooncake_session_id}, "
436
- f"src={src_slice_addr}, dst={dst_slice_addr}, len={slice_lens_per_page}"
437
- )
438
+ length_list.append(heads_bytes_per_token_to_send)
438
439
 
439
440
  return self.engine.batch_transfer_sync(
440
441
  mooncake_session_id, src_addr_list, dst_addr_list, length_list
@@ -445,7 +446,7 @@ class MooncakeKVManager(BaseKVManager):
445
446
  process_layer_tp_aware,
446
447
  layer_params,
447
448
  )
448
- for layer_params in layer_transfer_params
449
+ for layer_params in layers_params
449
450
  ]
450
451
 
451
452
  for future in concurrent.futures.as_completed(futures):
@@ -483,9 +484,9 @@ class MooncakeKVManager(BaseKVManager):
483
484
  def sync_status_to_decode_endpoint(
484
485
  self, remote: str, dst_port: int, room: int, status: int, prefill_rank: int
485
486
  ):
486
- if ":" in remote:
487
- remote = remote.split(":")[0]
488
- self._connect("tcp://" + remote + ":" + str(dst_port)).send_multipart(
487
+ self._connect(
488
+ format_tcp_address(remote, dst_port), is_ipv6=is_valid_ipv6_address(remote)
489
+ ).send_multipart(
489
490
  [
490
491
  str(room).encode("ascii"),
491
492
  str(status).encode("ascii"),
@@ -533,12 +534,12 @@ class MooncakeKVManager(BaseKVManager):
533
534
  if len(chunked_dst_kv_indice) < len(
534
535
  kv_chunk.prefill_kv_indices
535
536
  ):
536
- kv_chunk.prefill_kv_indices = kv_chunk.prefill_kv_indices[
537
- : len(chunked_dst_kv_indice)
538
- ]
539
537
  logger.warning(
540
538
  f"len(chunked_dst_kv_indice) = {len(chunked_dst_kv_indice)}, len(kv_chunk.prefill_kv_indices) = {len(kv_chunk.prefill_kv_indices)}"
541
539
  )
540
+ kv_chunk.prefill_kv_indices = kv_chunk.prefill_kv_indices[
541
+ : len(chunked_dst_kv_indice)
542
+ ]
542
543
 
543
544
  target_rank_registration_info: KVArgsRegisterInfo = (
544
545
  self.decode_kv_args_table[req.mooncake_session_id]
@@ -628,9 +629,12 @@ class MooncakeKVManager(BaseKVManager):
628
629
  f"Transfer thread failed because of {e}. Prefill instance with bootstrap_port={self.bootstrap_port} is dead."
629
630
  )
630
631
 
632
+ def _bind_server_socket(self):
633
+ self.server_socket.bind(format_tcp_address(self.local_ip, self.rank_port))
634
+
631
635
  def start_prefill_thread(self):
632
636
  self.rank_port = get_free_port()
633
- self.server_socket.bind(f"tcp://{self.local_ip}:{self.rank_port}")
637
+ self._bind_server_socket()
634
638
 
635
639
  def bootstrap_thread():
636
640
  """This thread recvs pre-alloc notification from the decode engine"""
@@ -669,7 +673,7 @@ class MooncakeKVManager(BaseKVManager):
669
673
 
670
674
  def start_decode_thread(self):
671
675
  self.rank_port = get_free_port()
672
- self.server_socket.bind(f"tcp://{self.local_ip}:{self.rank_port}")
676
+ self._bind_server_socket()
673
677
 
674
678
  def decode_thread():
675
679
  while True:
@@ -788,7 +792,7 @@ class MooncakeKVManager(BaseKVManager):
788
792
  # requests with the same dst_sessions will be added into the same
789
793
  # queue, which enables early abort with failed sessions.
790
794
  dst_infos = self.transfer_infos[bootstrap_room].keys()
791
- session_port_sum = sum(int(session.split(":")[1]) for session in dst_infos)
795
+ session_port_sum = sum(int(session.rsplit(":", 1)[1]) for session in dst_infos)
792
796
  shard_idx = session_port_sum % len(self.transfer_queues)
793
797
 
794
798
  self.transfer_queues[shard_idx].put(
@@ -826,11 +830,18 @@ class MooncakeKVManager(BaseKVManager):
826
830
  def _register_to_bootstrap(self):
827
831
  """Register KVSender to bootstrap server via HTTP POST."""
828
832
  if self.dist_init_addr:
829
- ip_address = socket.gethostbyname(self.dist_init_addr.split(":")[0])
833
+ if self.dist_init_addr.startswith("["): # [ipv6]:port or [ipv6]
834
+ if self.dist_init_addr.endswith("]"):
835
+ host = self.dist_init_addr
836
+ else:
837
+ host, _ = self.dist_init_addr.rsplit(":", 1)
838
+ else:
839
+ host = socket.gethostbyname(self.dist_init_addr.rsplit(":", 1)[0])
830
840
  else:
831
- ip_address = get_ip()
841
+ host = get_ip()
842
+ host = maybe_wrap_ipv6_address(host)
832
843
 
833
- bootstrap_server_url = f"{ip_address}:{self.bootstrap_port}"
844
+ bootstrap_server_url = f"{host}:{self.bootstrap_port}"
834
845
  url = f"http://{bootstrap_server_url}/route"
835
846
  payload = {
836
847
  "role": "Prefill",
@@ -1175,9 +1186,6 @@ class MooncakeKVReceiver(BaseKVReceiver):
1175
1186
 
1176
1187
  def _register_kv_args(self):
1177
1188
  for bootstrap_info in self.bootstrap_infos:
1178
- self.prefill_server_url = (
1179
- f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}"
1180
- )
1181
1189
  packed_kv_data_ptrs = b"".join(
1182
1190
  struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
1183
1191
  )
@@ -1191,7 +1199,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
1191
1199
  dst_tp_size = str(tp_size).encode("ascii")
1192
1200
  dst_kv_item_len = str(kv_item_len).encode("ascii")
1193
1201
 
1194
- sock, lock = self._connect("tcp://" + self.prefill_server_url)
1202
+ sock, lock = self._connect_to_bootstrap_server(bootstrap_info)
1195
1203
  with lock:
1196
1204
  sock.send_multipart(
1197
1205
  [
@@ -1208,23 +1216,32 @@ class MooncakeKVReceiver(BaseKVReceiver):
1208
1216
  )
1209
1217
 
1210
1218
  @classmethod
1211
- def _connect(cls, endpoint: str):
1219
+ def _connect(cls, endpoint: str, is_ipv6: bool = False):
1212
1220
  with cls._global_lock:
1213
1221
  if endpoint not in cls._socket_cache:
1214
1222
  sock = cls._ctx.socket(zmq.PUSH)
1223
+ if is_ipv6:
1224
+ sock.setsockopt(zmq.IPV6, 1)
1215
1225
  sock.connect(endpoint)
1216
1226
  cls._socket_cache[endpoint] = sock
1217
1227
  cls._socket_locks[endpoint] = threading.Lock()
1218
1228
  return cls._socket_cache[endpoint], cls._socket_locks[endpoint]
1219
1229
 
1230
+ @classmethod
1231
+ def _connect_to_bootstrap_server(cls, bootstrap_info: dict):
1232
+ ip_address = bootstrap_info["rank_ip"]
1233
+ port = bootstrap_info["rank_port"]
1234
+ is_ipv6_address = is_valid_ipv6_address(ip_address)
1235
+ sock, lock = cls._connect(
1236
+ format_tcp_address(ip_address, port), is_ipv6=is_ipv6_address
1237
+ )
1238
+ return sock, lock
1239
+
1220
1240
  def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
1221
1241
  for bootstrap_info in self.bootstrap_infos:
1222
- self.prefill_server_url = (
1223
- f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}"
1224
- )
1242
+ sock, lock = self._connect_to_bootstrap_server(bootstrap_info)
1225
1243
  is_dummy = bootstrap_info["is_dummy"]
1226
1244
 
1227
- sock, lock = self._connect("tcp://" + self.prefill_server_url)
1228
1245
  with lock:
1229
1246
  sock.send_multipart(
1230
1247
  [
@@ -1,7 +1,7 @@
1
1
  import logging
2
2
  from typing import List, Optional
3
3
 
4
- from sglang.srt.utils import get_bool_env_var, get_free_port
4
+ from sglang.srt.utils import get_bool_env_var, get_free_port, maybe_wrap_ipv6_address
5
5
 
6
6
  logger = logging.getLogger(__name__)
7
7
 
@@ -27,7 +27,9 @@ class MooncakeTransferEngine:
27
27
  hostname=self.hostname,
28
28
  device_name=self.ib_device,
29
29
  )
30
- self.session_id = f"{self.hostname}:{self.engine.get_rpc_port()}"
30
+ self.session_id = (
31
+ f"{maybe_wrap_ipv6_address(self.hostname)}:{self.engine.get_rpc_port()}"
32
+ )
31
33
 
32
34
  def register(self, ptr, length):
33
35
  try:
@@ -27,7 +27,11 @@ from sglang.srt.disaggregation.common.conn import (
27
27
  from sglang.srt.disaggregation.common.utils import group_concurrent_contiguous
28
28
  from sglang.srt.disaggregation.utils import DisaggregationMode
29
29
  from sglang.srt.server_args import ServerArgs
30
- from sglang.srt.utils import get_local_ip_by_remote
30
+ from sglang.srt.utils import (
31
+ format_tcp_address,
32
+ get_local_ip_auto,
33
+ is_valid_ipv6_address,
34
+ )
31
35
 
32
36
  logger = logging.getLogger(__name__)
33
37
 
@@ -124,7 +128,10 @@ class NixlKVManager(CommonKVManager):
124
128
  "to run SGLang with NixlTransferEngine."
125
129
  ) from e
126
130
  self.agent = nixl_agent(str(uuid.uuid4()))
131
+ self.local_ip = get_local_ip_auto()
127
132
  self.server_socket = zmq.Context().socket(zmq.PULL)
133
+ if is_valid_ipv6_address(self.local_ip):
134
+ self.server_socket.setsockopt(zmq.IPV6, 1)
128
135
  self.register_buffer_to_engine()
129
136
 
130
137
  if self.disaggregation_mode == DisaggregationMode.PREFILL:
@@ -337,8 +344,11 @@ class NixlKVManager(CommonKVManager):
337
344
  return False
338
345
  return self.transfer_statuses[room].is_done()
339
346
 
347
+ def _bind_server_socket(self):
348
+ self.server_socket.bind(format_tcp_address(self.local_ip, self.rank_port))
349
+
340
350
  def _start_bootstrap_thread(self):
341
- self.server_socket.bind(f"tcp://{get_local_ip_by_remote()}:{self.rank_port}")
351
+ self._bind_server_socket()
342
352
 
343
353
  def bootstrap_thread():
344
354
  """This thread recvs transfer info from the decode engine"""
@@ -452,23 +462,20 @@ class NixlKVReceiver(CommonKVReceiver):
452
462
 
453
463
  def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
454
464
  for bootstrap_info in self.bootstrap_infos:
455
- self.prefill_server_url = (
456
- f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}"
457
- )
458
465
  logger.debug(
459
466
  f"Fetched bootstrap info: {bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}"
460
467
  )
468
+ sock, lock = self._connect_to_bootstrap_server(bootstrap_info)
461
469
  is_dummy = bootstrap_info["is_dummy"]
462
470
  logger.debug(
463
- f"Sending to {self.prefill_server_url} with bootstrap room {self.bootstrap_room} {is_dummy=}"
471
+ f"Sending to prefill server with bootstrap room {self.bootstrap_room} {is_dummy=}"
464
472
  )
465
- sock, lock = self._connect("tcp://" + self.prefill_server_url)
466
473
  with lock:
467
474
  sock.send_multipart(
468
475
  [
469
476
  GUARD,
470
477
  str(self.bootstrap_room).encode("ascii"),
471
- get_local_ip_by_remote().encode("ascii"),
478
+ self.kv_mgr.local_ip.encode("ascii"),
472
479
  str(self.kv_mgr.rank_port).encode("ascii"),
473
480
  self.kv_mgr.agent.name.encode("ascii"),
474
481
  kv_indices.tobytes() if not is_dummy else b"",
@@ -494,9 +501,7 @@ class NixlKVReceiver(CommonKVReceiver):
494
501
 
495
502
  def _register_kv_args(self):
496
503
  for bootstrap_info in self.bootstrap_infos:
497
- self.prefill_server_url = (
498
- f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}"
499
- )
504
+ sock, lock = self._connect_to_bootstrap_server(bootstrap_info)
500
505
  packed_kv_data_ptrs = b"".join(
501
506
  struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
502
507
  )
@@ -504,13 +509,12 @@ class NixlKVReceiver(CommonKVReceiver):
504
509
  struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
505
510
  )
506
511
 
507
- sock, lock = self._connect("tcp://" + self.prefill_server_url)
508
512
  with lock:
509
513
  sock.send_multipart(
510
514
  [
511
515
  GUARD,
512
516
  "None".encode("ascii"),
513
- get_local_ip_by_remote().encode("ascii"),
517
+ self.kv_mgr.local_ip.encode("ascii"),
514
518
  str(self.kv_mgr.rank_port).encode("ascii"),
515
519
  self.kv_mgr.agent.name.encode("ascii"),
516
520
  self.kv_mgr.agent.get_agent_metadata(),
@@ -4,18 +4,18 @@ import ctypes
4
4
  import logging
5
5
  import os
6
6
  from contextlib import contextmanager
7
- from functools import wraps
8
- from typing import Any, Callable, List, Optional, TypeVar, Union
7
+ from typing import Any, List, Optional, Union
9
8
 
10
9
  import torch
11
10
  import torch.distributed as dist
12
11
  from torch.distributed import ProcessGroup
13
- from typing_extensions import ParamSpec
14
12
 
15
13
  from sglang.srt import _custom_ops as ops
16
14
  from sglang.srt.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
17
15
  from sglang.srt.distributed.device_communicators.custom_all_reduce_utils import (
18
16
  gpu_p2p_access_check,
17
+ is_full_nvlink,
18
+ is_weak_contiguous,
19
19
  )
20
20
  from sglang.srt.distributed.parallel_state import in_the_same_node_as
21
21
  from sglang.srt.utils import is_cuda, is_hip
@@ -25,23 +25,6 @@ logger = logging.getLogger(__name__)
25
25
  _is_cuda = is_cuda()
26
26
  _is_hip = is_hip()
27
27
 
28
- if _is_cuda:
29
- try:
30
- import pynvml
31
- except ImportError as e:
32
- logger.warning("Failed to import pynvml with %r", e)
33
-
34
- if _is_hip:
35
- try:
36
- from amdsmi import (
37
- AmdSmiException,
38
- amdsmi_get_processor_handles,
39
- amdsmi_init,
40
- amdsmi_shut_down,
41
- amdsmi_topo_get_link_type,
42
- )
43
- except ImportError as e:
44
- logger.warning("Failed to import amdsmi with %r", e)
45
28
 
46
29
  try:
47
30
  if ops.use_vllm_custom_allreduce and not _is_hip:
@@ -57,70 +40,6 @@ except Exception:
57
40
 
58
41
  logger = logging.getLogger(__name__)
59
42
 
60
- _P = ParamSpec("_P")
61
- _R = TypeVar("_R")
62
-
63
-
64
- def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
65
- @wraps(fn)
66
- def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
67
- if _is_hip:
68
- try:
69
- amdsmi_init()
70
- return fn(*args, **kwargs)
71
- finally:
72
- amdsmi_shut_down()
73
- else:
74
- pynvml.nvmlInit()
75
- try:
76
- return fn(*args, **kwargs)
77
- finally:
78
- pynvml.nvmlShutdown()
79
-
80
- return wrapper
81
-
82
-
83
- @with_nvml_context
84
- def is_full_nvlink(physical_device_ids: List[int], world_size: int) -> bool:
85
- if _is_hip:
86
- """
87
- query if the set of gpus are fully connected by xgmi (1 hop)
88
- """
89
- handles = [amdsmi_get_processor_handles()[i] for i in physical_device_ids]
90
- for i, handle in enumerate(handles):
91
- for j, peer_handle in enumerate(handles):
92
- if i < j:
93
- try:
94
- link_type = amdsmi_topo_get_link_type(handle, peer_handle)
95
- # type is 2 for XGMI
96
- if link_type["hops"] != 1 or link_type["type"] != 2:
97
- return False
98
- except AmdSmiException as error:
99
- logger.error("AMD 1 hop XGMI detection failed.", exc_info=error)
100
- return False
101
- return True
102
- else:
103
- """
104
- query if the set of gpus are fully connected by nvlink (1 hop)
105
- """
106
- handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids]
107
- for i, handle in enumerate(handles):
108
- for j, peer_handle in enumerate(handles):
109
- if i < j:
110
- try:
111
- p2p_status = pynvml.nvmlDeviceGetP2PStatus(
112
- handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK
113
- )
114
- if p2p_status != pynvml.NVML_P2P_STATUS_OK:
115
- return False
116
- except pynvml.NVMLError:
117
- logger.exception(
118
- "NVLink detection failed. This is normal if your"
119
- " machine has no NVLink equipped."
120
- )
121
- return False
122
- return True
123
-
124
43
 
125
44
  def _can_p2p(rank: int, world_size: int) -> bool:
126
45
  # SGLANG_SKIP_P2P_CHECK can be set to False in sglang
@@ -136,13 +55,6 @@ def _can_p2p(rank: int, world_size: int) -> bool:
136
55
  return True
137
56
 
138
57
 
139
- def is_weak_contiguous(inp: torch.Tensor):
140
- return inp.is_contiguous() or (
141
- inp.storage().nbytes() - inp.storage_offset() * inp.element_size()
142
- == inp.numel() * inp.element_size()
143
- )
144
-
145
-
146
58
  class CustomAllreduce:
147
59
  _SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
148
60
  _MAX_CAR_SIZE = 8192 * 1024
@@ -8,17 +8,44 @@ import pickle
8
8
  import subprocess
9
9
  import sys
10
10
  import tempfile
11
+ from functools import wraps
11
12
  from itertools import product
12
- from typing import Dict, List, Optional, Sequence
13
+ from typing import Callable, Dict, List, Optional, Sequence, TypeVar
13
14
 
14
15
  import torch
15
16
  import torch.distributed as dist
16
17
  import torch.multiprocessing as mp
18
+ from typing_extensions import ParamSpec
17
19
 
18
20
  from sglang.srt.distributed.device_communicators.cuda_wrapper import CudaRTLibrary
21
+ from sglang.srt.utils import is_cuda, is_hip
19
22
 
20
23
  logger = logging.getLogger(__name__)
21
24
 
25
+ _is_cuda = is_cuda()
26
+ _is_hip = is_hip()
27
+
28
+ if _is_cuda:
29
+ try:
30
+ import pynvml
31
+ except ImportError as e:
32
+ logger.warning("Failed to import pynvml with %r", e)
33
+
34
+ if _is_hip:
35
+ try:
36
+ from amdsmi import (
37
+ AmdSmiException,
38
+ amdsmi_get_processor_handles,
39
+ amdsmi_init,
40
+ amdsmi_shut_down,
41
+ amdsmi_topo_get_link_type,
42
+ )
43
+ except ImportError as e:
44
+ logger.warning("Failed to import amdsmi with %r", e)
45
+
46
+ _P = ParamSpec("_P")
47
+ _R = TypeVar("_R")
48
+
22
49
 
23
50
  def update_environment_variables(envs: Dict[str, str]):
24
51
  for k, v in envs.items():
@@ -282,6 +309,74 @@ def gpu_p2p_access_check(src: int, tgt: int) -> bool:
282
309
  return _gpu_p2p_access_cache[f"{src}->{tgt}"]
283
310
 
284
311
 
312
+ def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
313
+ @wraps(fn)
314
+ def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
315
+ if _is_hip:
316
+ try:
317
+ amdsmi_init()
318
+ return fn(*args, **kwargs)
319
+ finally:
320
+ amdsmi_shut_down()
321
+ else:
322
+ pynvml.nvmlInit()
323
+ try:
324
+ return fn(*args, **kwargs)
325
+ finally:
326
+ pynvml.nvmlShutdown()
327
+
328
+ return wrapper
329
+
330
+
331
+ @with_nvml_context
332
+ def is_full_nvlink(physical_device_ids: List[int], world_size: int) -> bool:
333
+ if _is_hip:
334
+ """
335
+ query if the set of gpus are fully connected by xgmi (1 hop)
336
+ """
337
+ handles = [amdsmi_get_processor_handles()[i] for i in physical_device_ids]
338
+ for i, handle in enumerate(handles):
339
+ for j, peer_handle in enumerate(handles):
340
+ if i < j:
341
+ try:
342
+ link_type = amdsmi_topo_get_link_type(handle, peer_handle)
343
+ # type is 2 for XGMI
344
+ if link_type["hops"] != 1 or link_type["type"] != 2:
345
+ return False
346
+ except AmdSmiException as error:
347
+ logger.error("AMD 1 hop XGMI detection failed.", exc_info=error)
348
+ return False
349
+ return True
350
+ else:
351
+ """
352
+ query if the set of gpus are fully connected by nvlink (1 hop)
353
+ """
354
+ handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids]
355
+ for i, handle in enumerate(handles):
356
+ for j, peer_handle in enumerate(handles):
357
+ if i < j:
358
+ try:
359
+ p2p_status = pynvml.nvmlDeviceGetP2PStatus(
360
+ handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK
361
+ )
362
+ if p2p_status != pynvml.NVML_P2P_STATUS_OK:
363
+ return False
364
+ except pynvml.NVMLError:
365
+ logger.exception(
366
+ "NVLink detection failed. This is normal if your"
367
+ " machine has no NVLink equipped."
368
+ )
369
+ return False
370
+ return True
371
+
372
+
373
+ def is_weak_contiguous(inp: torch.Tensor):
374
+ return inp.is_contiguous() or (
375
+ inp.storage().nbytes() - inp.storage_offset() * inp.element_size()
376
+ == inp.numel() * inp.element_size()
377
+ )
378
+
379
+
285
380
  __all__ = ["gpu_p2p_access_check"]
286
381
 
287
382
  if __name__ == "__main__":