sglang 0.4.1.post6__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (141) hide show
  1. sglang/__init__.py +21 -23
  2. sglang/api.py +2 -7
  3. sglang/bench_offline_throughput.py +41 -27
  4. sglang/bench_one_batch.py +60 -4
  5. sglang/bench_one_batch_server.py +1 -1
  6. sglang/bench_serving.py +83 -71
  7. sglang/lang/backend/runtime_endpoint.py +183 -4
  8. sglang/lang/chat_template.py +46 -4
  9. sglang/launch_server.py +1 -1
  10. sglang/srt/_custom_ops.py +80 -42
  11. sglang/srt/configs/device_config.py +1 -1
  12. sglang/srt/configs/load_config.py +1 -0
  13. sglang/srt/configs/model_config.py +1 -0
  14. sglang/srt/constrained/base_grammar_backend.py +21 -0
  15. sglang/srt/constrained/xgrammar_backend.py +8 -4
  16. sglang/srt/conversation.py +14 -1
  17. sglang/srt/distributed/__init__.py +3 -3
  18. sglang/srt/distributed/communication_op.py +2 -1
  19. sglang/srt/distributed/device_communicators/cuda_wrapper.py +2 -1
  20. sglang/srt/distributed/device_communicators/custom_all_reduce.py +112 -42
  21. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
  22. sglang/srt/distributed/device_communicators/hpu_communicator.py +2 -1
  23. sglang/srt/distributed/device_communicators/pynccl.py +80 -1
  24. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +112 -2
  25. sglang/srt/distributed/device_communicators/shm_broadcast.py +5 -72
  26. sglang/srt/distributed/device_communicators/xpu_communicator.py +2 -1
  27. sglang/srt/distributed/parallel_state.py +1 -1
  28. sglang/srt/distributed/utils.py +2 -1
  29. sglang/srt/entrypoints/engine.py +452 -0
  30. sglang/srt/entrypoints/http_server.py +603 -0
  31. sglang/srt/function_call_parser.py +494 -0
  32. sglang/srt/layers/activation.py +8 -8
  33. sglang/srt/layers/attention/flashinfer_backend.py +10 -9
  34. sglang/srt/layers/attention/triton_backend.py +4 -6
  35. sglang/srt/layers/attention/vision.py +204 -0
  36. sglang/srt/layers/dp_attention.py +71 -0
  37. sglang/srt/layers/layernorm.py +5 -5
  38. sglang/srt/layers/linear.py +65 -14
  39. sglang/srt/layers/logits_processor.py +49 -64
  40. sglang/srt/layers/moe/ep_moe/layer.py +24 -16
  41. sglang/srt/layers/moe/fused_moe_native.py +84 -1
  42. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  43. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +27 -7
  44. sglang/srt/layers/moe/fused_moe_triton/layer.py +38 -5
  45. sglang/srt/layers/parameter.py +18 -8
  46. sglang/srt/layers/quantization/__init__.py +20 -23
  47. sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  48. sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  49. sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  50. sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  51. sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  52. sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  53. sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  54. sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  55. sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
  56. sglang/srt/layers/quantization/fp8.py +10 -4
  57. sglang/srt/layers/quantization/modelopt_quant.py +1 -2
  58. sglang/srt/layers/quantization/w8a8_int8.py +1 -1
  59. sglang/srt/layers/radix_attention.py +2 -2
  60. sglang/srt/layers/rotary_embedding.py +1184 -31
  61. sglang/srt/layers/sampler.py +64 -6
  62. sglang/srt/layers/torchao_utils.py +12 -6
  63. sglang/srt/layers/vocab_parallel_embedding.py +2 -2
  64. sglang/srt/lora/lora.py +1 -9
  65. sglang/srt/managers/configure_logging.py +3 -0
  66. sglang/srt/managers/data_parallel_controller.py +79 -72
  67. sglang/srt/managers/detokenizer_manager.py +24 -6
  68. sglang/srt/managers/image_processor.py +158 -2
  69. sglang/srt/managers/io_struct.py +57 -3
  70. sglang/srt/managers/schedule_batch.py +78 -45
  71. sglang/srt/managers/schedule_policy.py +26 -12
  72. sglang/srt/managers/scheduler.py +326 -201
  73. sglang/srt/managers/session_controller.py +1 -0
  74. sglang/srt/managers/tokenizer_manager.py +210 -121
  75. sglang/srt/managers/tp_worker.py +6 -4
  76. sglang/srt/managers/tp_worker_overlap_thread.py +5 -8
  77. sglang/srt/managers/utils.py +44 -0
  78. sglang/srt/mem_cache/memory_pool.py +10 -32
  79. sglang/srt/metrics/collector.py +15 -6
  80. sglang/srt/model_executor/cuda_graph_runner.py +26 -30
  81. sglang/srt/model_executor/forward_batch_info.py +5 -7
  82. sglang/srt/model_executor/model_runner.py +44 -19
  83. sglang/srt/model_loader/loader.py +83 -6
  84. sglang/srt/model_loader/weight_utils.py +145 -6
  85. sglang/srt/models/baichuan.py +6 -6
  86. sglang/srt/models/chatglm.py +2 -2
  87. sglang/srt/models/commandr.py +17 -5
  88. sglang/srt/models/dbrx.py +13 -5
  89. sglang/srt/models/deepseek.py +3 -3
  90. sglang/srt/models/deepseek_v2.py +11 -11
  91. sglang/srt/models/exaone.py +2 -2
  92. sglang/srt/models/gemma.py +2 -2
  93. sglang/srt/models/gemma2.py +15 -25
  94. sglang/srt/models/gpt2.py +3 -5
  95. sglang/srt/models/gpt_bigcode.py +1 -1
  96. sglang/srt/models/granite.py +2 -2
  97. sglang/srt/models/grok.py +4 -3
  98. sglang/srt/models/internlm2.py +2 -2
  99. sglang/srt/models/llama.py +7 -5
  100. sglang/srt/models/minicpm.py +2 -2
  101. sglang/srt/models/minicpm3.py +9 -9
  102. sglang/srt/models/minicpmv.py +1238 -0
  103. sglang/srt/models/mixtral.py +3 -3
  104. sglang/srt/models/mixtral_quant.py +3 -3
  105. sglang/srt/models/mllama.py +2 -2
  106. sglang/srt/models/olmo.py +3 -3
  107. sglang/srt/models/olmo2.py +4 -4
  108. sglang/srt/models/olmoe.py +7 -13
  109. sglang/srt/models/phi3_small.py +2 -2
  110. sglang/srt/models/qwen.py +2 -2
  111. sglang/srt/models/qwen2.py +41 -4
  112. sglang/srt/models/qwen2_moe.py +3 -3
  113. sglang/srt/models/qwen2_vl.py +22 -122
  114. sglang/srt/models/stablelm.py +2 -2
  115. sglang/srt/models/torch_native_llama.py +20 -7
  116. sglang/srt/models/xverse.py +6 -6
  117. sglang/srt/models/xverse_moe.py +6 -6
  118. sglang/srt/openai_api/adapter.py +139 -37
  119. sglang/srt/openai_api/protocol.py +7 -4
  120. sglang/srt/sampling/custom_logit_processor.py +38 -0
  121. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +11 -14
  122. sglang/srt/sampling/sampling_batch_info.py +143 -18
  123. sglang/srt/sampling/sampling_params.py +3 -1
  124. sglang/srt/server.py +4 -1090
  125. sglang/srt/server_args.py +77 -15
  126. sglang/srt/speculative/eagle_utils.py +37 -15
  127. sglang/srt/speculative/eagle_worker.py +11 -13
  128. sglang/srt/utils.py +164 -129
  129. sglang/test/runners.py +8 -13
  130. sglang/test/test_programs.py +2 -1
  131. sglang/test/test_utils.py +83 -22
  132. sglang/utils.py +12 -2
  133. sglang/version.py +1 -1
  134. {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/METADATA +21 -10
  135. {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/RECORD +138 -123
  136. sglang/launch_server_llavavid.py +0 -25
  137. sglang/srt/constrained/__init__.py +0 -16
  138. sglang/srt/distributed/device_communicators/__init__.py +0 -0
  139. {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/LICENSE +0 -0
  140. {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/WHEEL +0 -0
  141. {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/top_level.txt +0 -0
sglang/srt/_custom_ops.py CHANGED
@@ -1,8 +1,9 @@
1
- # Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/_custom_ops.py
1
+ # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/_custom_ops.py
2
2
  import contextlib
3
3
  import functools
4
4
  import importlib
5
5
  import logging
6
+ import os
6
7
  from typing import TYPE_CHECKING, List, Optional, Tuple, Union
7
8
 
8
9
  import torch
@@ -11,12 +12,19 @@ import torch.library
11
12
  from sglang.srt.utils import is_hpu
12
13
 
13
14
  logger = logging.getLogger(__name__)
15
+ use_vllm_custom_allreduce = os.environ.get("USE_VLLM_CUSTOM_ALLREDUCE", default=True)
14
16
 
15
17
  if not is_hpu():
16
- try:
17
- import custom_ar
18
- except ImportError as e:
19
- logger.warning("Failed to import from custom_ar with %r", e)
18
+ if use_vllm_custom_allreduce:
19
+ try:
20
+ import vllm._C
21
+ except ImportError as e:
22
+ logger.warning("Failed to import from vllm._C with %r", e)
23
+ else:
24
+ try:
25
+ import sgl_kernel
26
+ except ImportError as e:
27
+ logger.warning("Failed to import from custom_ar with %r", e)
20
28
 
21
29
 
22
30
  def hint_on_error(fn):
@@ -48,48 +56,78 @@ def hint_on_error(fn):
48
56
  return wrapper
49
57
 
50
58
 
51
- # custom ar
52
- def init_custom_ar(
53
- ipc_tensors: List[torch.Tensor],
54
- rank_data: torch.Tensor,
55
- rank: int,
56
- full_nvlink: bool,
57
- ) -> int:
58
- return torch.ops._C_vllm_ar.init_custom_ar(
59
- ipc_tensors, rank_data, rank, full_nvlink
60
- )
61
-
62
-
63
- def all_reduce(
64
- fa: int,
65
- inp: torch.Tensor,
66
- out: torch.Tensor,
67
- reg_buffer: int,
68
- reg_buffer_sz_bytes: int,
69
- ) -> None:
70
- torch.ops._C_vllm_ar.all_reduce(fa, inp, out, reg_buffer, reg_buffer_sz_bytes)
71
-
72
-
73
- def dispose(fa: int) -> None:
74
- torch.ops._C_vllm_ar.dispose(fa)
75
-
76
-
77
- def meta_size() -> int:
78
- return torch.ops._C_vllm_ar.meta_size()
79
-
59
+ if use_vllm_custom_allreduce:
60
+ # custom ar
61
+ def init_custom_ar(
62
+ ipc_tensors: List[torch.Tensor],
63
+ rank_data: torch.Tensor,
64
+ rank: int,
65
+ full_nvlink: bool,
66
+ ) -> int:
67
+ return torch.ops._C_custom_ar.init_custom_ar(
68
+ ipc_tensors, rank_data, rank, full_nvlink
69
+ )
80
70
 
81
- def register_buffer(fa: int, ipc_tensors: List[int]) -> None:
82
- return torch.ops._C_vllm_ar.register_buffer(fa, ipc_tensors)
71
+ def all_reduce(
72
+ fa: int,
73
+ inp: torch.Tensor,
74
+ out: torch.Tensor,
75
+ reg_buffer: int,
76
+ reg_buffer_sz_bytes: int,
77
+ ) -> None:
78
+ torch.ops._C_custom_ar.all_reduce(fa, inp, out, reg_buffer, reg_buffer_sz_bytes)
79
+
80
+ def dispose(fa: int) -> None:
81
+ torch.ops._C_custom_ar.dispose(fa)
82
+
83
+ def meta_size() -> int:
84
+ return torch.ops._C_custom_ar.meta_size()
85
+
86
+ def register_buffer(fa: int, ipc_tensors: List[int]) -> None:
87
+ return torch.ops._C_custom_ar.register_buffer(fa, ipc_tensors)
88
+
89
+ def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[int], List[int]]:
90
+ return torch.ops._C_custom_ar.get_graph_buffer_ipc_meta(fa)
91
+
92
+ def register_graph_buffers(
93
+ fa: int, handles: List[List[int]], offsets: List[List[int]]
94
+ ) -> None:
95
+ torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets)
96
+
97
+ else:
98
+ # custom ar
99
+ def init_custom_ar(
100
+ rank_id: int,
101
+ world_size: int,
102
+ rank_data_base: torch.Tensor,
103
+ buffers: List[int],
104
+ tmp_result_buffers: List[int],
105
+ barrier_in: List[int],
106
+ barrier_out: List[int],
107
+ ) -> int:
108
+ return sgl_kernel.ops.init_custom_reduce(
109
+ rank_id,
110
+ world_size,
111
+ rank_data_base,
112
+ buffers,
113
+ tmp_result_buffers,
114
+ barrier_in,
115
+ barrier_out,
116
+ )
83
117
 
118
+ def all_reduce(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
119
+ sgl_kernel.ops.custom_reduce(fa, inp, out)
84
120
 
85
- def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[int], List[int]]:
86
- return torch.ops._C_vllm_ar.get_graph_buffer_ipc_meta(fa)
121
+ def dispose(fa: int) -> None:
122
+ sgl_kernel.ops.custom_dispose(fa)
87
123
 
124
+ def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[int], List[int]]:
125
+ return sgl_kernel.ops.get_graph_buffer_ipc_meta(fa)
88
126
 
89
- def register_graph_buffers(
90
- fa: int, handles: List[List[int]], offsets: List[List[int]]
91
- ) -> None:
92
- torch.ops._C_vllm_ar.register_graph_buffers(fa, handles, offsets)
127
+ def register_graph_buffers(
128
+ fa: int, handles: List[List[int]], offsets: List[List[int]]
129
+ ) -> None:
130
+ sgl_kernel.ops.register_graph_buffers(fa, handles, offsets)
93
131
 
94
132
 
95
133
  # temporary fix for https://github.com/vllm-project/vllm/issues/5456
@@ -10,7 +10,7 @@ class DeviceConfig:
10
10
  device: Optional[torch.device]
11
11
 
12
12
  def __init__(self, device: str = "cuda") -> None:
13
- if device in ["cuda", "xpu", "hpu"]:
13
+ if device in ["cuda", "xpu", "hpu", "cpu"]:
14
14
  self.device_type = device
15
15
  else:
16
16
  raise RuntimeError(f"Not supported device type: {device}")
@@ -20,6 +20,7 @@ class LoadFormat(str, enum.Enum):
20
20
  GGUF = "gguf"
21
21
  BITSANDBYTES = "bitsandbytes"
22
22
  MISTRAL = "mistral"
23
+ LAYERED = "layered"
23
24
 
24
25
 
25
26
  @dataclass
@@ -402,6 +402,7 @@ def is_multimodal_model(model_architectures: List[str]):
402
402
  or "LlavaVidForCausalLM" in model_architectures
403
403
  or "MllamaForConditionalGeneration" in model_architectures
404
404
  or "Qwen2VLForConditionalGeneration" in model_architectures
405
+ or "MiniCPMV" in model_architectures
405
406
  ):
406
407
  return True
407
408
  else:
@@ -18,6 +18,8 @@ from dataclasses import dataclass
18
18
  from threading import Event, Lock
19
19
  from typing import Any, Optional, Tuple
20
20
 
21
+ from sglang.srt.server_args import ServerArgs
22
+
21
23
 
22
24
  @dataclass
23
25
  class CacheEntry:
@@ -69,3 +71,22 @@ class BaseGrammarBackend:
69
71
  def reset(self):
70
72
  with self.cache_lock:
71
73
  self.cache.clear()
74
+
75
+
76
+ def create_grammar_backend(server_args: ServerArgs, tokenizer, vocab_size):
77
+ if server_args.grammar_backend == "outlines":
78
+ from sglang.srt.constrained.outlines_backend import OutlinesGrammarBackend
79
+
80
+ grammar_backend = OutlinesGrammarBackend(
81
+ tokenizer,
82
+ whitespace_pattern=server_args.constrained_json_whitespace_pattern,
83
+ allow_jump_forward=not server_args.disable_jump_forward,
84
+ )
85
+ elif server_args.grammar_backend == "xgrammar":
86
+ from sglang.srt.constrained.xgrammar_backend import XGrammarGrammarBackend
87
+
88
+ grammar_backend = XGrammarGrammarBackend(tokenizer, vocab_size=vocab_size)
89
+ else:
90
+ raise ValueError(f"Invalid grammar backend: {server_args.grammar_backend}")
91
+
92
+ return grammar_backend
@@ -19,6 +19,7 @@ from typing import List, Tuple
19
19
  import torch
20
20
  from xgrammar import (
21
21
  CompiledGrammar,
22
+ Grammar,
22
23
  GrammarCompiler,
23
24
  GrammarMatcher,
24
25
  TokenizerInfo,
@@ -133,10 +134,13 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
133
134
  logging.warning(f"Skip invalid ebnf: ebnf={key_string}, {e=}")
134
135
  return None
135
136
  elif key_type == "regex":
136
- logger.warning(
137
- "regex hasn't been supported by xgrammar yet. This is skipped."
138
- )
139
- return None
137
+ try:
138
+ ctx = self.grammar_compiler.compile_grammar(
139
+ Grammar.from_regex(key_string)
140
+ )
141
+ except RuntimeError as e:
142
+ logging.warning(f"Skip invalid regex: regex={key_string}, {e=}")
143
+ return None
140
144
  else:
141
145
  raise ValueError(f"Invalid key_type: {key_type}")
142
146
 
@@ -452,7 +452,6 @@ def generate_chat_conv(
452
452
 
453
453
  # Add a blank message for the assistant.
454
454
  conv.append_message(conv.roles[1], None)
455
-
456
455
  return conv
457
456
 
458
457
 
@@ -555,3 +554,17 @@ register_conv_template(
555
554
  image_token="<|vision_start|><|image_pad|><|vision_end|>",
556
555
  )
557
556
  )
557
+
558
+ # Reference: https://huggingface.co/openbmb/MiniCPM-V-2_6#usage
559
+ register_conv_template(
560
+ Conversation(
561
+ name="minicpmv",
562
+ system_message="You are a helpful assistant",
563
+ system_template="<|im_start|>system\n{system_message}.",
564
+ roles=("<|im_start|>user", "<|im_start|>assistant"),
565
+ sep="<|im_end|>\n",
566
+ sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE,
567
+ stop_str=("<|im_end|>", "<|endoftext|>"),
568
+ image_token="(<image>./</image>)",
569
+ )
570
+ )
@@ -1,3 +1,3 @@
1
- from .communication_op import *
2
- from .parallel_state import *
3
- from .utils import *
1
+ from sglang.srt.distributed.communication_op import *
2
+ from sglang.srt.distributed.parallel_state import *
3
+ from sglang.srt.distributed.utils import *
@@ -1,4 +1,5 @@
1
- # Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/communication_op.py
1
+ # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/communication_op.py
2
+
2
3
  from typing import Any, Dict, Optional, Union
3
4
 
4
5
  import torch
@@ -1,4 +1,5 @@
1
- # Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/cuda_wrapper.py
1
+ # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/cuda_wrapper.py
2
+
2
3
  """This file is a pure Python wrapper for the cudart library.
3
4
  It avoids the need to compile a separate shared library, and is
4
5
  convenient for use when we just need to call a few functions.
@@ -1,4 +1,5 @@
1
- # Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/custom_all_reduce.py
1
+ # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/custom_all_reduce.py
2
+
2
3
  import ctypes
3
4
  import logging
4
5
  import os
@@ -6,7 +7,6 @@ from contextlib import contextmanager
6
7
  from functools import wraps
7
8
  from typing import Callable, List, Optional, TypeVar, Union
8
9
 
9
- import pynvml
10
10
  import torch
11
11
  import torch.distributed as dist
12
12
  from torch.distributed import ProcessGroup
@@ -20,8 +20,19 @@ from sglang.srt.distributed.device_communicators.custom_all_reduce_utils import
20
20
  from sglang.srt.distributed.parallel_state import in_the_same_node_as
21
21
  from sglang.srt.utils import cuda_device_count_stateless, is_cuda
22
22
 
23
+ logger = logging.getLogger(__name__)
24
+
25
+ if is_cuda():
26
+ try:
27
+ import pynvml
28
+ except ImportError as e:
29
+ logger.warning("Failed to import pynvml with %r", e)
30
+
23
31
  try:
24
- ops.meta_size()
32
+ if ops.use_vllm_custom_allreduce:
33
+ ops.meta_size()
34
+ else:
35
+ import sgl_kernel
25
36
  custom_ar = True
26
37
  except Exception:
27
38
  # For AMD GPUs and CPUs
@@ -29,7 +40,6 @@ except Exception:
29
40
 
30
41
  logger = logging.getLogger(__name__)
31
42
 
32
-
33
43
  _P = ParamSpec("_P")
34
44
  _R = TypeVar("_R")
35
45
 
@@ -47,7 +57,7 @@ def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
47
57
 
48
58
 
49
59
  @with_nvml_context
50
- def is_full_nvlink(cls, physical_device_ids: List[int]) -> bool:
60
+ def is_full_nvlink(physical_device_ids: List[int]) -> bool:
51
61
  """
52
62
  query if the set of gpus are fully connected by nvlink (1 hop)
53
63
  """
@@ -175,9 +185,12 @@ class CustomAllreduce:
175
185
  # test nvlink first, this will filter out most of the cases
176
186
  # where custom allreduce is not supported
177
187
  # this checks hardware and driver support for NVLink
178
- assert is_cuda()
188
+ if is_cuda():
189
+ assert is_cuda()
179
190
 
180
- full_nvlink = is_full_nvlink(physical_device_ids)
191
+ full_nvlink = is_full_nvlink(physical_device_ids)
192
+ else:
193
+ full_nvlink = False
181
194
  if world_size > 2 and not full_nvlink:
182
195
  logger.warning(
183
196
  "Custom allreduce is disabled because it's not supported on"
@@ -196,32 +209,64 @@ class CustomAllreduce:
196
209
  )
197
210
  return
198
211
 
199
- self.disabled = False
200
- # Buffers memory are owned by this Python class and passed to C++.
201
- # Meta data composes of two parts: meta data for synchronization and a
202
- # temporary buffer for storing intermediate allreduce results.
203
- self.meta_ptrs = self.create_shared_buffer(
204
- ops.meta_size() + max_size, group=group
205
- )
206
- # This is a pre-registered IPC buffer. In eager mode, input tensors
207
- # are first copied into this buffer before allreduce is performed
208
- self.buffer_ptrs = self.create_shared_buffer(max_size, group=group)
209
- # This is a buffer for storing the tuples of pointers pointing to
210
- # IPC buffers from all ranks. Each registered tuple has size of
211
- # 8*world_size bytes where world_size is at most 8. Allocating 8MB
212
- # is enough for 131072 such tuples. The largest model I've seen only
213
- # needs less than 10000 of registered tuples.
214
- self.rank_data = torch.empty(
215
- 8 * 1024 * 1024, dtype=torch.uint8, device=self.device
216
- )
217
212
  self.max_size = max_size
218
213
  self.rank = rank
219
214
  self.world_size = world_size
220
215
  self.full_nvlink = full_nvlink
221
- self._ptr = ops.init_custom_ar(
222
- self.meta_ptrs, self.rank_data, rank, self.full_nvlink
223
- )
224
- ops.register_buffer(self._ptr, self.buffer_ptrs)
216
+
217
+ if ops.use_vllm_custom_allreduce:
218
+ # Buffers memory are owned by this Python class and passed to C++.
219
+ # Meta data composes of two parts: meta data for synchronization and a
220
+ # temporary buffer for storing intermediate allreduce results.
221
+ self.meta_ptrs = self.create_shared_buffer(
222
+ ops.meta_size() + max_size, group=group
223
+ )
224
+ # This is a pre-registered IPC buffer. In eager mode, input tensors
225
+ # are first copied into this buffer before allreduce is performed
226
+ self.buffer_ptrs = self.create_shared_buffer(max_size, group=group)
227
+ # This is a buffer for storing the tuples of pointers pointing to
228
+ # IPC buffers from all ranks. Each registered tuple has size of
229
+ # 8*world_size bytes where world_size is at most 8. Allocating 8MB
230
+ # is enough for 131072 such tuples. The largest model I've seen only
231
+ # needs less than 10000 of registered tuples.
232
+ self.rank_data = torch.empty(
233
+ 8 * 1024 * 1024, dtype=torch.uint8, device=self.device
234
+ )
235
+ self._ptr = ops.init_custom_ar(
236
+ self.meta_ptrs, self.rank_data, rank, self.full_nvlink
237
+ )
238
+ ops.register_buffer(self._ptr, self.buffer_ptrs)
239
+ else:
240
+ # From TensorRT-LLM getMaxRequiredWorkspaceSize
241
+ self.max_required_workspace_size = [16 * 1024 * 1024, 8 * 1024 * 1024]
242
+
243
+ # sizeof(uint32_t) * (MAX_ALL_REDUCE_BLOCKS + 2) * MAX_RANKS_PER_NODE;
244
+ self.barrier_max_size = 8 * (36 + 2) * 8
245
+
246
+ self.buffer_ptrs = self.create_shared_buffer(max_size, group=group)
247
+ self.tmp_result_buffer_ptrs = self.create_shared_buffer(
248
+ max_size, group=group
249
+ )
250
+ self.rank_data_base = torch.empty(
251
+ 8 * 1024 * 1024, dtype=torch.uint8, device=self.device
252
+ )
253
+ self.barrier_in_ptrs = self.create_shared_buffer(
254
+ self.barrier_max_size, group=group
255
+ )
256
+ self.barrier_out_ptrs = self.create_shared_buffer(
257
+ self.barrier_max_size, group=group
258
+ )
259
+
260
+ self._ptr = ops.init_custom_ar(
261
+ rank,
262
+ world_size,
263
+ self.rank_data_base,
264
+ self.buffer_ptrs,
265
+ self.tmp_result_buffer_ptrs,
266
+ self.barrier_in_ptrs,
267
+ self.barrier_out_ptrs,
268
+ )
269
+ self.disabled = False
225
270
 
226
271
  @staticmethod
227
272
  def create_shared_buffer(
@@ -300,12 +345,31 @@ class CustomAllreduce:
300
345
  return False
301
346
  # for 4 or more non NVLink-capable GPUs, custom allreduce provides
302
347
  # little performance improvement over NCCL.
303
- if self.world_size == 2 or self.full_nvlink:
304
- return inp_size < self.max_size
348
+ if ops.use_vllm_custom_allreduce:
349
+ if self.world_size == 2 or self.full_nvlink:
350
+ return inp_size < self.max_size
351
+ return False
352
+
353
+ if self.world_size == 2:
354
+ return (
355
+ inp_size < self.max_size
356
+ and inp_size < self.max_required_workspace_size[0]
357
+ )
358
+
359
+ if self.full_nvlink:
360
+ return (
361
+ inp_size < self.max_size
362
+ and inp_size < self.max_required_workspace_size[1]
363
+ )
364
+
305
365
  return False
306
366
 
307
367
  def all_reduce(
308
- self, inp: torch.Tensor, *, out: torch.Tensor = None, registered: bool = False
368
+ self,
369
+ inp: torch.Tensor,
370
+ *,
371
+ out: torch.Tensor = None,
372
+ registered: bool = False,
309
373
  ):
310
374
  """Performs an out-of-place all reduce.
311
375
 
@@ -315,12 +379,15 @@ class CustomAllreduce:
315
379
  """
316
380
  if out is None:
317
381
  out = torch.empty_like(inp)
318
- if registered:
319
- ops.all_reduce(self._ptr, inp, out, 0, 0)
382
+ if ops.use_vllm_custom_allreduce:
383
+ if registered:
384
+ ops.all_reduce(self._ptr, inp, out, 0, 0)
385
+ else:
386
+ ops.all_reduce(
387
+ self._ptr, inp, out, self.buffer_ptrs[self.rank], self.max_size
388
+ )
320
389
  else:
321
- ops.all_reduce(
322
- self._ptr, inp, out, self.buffer_ptrs[self.rank], self.max_size
323
- )
390
+ ops.all_reduce(self._ptr, inp, out)
324
391
  return out
325
392
 
326
393
  def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]:
@@ -336,17 +403,20 @@ class CustomAllreduce:
336
403
  # allreduce is out-of-place.
337
404
  return torch.empty_like(input)
338
405
  else:
339
- # Note: outside of cuda graph context, custom allreduce incurs a
340
- # cost of cudaMemcpy, which should be small (<=1% of overall
341
- # latency) compared to the performance gain of using custom kernels
342
406
  return self.all_reduce(input, registered=False)
343
407
 
344
408
  def close(self):
345
409
  if not self.disabled and self._ptr:
346
410
  ops.dispose(self._ptr)
411
+ if ops.use_vllm_custom_allreduce:
412
+ self.free_shared_buffer(self.meta_ptrs)
413
+ self.free_shared_buffer(self.buffer_ptrs)
414
+ else:
415
+ self.free_shared_buffer(self.buffer_ptrs)
416
+ self.free_shared_buffer(self.tmp_result_buffer_ptrs)
417
+ self.free_shared_buffer(self.barrier_in_ptrs)
418
+ self.free_shared_buffer(self.barrier_out_ptrs)
347
419
  self._ptr = 0
348
- self.free_shared_buffer(self.meta_ptrs)
349
- self.free_shared_buffer(self.buffer_ptrs)
350
420
 
351
421
  def __del__(self):
352
422
  self.close()
@@ -1,4 +1,5 @@
1
- # Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/custom_all_reduce_utils.py
1
+ # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/custom_all_reduce_utils.py
2
+
2
3
  import ctypes
3
4
  import json
4
5
  import logging
@@ -7,7 +8,6 @@ import pickle
7
8
  import subprocess
8
9
  import sys
9
10
  import tempfile
10
- from functools import lru_cache
11
11
  from itertools import product
12
12
  from typing import Dict, List, Optional, Sequence
13
13
 
@@ -1,4 +1,5 @@
1
- # Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/hpu_communicator.py
1
+ # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/hpu_communicator.py
2
+
2
3
  import torch
3
4
  import torch.distributed as dist
4
5
  from torch.distributed import ProcessGroup
@@ -1,8 +1,10 @@
1
- # Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/pynccl.py
1
+ # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/pynccl.py
2
+
2
3
  import logging
3
4
  from contextlib import contextmanager
4
5
  from typing import Optional, Union
5
6
 
7
+ # ===================== import region =====================
6
8
  import torch
7
9
  import torch.distributed as dist
8
10
  from torch.distributed import ProcessGroup, ReduceOp
@@ -143,6 +145,57 @@ class PyNcclCommunicator:
143
145
  cudaStream_t(stream.cuda_stream),
144
146
  )
145
147
 
148
+ def all_gather(
149
+ self, output_tensor: torch.Tensor, input_tensor: torch.Tensor, stream=None
150
+ ):
151
+ if self.disabled:
152
+ return
153
+ # nccl communicator created on a specific device
154
+ # will only work on tensors on the same device
155
+ # otherwise it will cause "illegal memory access"
156
+ assert input_tensor.device == self.device, (
157
+ f"this nccl communicator is created to work on {self.device}, "
158
+ f"but the input tensor is on {input_tensor.device}"
159
+ )
160
+ if stream is None:
161
+ stream = self.stream
162
+ self.nccl.ncclAllGather(
163
+ buffer_type(input_tensor.data_ptr()),
164
+ buffer_type(output_tensor.data_ptr()),
165
+ input_tensor.numel(),
166
+ ncclDataTypeEnum.from_torch(input_tensor.dtype),
167
+ self.comm,
168
+ cudaStream_t(stream.cuda_stream),
169
+ )
170
+
171
+ def reduce_scatter(
172
+ self,
173
+ output_tensor: torch.Tensor,
174
+ input_tensor: torch.Tensor,
175
+ op: ReduceOp = ReduceOp.SUM,
176
+ stream=None,
177
+ ):
178
+ if self.disabled:
179
+ return
180
+ # nccl communicator created on a specific device
181
+ # will only work on tensors on the same device
182
+ # otherwise it will cause "illegal memory access"
183
+ assert input_tensor.device == self.device, (
184
+ f"this nccl communicator is created to work on {self.device}, "
185
+ f"but the input tensor is on {input_tensor.device}"
186
+ )
187
+ if stream is None:
188
+ stream = self.stream
189
+ self.nccl.ncclReduceScatter(
190
+ buffer_type(input_tensor.data_ptr()),
191
+ buffer_type(output_tensor.data_ptr()),
192
+ output_tensor.numel(),
193
+ ncclDataTypeEnum.from_torch(input_tensor.dtype),
194
+ ncclRedOpTypeEnum.from_torch(op),
195
+ self.comm,
196
+ cudaStream_t(stream.cuda_stream),
197
+ )
198
+
146
199
  def send(self, tensor: torch.Tensor, dst: int, stream=None):
147
200
  if self.disabled:
148
201
  return
@@ -179,6 +232,32 @@ class PyNcclCommunicator:
179
232
  cudaStream_t(stream.cuda_stream),
180
233
  )
181
234
 
235
+ def broadcast(self, tensor: torch.Tensor, src: int, stream=None):
236
+ if self.disabled:
237
+ return
238
+ assert tensor.device == self.device, (
239
+ f"this nccl communicator is created to work on {self.device}, "
240
+ f"but the input tensor is on {tensor.device}"
241
+ )
242
+ if stream is None:
243
+ stream = self.stream
244
+ if src == self.rank:
245
+ sendbuff = buffer_type(tensor.data_ptr())
246
+ # NCCL requires the sender also to have a receive buffer
247
+ recvbuff = buffer_type(tensor.data_ptr())
248
+ else:
249
+ sendbuff = buffer_type()
250
+ recvbuff = buffer_type(tensor.data_ptr())
251
+ self.nccl.ncclBroadcast(
252
+ sendbuff,
253
+ recvbuff,
254
+ tensor.numel(),
255
+ ncclDataTypeEnum.from_torch(tensor.dtype),
256
+ src,
257
+ self.comm,
258
+ cudaStream_t(stream.cuda_stream),
259
+ )
260
+
182
261
  @contextmanager
183
262
  def change_state(
184
263
  self, enable: Optional[bool] = None, stream: Optional[torch.cuda.Stream] = None