sglang 0.4.3.post2__py3-none-any.whl → 0.4.3.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (205) hide show
  1. sglang/api.py +1 -1
  2. sglang/bench_offline_throughput.py +19 -0
  3. sglang/bench_one_batch.py +2 -2
  4. sglang/bench_serving.py +123 -79
  5. sglang/global_config.py +8 -3
  6. sglang/lang/backend/runtime_endpoint.py +1 -1
  7. sglang/lang/ir.py +1 -1
  8. sglang/srt/_custom_ops.py +83 -91
  9. sglang/srt/configs/load_config.py +4 -1
  10. sglang/srt/configs/model_config.py +48 -2
  11. sglang/srt/configs/qwen2_5_vl_config.py +5 -2
  12. sglang/srt/constrained/base_grammar_backend.py +117 -15
  13. sglang/srt/constrained/llguidance_backend.py +151 -0
  14. sglang/srt/constrained/outlines_backend.py +24 -33
  15. sglang/srt/constrained/xgrammar_backend.py +69 -38
  16. sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
  17. sglang/srt/distributed/parallel_state.py +48 -3
  18. sglang/srt/entrypoints/engine.py +67 -9
  19. sglang/srt/entrypoints/http_server.py +190 -41
  20. sglang/srt/entrypoints/verl_engine.py +147 -0
  21. sglang/srt/function_call_parser.py +0 -1
  22. sglang/srt/layers/activation.py +11 -0
  23. sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
  24. sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
  25. sglang/srt/layers/attention/flashinfer_backend.py +220 -378
  26. sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
  27. sglang/srt/layers/attention/torch_native_backend.py +1 -1
  28. sglang/srt/layers/attention/triton_backend.py +9 -6
  29. sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
  30. sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
  31. sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
  32. sglang/srt/layers/attention/utils.py +39 -0
  33. sglang/srt/layers/attention/vision.py +60 -63
  34. sglang/srt/layers/dp_attention.py +142 -1
  35. sglang/srt/layers/layernorm.py +1 -1
  36. sglang/srt/layers/linear.py +3 -1
  37. sglang/srt/layers/logits_processor.py +281 -45
  38. sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
  39. sglang/srt/layers/moe/ep_moe/layer.py +140 -28
  40. sglang/srt/layers/moe/fused_moe_native.py +2 -0
  41. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
  43. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
  44. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
  45. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
  46. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
  47. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
  48. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
  49. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
  50. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
  51. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
  52. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
  55. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
  56. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
  57. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
  58. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
  59. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
  60. sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
  63. sglang/srt/layers/moe/topk.py +13 -4
  64. sglang/srt/layers/quantization/__init__.py +111 -7
  65. sglang/srt/layers/quantization/blockwise_int8.py +409 -0
  66. sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  67. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  68. sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  69. sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  70. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  71. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  72. sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  73. sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  74. sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  75. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  76. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  77. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  78. sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  79. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  80. sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  81. sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  82. sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  83. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  84. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  85. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  86. sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  87. sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
  88. sglang/srt/layers/quantization/fp8.py +69 -28
  89. sglang/srt/layers/quantization/fp8_utils.py +17 -1
  90. sglang/srt/layers/quantization/gptq.py +416 -0
  91. sglang/srt/layers/quantization/int8_kernel.py +327 -0
  92. sglang/srt/layers/quantization/int8_utils.py +73 -0
  93. sglang/srt/layers/quantization/modelopt_quant.py +18 -1
  94. sglang/srt/layers/radix_attention.py +1 -0
  95. sglang/srt/layers/rotary_embedding.py +0 -1
  96. sglang/srt/layers/sampler.py +76 -31
  97. sglang/srt/layers/vocab_parallel_embedding.py +14 -13
  98. sglang/srt/lora/lora.py +17 -1
  99. sglang/srt/lora/lora_config.py +5 -0
  100. sglang/srt/lora/lora_manager.py +1 -3
  101. sglang/srt/managers/cache_controller.py +193 -62
  102. sglang/srt/managers/configure_logging.py +2 -1
  103. sglang/srt/managers/data_parallel_controller.py +6 -2
  104. sglang/srt/managers/detokenizer_manager.py +124 -102
  105. sglang/srt/managers/image_processor.py +2 -1
  106. sglang/srt/managers/io_struct.py +143 -6
  107. sglang/srt/managers/schedule_batch.py +237 -197
  108. sglang/srt/managers/schedule_policy.py +29 -29
  109. sglang/srt/managers/scheduler.py +681 -259
  110. sglang/srt/managers/session_controller.py +6 -2
  111. sglang/srt/managers/tokenizer_manager.py +224 -68
  112. sglang/srt/managers/tp_worker.py +15 -4
  113. sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
  114. sglang/srt/mem_cache/chunk_cache.py +18 -11
  115. sglang/srt/mem_cache/hiradix_cache.py +394 -0
  116. sglang/srt/mem_cache/memory_pool.py +44 -18
  117. sglang/srt/mem_cache/radix_cache.py +58 -47
  118. sglang/srt/metrics/collector.py +94 -36
  119. sglang/srt/model_executor/cuda_graph_runner.py +55 -24
  120. sglang/srt/model_executor/forward_batch_info.py +49 -16
  121. sglang/srt/model_executor/model_runner.py +208 -28
  122. sglang/srt/model_loader/loader.py +3 -3
  123. sglang/srt/model_loader/weight_utils.py +36 -14
  124. sglang/srt/models/baichuan.py +31 -6
  125. sglang/srt/models/chatglm.py +39 -7
  126. sglang/srt/models/commandr.py +29 -5
  127. sglang/srt/models/dbrx.py +31 -5
  128. sglang/srt/models/deepseek.py +43 -6
  129. sglang/srt/models/deepseek_nextn.py +32 -19
  130. sglang/srt/models/deepseek_v2.py +265 -32
  131. sglang/srt/models/exaone.py +19 -9
  132. sglang/srt/models/gemma.py +22 -8
  133. sglang/srt/models/gemma2.py +25 -12
  134. sglang/srt/models/gemma2_reward.py +5 -1
  135. sglang/srt/models/gpt2.py +28 -13
  136. sglang/srt/models/gpt_bigcode.py +27 -5
  137. sglang/srt/models/granite.py +21 -9
  138. sglang/srt/models/grok.py +21 -4
  139. sglang/srt/models/internlm2.py +36 -6
  140. sglang/srt/models/internlm2_reward.py +5 -1
  141. sglang/srt/models/llama.py +26 -9
  142. sglang/srt/models/llama_classification.py +5 -1
  143. sglang/srt/models/llama_eagle.py +17 -4
  144. sglang/srt/models/llama_embedding.py +5 -1
  145. sglang/srt/models/llama_reward.py +7 -2
  146. sglang/srt/models/llava.py +19 -3
  147. sglang/srt/models/llavavid.py +10 -1
  148. sglang/srt/models/minicpm.py +26 -2
  149. sglang/srt/models/minicpm3.py +39 -3
  150. sglang/srt/models/minicpmv.py +45 -14
  151. sglang/srt/models/mixtral.py +20 -9
  152. sglang/srt/models/mixtral_quant.py +50 -8
  153. sglang/srt/models/mllama.py +57 -11
  154. sglang/srt/models/olmo.py +34 -6
  155. sglang/srt/models/olmo2.py +34 -13
  156. sglang/srt/models/olmoe.py +26 -4
  157. sglang/srt/models/phi3_small.py +29 -10
  158. sglang/srt/models/qwen.py +26 -3
  159. sglang/srt/models/qwen2.py +26 -4
  160. sglang/srt/models/qwen2_5_vl.py +46 -8
  161. sglang/srt/models/qwen2_eagle.py +17 -5
  162. sglang/srt/models/qwen2_moe.py +44 -6
  163. sglang/srt/models/qwen2_rm.py +78 -0
  164. sglang/srt/models/qwen2_vl.py +39 -8
  165. sglang/srt/models/stablelm.py +32 -5
  166. sglang/srt/models/torch_native_llama.py +5 -2
  167. sglang/srt/models/xverse.py +21 -9
  168. sglang/srt/models/xverse_moe.py +45 -7
  169. sglang/srt/models/yivl.py +2 -1
  170. sglang/srt/openai_api/adapter.py +109 -24
  171. sglang/srt/openai_api/protocol.py +17 -1
  172. sglang/srt/reasoning_parser.py +154 -0
  173. sglang/srt/sampling/penaltylib/__init__.py +4 -6
  174. sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
  175. sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
  176. sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
  177. sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
  178. sglang/srt/sampling/sampling_batch_info.py +79 -157
  179. sglang/srt/sampling/sampling_params.py +16 -13
  180. sglang/srt/server_args.py +136 -52
  181. sglang/srt/speculative/build_eagle_tree.py +2 -8
  182. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +0 -1
  183. sglang/srt/speculative/eagle_utils.py +92 -58
  184. sglang/srt/speculative/eagle_worker.py +186 -94
  185. sglang/srt/speculative/spec_info.py +1 -13
  186. sglang/srt/utils.py +43 -17
  187. sglang/srt/warmup.py +47 -0
  188. sglang/test/few_shot_gsm8k.py +4 -1
  189. sglang/test/runners.py +389 -126
  190. sglang/test/send_one.py +88 -0
  191. sglang/test/test_block_fp8_ep.py +361 -0
  192. sglang/test/test_programs.py +1 -1
  193. sglang/test/test_utils.py +138 -84
  194. sglang/utils.py +50 -60
  195. sglang/version.py +1 -1
  196. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/METADATA +21 -15
  197. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/RECORD +200 -166
  198. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/WHEEL +1 -1
  199. sglang/bench_latency.py +0 -1
  200. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
  201. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
  202. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
  203. sglang/test/srt/sampling/penaltylib/utils.py +0 -344
  204. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/LICENSE +0 -0
  205. {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/top_level.txt +0 -0
@@ -18,24 +18,40 @@ from sglang.srt.distributed.device_communicators.custom_all_reduce_utils import
18
18
  gpu_p2p_access_check,
19
19
  )
20
20
  from sglang.srt.distributed.parallel_state import in_the_same_node_as
21
- from sglang.srt.utils import cuda_device_count_stateless, is_cuda
21
+ from sglang.srt.utils import cuda_device_count_stateless, is_cuda, is_hip
22
22
 
23
23
  logger = logging.getLogger(__name__)
24
24
 
25
+ is_hip_ = is_hip()
26
+
25
27
  if is_cuda():
26
28
  try:
27
29
  import pynvml
28
30
  except ImportError as e:
29
31
  logger.warning("Failed to import pynvml with %r", e)
30
32
 
33
+ if is_hip_:
34
+ try:
35
+ from amdsmi import (
36
+ AmdSmiException,
37
+ amdsmi_get_processor_handles,
38
+ amdsmi_init,
39
+ amdsmi_shut_down,
40
+ amdsmi_topo_get_link_type,
41
+ )
42
+ except ImportError as e:
43
+ logger.warning("Failed to import amdsmi with %r", e)
44
+
31
45
  try:
32
- if ops.use_vllm_custom_allreduce:
46
+ if ops.use_vllm_custom_allreduce and not is_hip_:
47
+ # Use vLLM custom allreduce
33
48
  ops.meta_size()
34
49
  else:
50
+ # Use custom allreduce from sgl kernel (ROCM and TRT-LLM)
35
51
  import sgl_kernel
36
52
  custom_ar = True
37
53
  except Exception:
38
- # For AMD GPUs and CPUs
54
+ # For CPUs
39
55
  custom_ar = False
40
56
 
41
57
  logger = logging.getLogger(__name__)
@@ -47,37 +63,62 @@ _R = TypeVar("_R")
47
63
  def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
48
64
  @wraps(fn)
49
65
  def wrapper(*args: _P.args, **kwargs: _P.kwargs) -> _R:
50
- pynvml.nvmlInit()
51
- try:
52
- return fn(*args, **kwargs)
53
- finally:
54
- pynvml.nvmlShutdown()
66
+ if is_hip_:
67
+ try:
68
+ amdsmi_init()
69
+ return fn(*args, **kwargs)
70
+ finally:
71
+ amdsmi_shut_down()
72
+ else:
73
+ pynvml.nvmlInit()
74
+ try:
75
+ return fn(*args, **kwargs)
76
+ finally:
77
+ pynvml.nvmlShutdown()
55
78
 
56
79
  return wrapper
57
80
 
58
81
 
59
82
  @with_nvml_context
60
- def is_full_nvlink(physical_device_ids: List[int]) -> bool:
61
- """
62
- query if the set of gpus are fully connected by nvlink (1 hop)
63
- """
64
- handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids]
65
- for i, handle in enumerate(handles):
66
- for j, peer_handle in enumerate(handles):
67
- if i < j:
68
- try:
69
- p2p_status = pynvml.nvmlDeviceGetP2PStatus(
70
- handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK
71
- )
72
- if p2p_status != pynvml.NVML_P2P_STATUS_OK:
83
+ def is_full_nvlink(physical_device_ids: List[int], world_size: int) -> bool:
84
+ if is_hip_:
85
+ """
86
+ query if the set of gpus are fully connected by xgmi (1 hop)
87
+ """
88
+ handles = [amdsmi_get_processor_handles()[i] for i in physical_device_ids]
89
+ for i, handle in enumerate(handles):
90
+ for j, peer_handle in enumerate(handles):
91
+ if i < j:
92
+ try:
93
+ link_type = amdsmi_topo_get_link_type(handle, peer_handle)
94
+ # type is 2 for XGMI
95
+ if link_type["hops"] != 1 or link_type["type"] != 2:
96
+ return False
97
+ except AmdSmiException as error:
98
+ logger.error("AMD 1 hop XGMI detection failed.", exc_info=error)
73
99
  return False
74
- except pynvml.NVMLError:
75
- logger.exception(
76
- "NVLink detection failed. This is normal if your"
77
- " machine has no NVLink equipped."
78
- )
79
- return False
80
- return True
100
+ return True
101
+ else:
102
+ """
103
+ query if the set of gpus are fully connected by nvlink (1 hop)
104
+ """
105
+ handles = [pynvml.nvmlDeviceGetHandleByIndex(i) for i in physical_device_ids]
106
+ for i, handle in enumerate(handles):
107
+ for j, peer_handle in enumerate(handles):
108
+ if i < j:
109
+ try:
110
+ p2p_status = pynvml.nvmlDeviceGetP2PStatus(
111
+ handle, peer_handle, pynvml.NVML_P2P_CAPS_INDEX_NVLINK
112
+ )
113
+ if p2p_status != pynvml.NVML_P2P_STATUS_OK:
114
+ return False
115
+ except pynvml.NVMLError:
116
+ logger.exception(
117
+ "NVLink detection failed. This is normal if your"
118
+ " machine has no NVLink equipped."
119
+ )
120
+ return False
121
+ return True
81
122
 
82
123
 
83
124
  def _can_p2p(rank: int, world_size: int) -> bool:
@@ -102,15 +143,18 @@ def is_weak_contiguous(inp: torch.Tensor):
102
143
 
103
144
 
104
145
  class CustomAllreduce:
105
-
106
146
  _SUPPORTED_WORLD_SIZES = [2, 4, 6, 8]
147
+ _MAX_CAR_SIZE = 8192 * 1024
148
+ if is_hip_:
149
+ # crossover is at 16MB buffer size for ROCm
150
+ _MAX_CAR_SIZE = 2 * 8192 * 1024
107
151
 
108
152
  # max_size: max supported allreduce size
109
153
  def __init__(
110
154
  self,
111
155
  group: ProcessGroup,
112
156
  device: Union[int, str, torch.device],
113
- max_size=8192 * 1024,
157
+ max_size=_MAX_CAR_SIZE,
114
158
  ) -> None:
115
159
  """
116
160
  Args:
@@ -185,12 +229,9 @@ class CustomAllreduce:
185
229
  # test nvlink first, this will filter out most of the cases
186
230
  # where custom allreduce is not supported
187
231
  # this checks hardware and driver support for NVLink
188
- if is_cuda():
189
- assert is_cuda()
232
+ if is_cuda() or is_hip_:
233
+ full_nvlink = is_full_nvlink(physical_device_ids, world_size)
190
234
 
191
- full_nvlink = is_full_nvlink(physical_device_ids)
192
- else:
193
- full_nvlink = False
194
235
  if world_size > 2 and not full_nvlink:
195
236
  logger.warning(
196
237
  "Custom allreduce is disabled because it's not supported on"
@@ -201,7 +242,8 @@ class CustomAllreduce:
201
242
  # test P2P capability, this checks software/cudaruntime support
202
243
  # this is expensive to compute at the first time
203
244
  # then we cache the result
204
- if not _can_p2p(rank, world_size):
245
+ # On AMD GPU, p2p is always enabled between XGMI connected GPUs
246
+ if not is_hip_ and not _can_p2p(rank, world_size):
205
247
  logger.warning(
206
248
  "Custom allreduce is disabled because your platform lacks "
207
249
  "GPU P2P capability or P2P test failed. To silence this "
@@ -214,7 +256,7 @@ class CustomAllreduce:
214
256
  self.world_size = world_size
215
257
  self.full_nvlink = full_nvlink
216
258
 
217
- if ops.use_vllm_custom_allreduce:
259
+ if ops.use_vllm_custom_allreduce and not is_hip_:
218
260
  # Buffers memory are owned by this Python class and passed to C++.
219
261
  # Meta data composes of two parts: meta data for synchronization and a
220
262
  # temporary buffer for storing intermediate allreduce results.
@@ -237,35 +279,56 @@ class CustomAllreduce:
237
279
  )
238
280
  ops.register_buffer(self._ptr, self.buffer_ptrs)
239
281
  else:
240
- # From TensorRT-LLM getMaxRequiredWorkspaceSize
241
- self.max_required_workspace_size = [16 * 1024 * 1024, 8 * 1024 * 1024]
282
+ if is_hip_:
283
+ # meta data buffers need to be "uncached" for signal on MI200
284
+ self.meta = ops.allocate_meta_buffer(ops.meta_size() + max_size)
285
+ self.buffer = torch.empty(
286
+ max_size, dtype=torch.uint8, device=self.device
287
+ )
288
+ handle = ops.get_meta_buffer_ipc_handle(self.meta)
289
+ shard_data = (
290
+ bytes(handle), # ipc handle to base ptr
291
+ 0, # offset of base ptr
292
+ )
293
+ handles, offsets = self._gather_ipc_meta(shard_data)
294
+ self.rank_data = torch.empty(
295
+ 8 * 1024 * 1024, dtype=torch.uint8, device=self.device
296
+ )
297
+ self._ptr = ops.init_custom_ar(
298
+ self.meta, self.rank_data, handles, offsets, rank, self.full_nvlink
299
+ )
300
+ self.register_buffer(self.buffer)
301
+ self.MSCCL = os.getenv("RCCL_MSCCL_ENABLE", "1") == "1"
302
+ else:
303
+ # From TensorRT-LLM getMaxRequiredWorkspaceSize
304
+ self.max_required_workspace_size = [16 * 1024 * 1024, 8 * 1024 * 1024]
242
305
 
243
- # sizeof(uint32_t) * (MAX_ALL_REDUCE_BLOCKS + 2) * MAX_RANKS_PER_NODE;
244
- self.barrier_max_size = 8 * (36 + 2) * 8
306
+ # sizeof(uint32_t) * (MAX_ALL_REDUCE_BLOCKS + 2) * MAX_RANKS_PER_NODE;
307
+ self.barrier_max_size = 8 * (36 + 2) * 8
245
308
 
246
- self.buffer_ptrs = self.create_shared_buffer(max_size, group=group)
247
- self.tmp_result_buffer_ptrs = self.create_shared_buffer(
248
- max_size, group=group
249
- )
250
- self.rank_data_base = torch.empty(
251
- 8 * 1024 * 1024, dtype=torch.uint8, device=self.device
252
- )
253
- self.barrier_in_ptrs = self.create_shared_buffer(
254
- self.barrier_max_size, group=group
255
- )
256
- self.barrier_out_ptrs = self.create_shared_buffer(
257
- self.barrier_max_size, group=group
258
- )
309
+ self.buffer_ptrs = self.create_shared_buffer(max_size, group=group)
310
+ self.tmp_result_buffer_ptrs = self.create_shared_buffer(
311
+ max_size, group=group
312
+ )
313
+ self.rank_data_base = torch.empty(
314
+ 8 * 1024 * 1024, dtype=torch.uint8, device=self.device
315
+ )
316
+ self.barrier_in_ptrs = self.create_shared_buffer(
317
+ self.barrier_max_size, group=group
318
+ )
319
+ self.barrier_out_ptrs = self.create_shared_buffer(
320
+ self.barrier_max_size, group=group
321
+ )
259
322
 
260
- self._ptr = ops.init_custom_ar(
261
- rank,
262
- world_size,
263
- self.rank_data_base,
264
- self.buffer_ptrs,
265
- self.tmp_result_buffer_ptrs,
266
- self.barrier_in_ptrs,
267
- self.barrier_out_ptrs,
268
- )
323
+ self._ptr = ops.init_custom_ar(
324
+ rank,
325
+ world_size,
326
+ self.rank_data_base,
327
+ self.buffer_ptrs,
328
+ self.tmp_result_buffer_ptrs,
329
+ self.barrier_in_ptrs,
330
+ self.barrier_out_ptrs,
331
+ )
269
332
  self.disabled = False
270
333
 
271
334
  @staticmethod
@@ -316,23 +379,69 @@ class CustomAllreduce:
316
379
  if not self.disabled:
317
380
  self.register_graph_buffers()
318
381
 
319
- def register_graph_buffers(self):
320
- handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr)
321
- logger.info("Registering %d cuda graph addresses", len(offset))
322
- # We cannot directly use `dist.all_gather_object` here
323
- # because it is incompatible with `gloo` backend under inference mode.
324
- # see https://github.com/pytorch/pytorch/issues/126032 for details.
325
- all_data = [[None, None] for _ in range(dist.get_world_size(group=self.group))]
326
- all_data[self.rank] = [handle, offset]
327
- ranks = sorted(dist.get_process_group_ranks(group=self.group))
382
+ def _get_ipc_meta(self, inp: torch.Tensor):
383
+ # _share_cuda_() doesn't accept meta buffer not allocated from
384
+ # PyTorch cache allocator, use direct HIP call to get IPC handle
385
+ handle = ops.get_meta_buffer_ipc_handle(inp)
386
+ shard_data = (
387
+ bytes(handle), # ipc handle to base ptr
388
+ 0, # offset of base ptr
389
+ )
390
+ return self._gather_ipc_meta(shard_data)
391
+
392
+ def _gather_ipc_meta(self, shard_data):
393
+ # Note: don't use `[[None]] * self.world_size` here
394
+ # because it will create a list of the same reference
395
+ all_data: List[Optional[Any]] = [[None] for i in range(self.world_size)]
396
+ all_data[self.rank][0] = shard_data
397
+
398
+ ranks = dist.get_process_group_ranks(group=self.group)
399
+ ranks.sort()
328
400
  for i, rank in enumerate(ranks):
329
401
  dist.broadcast_object_list(
330
402
  all_data[i], src=rank, group=self.group, device="cpu"
331
403
  )
332
- # Unpack list of tuples to tuple of lists.
333
- handles = [d[0] for d in all_data] # type: ignore
334
- offsets = [d[1] for d in all_data] # type: ignore
335
- ops.register_graph_buffers(self._ptr, handles, offsets)
404
+
405
+ # we cannot directly use `dist.all_gather_object` here
406
+ # because it is incompatible with `gloo` backend under inference mode.
407
+ # see https://github.com/pytorch/pytorch/issues/126032 for details.
408
+
409
+ handles = []
410
+ offsets = []
411
+ for i in range(len(all_data)):
412
+ handles.append(all_data[i][0][0]) # type: ignore
413
+ offsets.append(all_data[i][0][1]) # type: ignore
414
+ return handles, offsets
415
+
416
+ def register_buffer(self, inp: torch.Tensor):
417
+ handles, offsets = self._get_ipc_meta(inp)
418
+ ops.register_buffer(self._ptr, inp, handles, offsets)
419
+
420
+ def register_graph_buffers(self):
421
+ if is_hip_:
422
+ handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr)
423
+ handles, offsets = self._gather_ipc_meta((bytes(handle), offset))
424
+ logger.info("Registering %d cuda graph addresses", len(offset))
425
+ ops.register_graph_buffers(self._ptr, handles, offsets)
426
+ else:
427
+ handle, offset = ops.get_graph_buffer_ipc_meta(self._ptr)
428
+ logger.info("Registering %d cuda graph addresses", len(offset))
429
+ # We cannot directly use `dist.all_gather_object` here
430
+ # because it is incompatible with `gloo` backend under inference mode.
431
+ # see https://github.com/pytorch/pytorch/issues/126032 for details.
432
+ all_data = [
433
+ [None, None] for _ in range(dist.get_world_size(group=self.group))
434
+ ]
435
+ all_data[self.rank] = [handle, offset]
436
+ ranks = sorted(dist.get_process_group_ranks(group=self.group))
437
+ for i, rank in enumerate(ranks):
438
+ dist.broadcast_object_list(
439
+ all_data[i], src=rank, group=self.group, device="cpu"
440
+ )
441
+ # Unpack list of tuples to tuple of lists.
442
+ handles = [d[0] for d in all_data] # type: ignore
443
+ offsets = [d[1] for d in all_data] # type: ignore
444
+ ops.register_graph_buffers(self._ptr, handles, offsets)
336
445
 
337
446
  def should_custom_ar(self, inp: torch.Tensor):
338
447
  if self.disabled:
@@ -345,11 +454,22 @@ class CustomAllreduce:
345
454
  return False
346
455
  # for 4 or more non NVLink-capable GPUs, custom allreduce provides
347
456
  # little performance improvement over NCCL.
348
- if ops.use_vllm_custom_allreduce:
457
+ if ops.use_vllm_custom_allreduce and not is_hip_:
349
458
  if self.world_size == 2 or self.full_nvlink:
350
459
  return inp_size < self.max_size
351
460
  return False
352
461
 
462
+ if is_hip_:
463
+ if self.full_nvlink:
464
+ if self.world_size == 8:
465
+ if self.MSCCL:
466
+ return False
467
+ else:
468
+ return inp_size < self.max_size
469
+ else:
470
+ return inp_size < self.max_size
471
+ return False
472
+
353
473
  if self.world_size == 2:
354
474
  return (
355
475
  inp_size < self.max_size
@@ -364,6 +484,21 @@ class CustomAllreduce:
364
484
 
365
485
  return False
366
486
 
487
+ # all reduce, assuming inp tensor is IPC registered with register_buffer,
488
+ # or, in the context of cuda graphs, register_graph_buffers
489
+ def all_reduce_reg(self, inp: torch.Tensor, out: torch.Tensor = None):
490
+ if out is None:
491
+ out = torch.empty_like(inp)
492
+ ops.all_reduce_reg(self._ptr, inp, out)
493
+ return out
494
+
495
+ # all reduce, assuming inp tensor is NOT IPC registered
496
+ def all_reduce_unreg(self, inp: torch.Tensor, out: torch.Tensor = None):
497
+ if out is None:
498
+ out = torch.empty_like(inp)
499
+ ops.all_reduce_unreg(self._ptr, inp, self.buffer, out)
500
+ return out
501
+
367
502
  def all_reduce(
368
503
  self,
369
504
  inp: torch.Tensor,
@@ -397,13 +532,23 @@ class CustomAllreduce:
397
532
  return None
398
533
  if self._IS_CAPTURING:
399
534
  if torch.cuda.is_current_stream_capturing():
400
- return self.all_reduce(input, registered=True)
535
+ if is_hip_:
536
+ return self.all_reduce_reg(input)
537
+ else:
538
+ return self.all_reduce(input, registered=True)
401
539
  else:
402
540
  # If warm up, mimic the allocation pattern since custom
403
541
  # allreduce is out-of-place.
404
542
  return torch.empty_like(input)
405
543
  else:
406
- return self.all_reduce(input, registered=False)
544
+ if is_hip_:
545
+ # note: outside of cuda graph context,
546
+ # custom allreduce incurs a cost of cudaMemcpy, which should
547
+ # be small(<=1% of overall latency) compared to the performance
548
+ # gains of using custom kernels
549
+ return self.all_reduce_unreg(input)
550
+ else:
551
+ return self.all_reduce(input, registered=False)
407
552
 
408
553
  def close(self):
409
554
  if not self.disabled and self._ptr:
@@ -411,7 +556,7 @@ class CustomAllreduce:
411
556
  if ops.use_vllm_custom_allreduce:
412
557
  self.free_shared_buffer(self.meta_ptrs)
413
558
  self.free_shared_buffer(self.buffer_ptrs)
414
- else:
559
+ elif is_cuda():
415
560
  self.free_shared_buffer(self.buffer_ptrs)
416
561
  self.free_shared_buffer(self.tmp_result_buffer_ptrs)
417
562
  self.free_shared_buffer(self.barrier_in_ptrs)
@@ -30,6 +30,7 @@ import weakref
30
30
  from collections import namedtuple
31
31
  from contextlib import contextmanager, nullcontext
32
32
  from dataclasses import dataclass
33
+ from datetime import timedelta
33
34
  from multiprocessing import shared_memory
34
35
  from typing import Any, Callable, Dict, List, Optional, Tuple, Union
35
36
  from unittest.mock import patch
@@ -138,6 +139,27 @@ if supports_custom_op():
138
139
  fake_impl=outplace_all_reduce_fake,
139
140
  )
140
141
 
142
+ def reg_all_gather_into_tensor(
143
+ output: torch.Tensor, input: torch.Tensor, group_name: str
144
+ ) -> None:
145
+ assert group_name in _groups, f"Group {group_name} is not found."
146
+ group = _groups[group_name]()
147
+ if group is None:
148
+ raise ValueError(f"Group {group_name} is destroyed.")
149
+ group._all_gather_into_tensor(output, input)
150
+
151
+ def reg_all_gather_into_tensor_fake(
152
+ output: torch.Tensor, input: torch.Tensor, group_name: str
153
+ ) -> None:
154
+ pass
155
+
156
+ direct_register_custom_op(
157
+ op_name="reg_all_gather_into_tensor",
158
+ op_func=reg_all_gather_into_tensor,
159
+ mutates_args=["output"],
160
+ fake_impl=reg_all_gather_into_tensor_fake,
161
+ )
162
+
141
163
 
142
164
  class GroupCoordinator:
143
165
  """
@@ -413,6 +435,23 @@ class GroupCoordinator:
413
435
  else:
414
436
  torch.distributed.all_reduce(input_, group=self.device_group)
415
437
 
438
+ def _all_gather_into_tensor(self, output: torch.Tensor, input: torch.Tensor):
439
+ pynccl_comm = self.pynccl_comm
440
+ if pynccl_comm is not None and not pynccl_comm.disabled:
441
+ pynccl_comm.all_gather(output, input)
442
+ else:
443
+ torch.distributed.all_gather_into_tensor(
444
+ output, input, group=self.device_group
445
+ )
446
+
447
+ def all_gather_into_tensor(self, output: torch.Tensor, input: torch.Tensor):
448
+ if not supports_custom_op():
449
+ self._all_gather_into_tensor(output, input)
450
+ else:
451
+ torch.ops.sglang.reg_all_gather_into_tensor(
452
+ output, input, group_name=self.unique_name
453
+ )
454
+
416
455
  def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
417
456
  world_size = self.world_size
418
457
  # Bypass the function if we are using only 1 GPU.
@@ -440,9 +479,7 @@ class GroupCoordinator:
440
479
  output_size, dtype=input_.dtype, device=input_.device
441
480
  )
442
481
  # All-gather.
443
- torch.distributed.all_gather_into_tensor(
444
- output_tensor, input_, group=self.device_group
445
- )
482
+ self.all_gather_into_tensor(output_tensor, input_)
446
483
  # Reshape
447
484
  output_tensor = output_tensor.reshape((world_size,) + input_size)
448
485
  output_tensor = output_tensor.movedim(0, dim)
@@ -960,6 +997,7 @@ def init_distributed_environment(
960
997
  distributed_init_method: str = "env://",
961
998
  local_rank: int = -1,
962
999
  backend: str = "nccl",
1000
+ timeout: Optional[int] = None,
963
1001
  ):
964
1002
  logger.debug(
965
1003
  "world_size=%d rank=%d local_rank=%d " "distributed_init_method=%s backend=%s",
@@ -974,13 +1012,20 @@ def init_distributed_environment(
974
1012
  "distributed_init_method must be provided when initializing "
975
1013
  "distributed environment"
976
1014
  )
1015
+ if timeout is not None:
1016
+ assert isinstance(timeout, (int)), "timeout must be a number"
1017
+ assert timeout > 0, "timeout must be positive"
1018
+ timeout = timedelta(seconds=timeout)
1019
+
977
1020
  # this backend is used for WORLD
978
1021
  torch.distributed.init_process_group(
979
1022
  backend=backend,
980
1023
  init_method=distributed_init_method,
981
1024
  world_size=world_size,
982
1025
  rank=rank,
1026
+ timeout=timeout,
983
1027
  )
1028
+
984
1029
  # set the local rank
985
1030
  # local_rank is not available in torch ProcessGroup,
986
1031
  # see https://github.com/pytorch/pytorch/issues/122816