sglang 0.4.1.post6__py3-none-any.whl → 0.4.1.post7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (122) hide show
  1. sglang/__init__.py +21 -23
  2. sglang/api.py +2 -7
  3. sglang/bench_offline_throughput.py +24 -16
  4. sglang/bench_one_batch.py +51 -3
  5. sglang/bench_one_batch_server.py +1 -1
  6. sglang/bench_serving.py +37 -28
  7. sglang/lang/backend/runtime_endpoint.py +183 -4
  8. sglang/lang/chat_template.py +15 -4
  9. sglang/launch_server.py +1 -1
  10. sglang/srt/_custom_ops.py +80 -42
  11. sglang/srt/configs/device_config.py +1 -1
  12. sglang/srt/configs/model_config.py +1 -0
  13. sglang/srt/constrained/base_grammar_backend.py +21 -0
  14. sglang/srt/constrained/xgrammar_backend.py +8 -4
  15. sglang/srt/conversation.py +14 -1
  16. sglang/srt/distributed/__init__.py +3 -3
  17. sglang/srt/distributed/communication_op.py +2 -1
  18. sglang/srt/distributed/device_communicators/cuda_wrapper.py +2 -1
  19. sglang/srt/distributed/device_communicators/custom_all_reduce.py +107 -40
  20. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
  21. sglang/srt/distributed/device_communicators/hpu_communicator.py +2 -1
  22. sglang/srt/distributed/device_communicators/pynccl.py +80 -1
  23. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +112 -2
  24. sglang/srt/distributed/device_communicators/shm_broadcast.py +5 -72
  25. sglang/srt/distributed/device_communicators/xpu_communicator.py +2 -1
  26. sglang/srt/distributed/parallel_state.py +1 -1
  27. sglang/srt/distributed/utils.py +2 -1
  28. sglang/srt/entrypoints/engine.py +449 -0
  29. sglang/srt/entrypoints/http_server.py +579 -0
  30. sglang/srt/layers/activation.py +3 -3
  31. sglang/srt/layers/attention/flashinfer_backend.py +10 -9
  32. sglang/srt/layers/attention/triton_backend.py +4 -6
  33. sglang/srt/layers/attention/vision.py +204 -0
  34. sglang/srt/layers/dp_attention.py +69 -0
  35. sglang/srt/layers/linear.py +41 -5
  36. sglang/srt/layers/logits_processor.py +48 -63
  37. sglang/srt/layers/moe/ep_moe/layer.py +4 -4
  38. sglang/srt/layers/moe/fused_moe_native.py +69 -0
  39. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -6
  40. sglang/srt/layers/moe/fused_moe_triton/layer.py +29 -5
  41. sglang/srt/layers/parameter.py +2 -1
  42. sglang/srt/layers/quantization/__init__.py +20 -23
  43. sglang/srt/layers/quantization/fp8.py +6 -3
  44. sglang/srt/layers/quantization/modelopt_quant.py +1 -2
  45. sglang/srt/layers/quantization/w8a8_int8.py +1 -1
  46. sglang/srt/layers/radix_attention.py +2 -2
  47. sglang/srt/layers/rotary_embedding.py +1179 -31
  48. sglang/srt/layers/sampler.py +39 -1
  49. sglang/srt/layers/vocab_parallel_embedding.py +2 -2
  50. sglang/srt/lora/lora.py +1 -9
  51. sglang/srt/managers/configure_logging.py +3 -0
  52. sglang/srt/managers/data_parallel_controller.py +79 -72
  53. sglang/srt/managers/detokenizer_manager.py +23 -6
  54. sglang/srt/managers/image_processor.py +158 -2
  55. sglang/srt/managers/io_struct.py +25 -2
  56. sglang/srt/managers/schedule_batch.py +49 -22
  57. sglang/srt/managers/schedule_policy.py +26 -12
  58. sglang/srt/managers/scheduler.py +277 -178
  59. sglang/srt/managers/session_controller.py +1 -0
  60. sglang/srt/managers/tokenizer_manager.py +206 -121
  61. sglang/srt/managers/tp_worker.py +6 -4
  62. sglang/srt/managers/tp_worker_overlap_thread.py +5 -8
  63. sglang/srt/managers/utils.py +44 -0
  64. sglang/srt/mem_cache/memory_pool.py +10 -32
  65. sglang/srt/metrics/collector.py +15 -6
  66. sglang/srt/model_executor/cuda_graph_runner.py +4 -6
  67. sglang/srt/model_executor/model_runner.py +37 -15
  68. sglang/srt/model_loader/loader.py +8 -6
  69. sglang/srt/model_loader/weight_utils.py +55 -2
  70. sglang/srt/models/baichuan.py +6 -6
  71. sglang/srt/models/chatglm.py +2 -2
  72. sglang/srt/models/commandr.py +3 -3
  73. sglang/srt/models/dbrx.py +4 -4
  74. sglang/srt/models/deepseek.py +3 -3
  75. sglang/srt/models/deepseek_v2.py +8 -8
  76. sglang/srt/models/exaone.py +2 -2
  77. sglang/srt/models/gemma.py +2 -2
  78. sglang/srt/models/gemma2.py +6 -24
  79. sglang/srt/models/gpt2.py +3 -5
  80. sglang/srt/models/gpt_bigcode.py +1 -1
  81. sglang/srt/models/granite.py +2 -2
  82. sglang/srt/models/grok.py +3 -3
  83. sglang/srt/models/internlm2.py +2 -2
  84. sglang/srt/models/llama.py +7 -5
  85. sglang/srt/models/minicpm.py +2 -2
  86. sglang/srt/models/minicpm3.py +6 -6
  87. sglang/srt/models/minicpmv.py +1238 -0
  88. sglang/srt/models/mixtral.py +3 -3
  89. sglang/srt/models/mixtral_quant.py +3 -3
  90. sglang/srt/models/mllama.py +2 -2
  91. sglang/srt/models/olmo.py +3 -3
  92. sglang/srt/models/olmo2.py +4 -4
  93. sglang/srt/models/olmoe.py +7 -13
  94. sglang/srt/models/phi3_small.py +2 -2
  95. sglang/srt/models/qwen.py +2 -2
  96. sglang/srt/models/qwen2.py +41 -4
  97. sglang/srt/models/qwen2_moe.py +3 -3
  98. sglang/srt/models/qwen2_vl.py +22 -122
  99. sglang/srt/models/stablelm.py +2 -2
  100. sglang/srt/models/torch_native_llama.py +3 -3
  101. sglang/srt/models/xverse.py +6 -6
  102. sglang/srt/models/xverse_moe.py +6 -6
  103. sglang/srt/openai_api/protocol.py +2 -0
  104. sglang/srt/sampling/custom_logit_processor.py +38 -0
  105. sglang/srt/sampling/sampling_batch_info.py +139 -4
  106. sglang/srt/sampling/sampling_params.py +3 -1
  107. sglang/srt/server.py +4 -1090
  108. sglang/srt/server_args.py +57 -14
  109. sglang/srt/utils.py +103 -65
  110. sglang/test/runners.py +8 -13
  111. sglang/test/test_programs.py +1 -1
  112. sglang/test/test_utils.py +3 -1
  113. sglang/utils.py +12 -2
  114. sglang/version.py +1 -1
  115. {sglang-0.4.1.post6.dist-info → sglang-0.4.1.post7.dist-info}/METADATA +16 -5
  116. {sglang-0.4.1.post6.dist-info → sglang-0.4.1.post7.dist-info}/RECORD +119 -115
  117. sglang/launch_server_llavavid.py +0 -25
  118. sglang/srt/constrained/__init__.py +0 -16
  119. sglang/srt/distributed/device_communicators/__init__.py +0 -0
  120. {sglang-0.4.1.post6.dist-info → sglang-0.4.1.post7.dist-info}/LICENSE +0 -0
  121. {sglang-0.4.1.post6.dist-info → sglang-0.4.1.post7.dist-info}/WHEEL +0 -0
  122. {sglang-0.4.1.post6.dist-info → sglang-0.4.1.post7.dist-info}/top_level.txt +0 -0
@@ -88,7 +88,6 @@ register_chat_template(
88
88
  )
89
89
  )
90
90
 
91
-
92
91
  register_chat_template(
93
92
  ChatTemplate(
94
93
  name="claude",
@@ -101,7 +100,6 @@ register_chat_template(
101
100
  )
102
101
  )
103
102
 
104
-
105
103
  register_chat_template(
106
104
  ChatTemplate(
107
105
  name="chatml",
@@ -116,7 +114,6 @@ register_chat_template(
116
114
  )
117
115
  )
118
116
 
119
-
120
117
  register_chat_template(
121
118
  ChatTemplate(
122
119
  name="chatml-llava",
@@ -132,7 +129,6 @@ register_chat_template(
132
129
  )
133
130
  )
134
131
 
135
-
136
132
  # There is default system prompt for qwen
137
133
  # reference: https://modelscope.cn/models/qwen/Qwen2-72B-Instruct/file/view/master?fileName=tokenizer_config.json&status=1
138
134
  # The chat template is: "{% for message in messages %}{% if loop.first and messages[0]['role'] != 'system' %}{{ '<|im_start|>system\nYou are a helpful assistant.<|im_end|>\n' }}{% endif %}{{'<|im_start|>' + message['role'] + '\n' + message['content'] + '<|im_end|>' + '\n'}}{% endfor %}{% if add_generation_prompt %}{{ '<|im_start|>assistant\n' }}{% endif %}"
@@ -219,6 +215,21 @@ register_chat_template(
219
215
  )
220
216
  )
221
217
 
218
+ # https://huggingface.co/openbmb/MiniCPM-V-2_6
219
+ register_chat_template(
220
+ ChatTemplate(
221
+ name="minicpmv",
222
+ default_system_prompt=None,
223
+ role_prefix_and_suffix={
224
+ "system": ("", " "),
225
+ "user": ("user:", " "),
226
+ "assistant": ("assistant:", "</s>"),
227
+ },
228
+ stop_str=("<|im_end|>", "<|endoftext|>"),
229
+ image_token="(<image>./</image>)",
230
+ )
231
+ )
232
+
222
233
  # The difference between "llama-3-instruct-llava" and "llama-3-instruct" is that llava uses a different image_token.
223
234
  register_chat_template(
224
235
  ChatTemplate(
sglang/launch_server.py CHANGED
@@ -3,7 +3,7 @@
3
3
  import os
4
4
  import sys
5
5
 
6
- from sglang.srt.server import launch_server
6
+ from sglang.srt.entrypoints.http_server import launch_server
7
7
  from sglang.srt.server_args import prepare_server_args
8
8
  from sglang.srt.utils import kill_process_tree
9
9
 
sglang/srt/_custom_ops.py CHANGED
@@ -1,8 +1,9 @@
1
- # Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/_custom_ops.py
1
+ # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/_custom_ops.py
2
2
  import contextlib
3
3
  import functools
4
4
  import importlib
5
5
  import logging
6
+ import os
6
7
  from typing import TYPE_CHECKING, List, Optional, Tuple, Union
7
8
 
8
9
  import torch
@@ -11,12 +12,19 @@ import torch.library
11
12
  from sglang.srt.utils import is_hpu
12
13
 
13
14
  logger = logging.getLogger(__name__)
15
+ use_vllm_custom_allreduce = os.environ.get("USE_VLLM_CUSTOM_ALLREDUCE", default=True)
14
16
 
15
17
  if not is_hpu():
16
- try:
17
- import custom_ar
18
- except ImportError as e:
19
- logger.warning("Failed to import from custom_ar with %r", e)
18
+ if use_vllm_custom_allreduce:
19
+ try:
20
+ import vllm._C
21
+ except ImportError as e:
22
+ logger.warning("Failed to import from vllm._C with %r", e)
23
+ else:
24
+ try:
25
+ import sgl_kernel
26
+ except ImportError as e:
27
+ logger.warning("Failed to import from custom_ar with %r", e)
20
28
 
21
29
 
22
30
  def hint_on_error(fn):
@@ -48,48 +56,78 @@ def hint_on_error(fn):
48
56
  return wrapper
49
57
 
50
58
 
51
- # custom ar
52
- def init_custom_ar(
53
- ipc_tensors: List[torch.Tensor],
54
- rank_data: torch.Tensor,
55
- rank: int,
56
- full_nvlink: bool,
57
- ) -> int:
58
- return torch.ops._C_vllm_ar.init_custom_ar(
59
- ipc_tensors, rank_data, rank, full_nvlink
60
- )
61
-
62
-
63
- def all_reduce(
64
- fa: int,
65
- inp: torch.Tensor,
66
- out: torch.Tensor,
67
- reg_buffer: int,
68
- reg_buffer_sz_bytes: int,
69
- ) -> None:
70
- torch.ops._C_vllm_ar.all_reduce(fa, inp, out, reg_buffer, reg_buffer_sz_bytes)
71
-
72
-
73
- def dispose(fa: int) -> None:
74
- torch.ops._C_vllm_ar.dispose(fa)
75
-
76
-
77
- def meta_size() -> int:
78
- return torch.ops._C_vllm_ar.meta_size()
79
-
59
+ if use_vllm_custom_allreduce:
60
+ # custom ar
61
+ def init_custom_ar(
62
+ ipc_tensors: List[torch.Tensor],
63
+ rank_data: torch.Tensor,
64
+ rank: int,
65
+ full_nvlink: bool,
66
+ ) -> int:
67
+ return torch.ops._C_custom_ar.init_custom_ar(
68
+ ipc_tensors, rank_data, rank, full_nvlink
69
+ )
80
70
 
81
- def register_buffer(fa: int, ipc_tensors: List[int]) -> None:
82
- return torch.ops._C_vllm_ar.register_buffer(fa, ipc_tensors)
71
+ def all_reduce(
72
+ fa: int,
73
+ inp: torch.Tensor,
74
+ out: torch.Tensor,
75
+ reg_buffer: int,
76
+ reg_buffer_sz_bytes: int,
77
+ ) -> None:
78
+ torch.ops._C_custom_ar.all_reduce(fa, inp, out, reg_buffer, reg_buffer_sz_bytes)
79
+
80
+ def dispose(fa: int) -> None:
81
+ torch.ops._C_custom_ar.dispose(fa)
82
+
83
+ def meta_size() -> int:
84
+ return torch.ops._C_custom_ar.meta_size()
85
+
86
+ def register_buffer(fa: int, ipc_tensors: List[int]) -> None:
87
+ return torch.ops._C_custom_ar.register_buffer(fa, ipc_tensors)
88
+
89
+ def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[int], List[int]]:
90
+ return torch.ops._C_custom_ar.get_graph_buffer_ipc_meta(fa)
91
+
92
+ def register_graph_buffers(
93
+ fa: int, handles: List[List[int]], offsets: List[List[int]]
94
+ ) -> None:
95
+ torch.ops._C_custom_ar.register_graph_buffers(fa, handles, offsets)
96
+
97
+ else:
98
+ # custom ar
99
+ def init_custom_ar(
100
+ rank_id: int,
101
+ world_size: int,
102
+ rank_data_base: torch.Tensor,
103
+ buffers: List[int],
104
+ tmp_result_buffers: List[int],
105
+ barrier_in: List[int],
106
+ barrier_out: List[int],
107
+ ) -> int:
108
+ return sgl_kernel.ops.init_custom_reduce(
109
+ rank_id,
110
+ world_size,
111
+ rank_data_base,
112
+ buffers,
113
+ tmp_result_buffers,
114
+ barrier_in,
115
+ barrier_out,
116
+ )
83
117
 
118
+ def all_reduce(fa: int, inp: torch.Tensor, out: torch.Tensor) -> None:
119
+ sgl_kernel.ops.custom_reduce(fa, inp, out)
84
120
 
85
- def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[int], List[int]]:
86
- return torch.ops._C_vllm_ar.get_graph_buffer_ipc_meta(fa)
121
+ def dispose(fa: int) -> None:
122
+ sgl_kernel.ops.custom_dispose(fa)
87
123
 
124
+ def get_graph_buffer_ipc_meta(fa: int) -> Tuple[List[int], List[int]]:
125
+ return sgl_kernel.ops.get_graph_buffer_ipc_meta(fa)
88
126
 
89
- def register_graph_buffers(
90
- fa: int, handles: List[List[int]], offsets: List[List[int]]
91
- ) -> None:
92
- torch.ops._C_vllm_ar.register_graph_buffers(fa, handles, offsets)
127
+ def register_graph_buffers(
128
+ fa: int, handles: List[List[int]], offsets: List[List[int]]
129
+ ) -> None:
130
+ sgl_kernel.ops.register_graph_buffers(fa, handles, offsets)
93
131
 
94
132
 
95
133
  # temporary fix for https://github.com/vllm-project/vllm/issues/5456
@@ -10,7 +10,7 @@ class DeviceConfig:
10
10
  device: Optional[torch.device]
11
11
 
12
12
  def __init__(self, device: str = "cuda") -> None:
13
- if device in ["cuda", "xpu", "hpu"]:
13
+ if device in ["cuda", "xpu", "hpu", "cpu"]:
14
14
  self.device_type = device
15
15
  else:
16
16
  raise RuntimeError(f"Not supported device type: {device}")
@@ -402,6 +402,7 @@ def is_multimodal_model(model_architectures: List[str]):
402
402
  or "LlavaVidForCausalLM" in model_architectures
403
403
  or "MllamaForConditionalGeneration" in model_architectures
404
404
  or "Qwen2VLForConditionalGeneration" in model_architectures
405
+ or "MiniCPMV" in model_architectures
405
406
  ):
406
407
  return True
407
408
  else:
@@ -18,6 +18,8 @@ from dataclasses import dataclass
18
18
  from threading import Event, Lock
19
19
  from typing import Any, Optional, Tuple
20
20
 
21
+ from sglang.srt.server_args import ServerArgs
22
+
21
23
 
22
24
  @dataclass
23
25
  class CacheEntry:
@@ -69,3 +71,22 @@ class BaseGrammarBackend:
69
71
  def reset(self):
70
72
  with self.cache_lock:
71
73
  self.cache.clear()
74
+
75
+
76
+ def create_grammar_backend(server_args: ServerArgs, tokenizer, vocab_size):
77
+ if server_args.grammar_backend == "outlines":
78
+ from sglang.srt.constrained.outlines_backend import OutlinesGrammarBackend
79
+
80
+ grammar_backend = OutlinesGrammarBackend(
81
+ tokenizer,
82
+ whitespace_pattern=server_args.constrained_json_whitespace_pattern,
83
+ allow_jump_forward=not server_args.disable_jump_forward,
84
+ )
85
+ elif server_args.grammar_backend == "xgrammar":
86
+ from sglang.srt.constrained.xgrammar_backend import XGrammarGrammarBackend
87
+
88
+ grammar_backend = XGrammarGrammarBackend(tokenizer, vocab_size=vocab_size)
89
+ else:
90
+ raise ValueError(f"Invalid grammar backend: {server_args.grammar_backend}")
91
+
92
+ return grammar_backend
@@ -19,6 +19,7 @@ from typing import List, Tuple
19
19
  import torch
20
20
  from xgrammar import (
21
21
  CompiledGrammar,
22
+ Grammar,
22
23
  GrammarCompiler,
23
24
  GrammarMatcher,
24
25
  TokenizerInfo,
@@ -133,10 +134,13 @@ class XGrammarGrammarBackend(BaseGrammarBackend):
133
134
  logging.warning(f"Skip invalid ebnf: ebnf={key_string}, {e=}")
134
135
  return None
135
136
  elif key_type == "regex":
136
- logger.warning(
137
- "regex hasn't been supported by xgrammar yet. This is skipped."
138
- )
139
- return None
137
+ try:
138
+ ctx = self.grammar_compiler.compile_grammar(
139
+ Grammar.from_regex(key_string)
140
+ )
141
+ except RuntimeError as e:
142
+ logging.warning(f"Skip invalid regex: regex={key_string}, {e=}")
143
+ return None
140
144
  else:
141
145
  raise ValueError(f"Invalid key_type: {key_type}")
142
146
 
@@ -452,7 +452,6 @@ def generate_chat_conv(
452
452
 
453
453
  # Add a blank message for the assistant.
454
454
  conv.append_message(conv.roles[1], None)
455
-
456
455
  return conv
457
456
 
458
457
 
@@ -555,3 +554,17 @@ register_conv_template(
555
554
  image_token="<|vision_start|><|image_pad|><|vision_end|>",
556
555
  )
557
556
  )
557
+
558
+ # Reference: https://huggingface.co/openbmb/MiniCPM-V-2_6#usage
559
+ register_conv_template(
560
+ Conversation(
561
+ name="minicpmv",
562
+ system_message="You are a helpful assistant",
563
+ system_template="<|im_start|>system\n{system_message}.",
564
+ roles=("<|im_start|>user", "<|im_start|>assistant"),
565
+ sep="<|im_end|>\n",
566
+ sep_style=SeparatorStyle.ADD_NEW_LINE_SINGLE,
567
+ stop_str=("<|im_end|>", "<|endoftext|>"),
568
+ image_token="(<image>./</image>)",
569
+ )
570
+ )
@@ -1,3 +1,3 @@
1
- from .communication_op import *
2
- from .parallel_state import *
3
- from .utils import *
1
+ from sglang.srt.distributed.communication_op import *
2
+ from sglang.srt.distributed.parallel_state import *
3
+ from sglang.srt.distributed.utils import *
@@ -1,4 +1,5 @@
1
- # Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/communication_op.py
1
+ # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/communication_op.py
2
+
2
3
  from typing import Any, Dict, Optional, Union
3
4
 
4
5
  import torch
@@ -1,4 +1,5 @@
1
- # Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/cuda_wrapper.py
1
+ # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/cuda_wrapper.py
2
+
2
3
  """This file is a pure Python wrapper for the cudart library.
3
4
  It avoids the need to compile a separate shared library, and is
4
5
  convenient for use when we just need to call a few functions.
@@ -1,4 +1,5 @@
1
- # Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/custom_all_reduce.py
1
+ # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/custom_all_reduce.py
2
+
2
3
  import ctypes
3
4
  import logging
4
5
  import os
@@ -6,7 +7,6 @@ from contextlib import contextmanager
6
7
  from functools import wraps
7
8
  from typing import Callable, List, Optional, TypeVar, Union
8
9
 
9
- import pynvml
10
10
  import torch
11
11
  import torch.distributed as dist
12
12
  from torch.distributed import ProcessGroup
@@ -20,8 +20,19 @@ from sglang.srt.distributed.device_communicators.custom_all_reduce_utils import
20
20
  from sglang.srt.distributed.parallel_state import in_the_same_node_as
21
21
  from sglang.srt.utils import cuda_device_count_stateless, is_cuda
22
22
 
23
+ logger = logging.getLogger(__name__)
24
+
25
+ if is_cuda():
26
+ try:
27
+ import pynvml
28
+ except ImportError as e:
29
+ logger.warning("Failed to import pynvml with %r", e)
30
+
23
31
  try:
24
- ops.meta_size()
32
+ if ops.use_vllm_custom_allreduce:
33
+ ops.meta_size()
34
+ else:
35
+ import sgl_kernel
25
36
  custom_ar = True
26
37
  except Exception:
27
38
  # For AMD GPUs and CPUs
@@ -29,7 +40,6 @@ except Exception:
29
40
 
30
41
  logger = logging.getLogger(__name__)
31
42
 
32
-
33
43
  _P = ParamSpec("_P")
34
44
  _R = TypeVar("_R")
35
45
 
@@ -47,7 +57,7 @@ def with_nvml_context(fn: Callable[_P, _R]) -> Callable[_P, _R]:
47
57
 
48
58
 
49
59
  @with_nvml_context
50
- def is_full_nvlink(cls, physical_device_ids: List[int]) -> bool:
60
+ def is_full_nvlink(physical_device_ids: List[int]) -> bool:
51
61
  """
52
62
  query if the set of gpus are fully connected by nvlink (1 hop)
53
63
  """
@@ -196,32 +206,64 @@ class CustomAllreduce:
196
206
  )
197
207
  return
198
208
 
199
- self.disabled = False
200
- # Buffers memory are owned by this Python class and passed to C++.
201
- # Meta data composes of two parts: meta data for synchronization and a
202
- # temporary buffer for storing intermediate allreduce results.
203
- self.meta_ptrs = self.create_shared_buffer(
204
- ops.meta_size() + max_size, group=group
205
- )
206
- # This is a pre-registered IPC buffer. In eager mode, input tensors
207
- # are first copied into this buffer before allreduce is performed
208
- self.buffer_ptrs = self.create_shared_buffer(max_size, group=group)
209
- # This is a buffer for storing the tuples of pointers pointing to
210
- # IPC buffers from all ranks. Each registered tuple has size of
211
- # 8*world_size bytes where world_size is at most 8. Allocating 8MB
212
- # is enough for 131072 such tuples. The largest model I've seen only
213
- # needs less than 10000 of registered tuples.
214
- self.rank_data = torch.empty(
215
- 8 * 1024 * 1024, dtype=torch.uint8, device=self.device
216
- )
217
209
  self.max_size = max_size
218
210
  self.rank = rank
219
211
  self.world_size = world_size
220
212
  self.full_nvlink = full_nvlink
221
- self._ptr = ops.init_custom_ar(
222
- self.meta_ptrs, self.rank_data, rank, self.full_nvlink
223
- )
224
- ops.register_buffer(self._ptr, self.buffer_ptrs)
213
+
214
+ if ops.use_vllm_custom_allreduce:
215
+ # Buffers memory are owned by this Python class and passed to C++.
216
+ # Meta data composes of two parts: meta data for synchronization and a
217
+ # temporary buffer for storing intermediate allreduce results.
218
+ self.meta_ptrs = self.create_shared_buffer(
219
+ ops.meta_size() + max_size, group=group
220
+ )
221
+ # This is a pre-registered IPC buffer. In eager mode, input tensors
222
+ # are first copied into this buffer before allreduce is performed
223
+ self.buffer_ptrs = self.create_shared_buffer(max_size, group=group)
224
+ # This is a buffer for storing the tuples of pointers pointing to
225
+ # IPC buffers from all ranks. Each registered tuple has size of
226
+ # 8*world_size bytes where world_size is at most 8. Allocating 8MB
227
+ # is enough for 131072 such tuples. The largest model I've seen only
228
+ # needs less than 10000 of registered tuples.
229
+ self.rank_data = torch.empty(
230
+ 8 * 1024 * 1024, dtype=torch.uint8, device=self.device
231
+ )
232
+ self._ptr = ops.init_custom_ar(
233
+ self.meta_ptrs, self.rank_data, rank, self.full_nvlink
234
+ )
235
+ ops.register_buffer(self._ptr, self.buffer_ptrs)
236
+ else:
237
+ # From TensorRT-LLM getMaxRequiredWorkspaceSize
238
+ self.max_required_workspace_size = [16 * 1024 * 1024, 8 * 1024 * 1024]
239
+
240
+ # sizeof(uint32_t) * (MAX_ALL_REDUCE_BLOCKS + 2) * MAX_RANKS_PER_NODE;
241
+ self.barrier_max_size = 8 * (36 + 2) * 8
242
+
243
+ self.buffer_ptrs = self.create_shared_buffer(max_size, group=group)
244
+ self.tmp_result_buffer_ptrs = self.create_shared_buffer(
245
+ max_size, group=group
246
+ )
247
+ self.rank_data_base = torch.empty(
248
+ 8 * 1024 * 1024, dtype=torch.uint8, device=self.device
249
+ )
250
+ self.barrier_in_ptrs = self.create_shared_buffer(
251
+ self.barrier_max_size, group=group
252
+ )
253
+ self.barrier_out_ptrs = self.create_shared_buffer(
254
+ self.barrier_max_size, group=group
255
+ )
256
+
257
+ self._ptr = ops.init_custom_ar(
258
+ rank,
259
+ world_size,
260
+ self.rank_data_base,
261
+ self.buffer_ptrs,
262
+ self.tmp_result_buffer_ptrs,
263
+ self.barrier_in_ptrs,
264
+ self.barrier_out_ptrs,
265
+ )
266
+ self.disabled = False
225
267
 
226
268
  @staticmethod
227
269
  def create_shared_buffer(
@@ -300,12 +342,31 @@ class CustomAllreduce:
300
342
  return False
301
343
  # for 4 or more non NVLink-capable GPUs, custom allreduce provides
302
344
  # little performance improvement over NCCL.
303
- if self.world_size == 2 or self.full_nvlink:
304
- return inp_size < self.max_size
345
+ if ops.use_vllm_custom_allreduce:
346
+ if self.world_size == 2 or self.full_nvlink:
347
+ return inp_size < self.max_size
348
+ return False
349
+
350
+ if self.world_size == 2:
351
+ return (
352
+ inp_size < self.max_size
353
+ and inp_size < self.max_required_workspace_size[0]
354
+ )
355
+
356
+ if self.full_nvlink:
357
+ return (
358
+ inp_size < self.max_size
359
+ and inp_size < self.max_required_workspace_size[1]
360
+ )
361
+
305
362
  return False
306
363
 
307
364
  def all_reduce(
308
- self, inp: torch.Tensor, *, out: torch.Tensor = None, registered: bool = False
365
+ self,
366
+ inp: torch.Tensor,
367
+ *,
368
+ out: torch.Tensor = None,
369
+ registered: bool = False,
309
370
  ):
310
371
  """Performs an out-of-place all reduce.
311
372
 
@@ -315,12 +376,15 @@ class CustomAllreduce:
315
376
  """
316
377
  if out is None:
317
378
  out = torch.empty_like(inp)
318
- if registered:
319
- ops.all_reduce(self._ptr, inp, out, 0, 0)
379
+ if ops.use_vllm_custom_allreduce:
380
+ if registered:
381
+ ops.all_reduce(self._ptr, inp, out, 0, 0)
382
+ else:
383
+ ops.all_reduce(
384
+ self._ptr, inp, out, self.buffer_ptrs[self.rank], self.max_size
385
+ )
320
386
  else:
321
- ops.all_reduce(
322
- self._ptr, inp, out, self.buffer_ptrs[self.rank], self.max_size
323
- )
387
+ ops.all_reduce(self._ptr, inp, out)
324
388
  return out
325
389
 
326
390
  def custom_all_reduce(self, input: torch.Tensor) -> Optional[torch.Tensor]:
@@ -336,17 +400,20 @@ class CustomAllreduce:
336
400
  # allreduce is out-of-place.
337
401
  return torch.empty_like(input)
338
402
  else:
339
- # Note: outside of cuda graph context, custom allreduce incurs a
340
- # cost of cudaMemcpy, which should be small (<=1% of overall
341
- # latency) compared to the performance gain of using custom kernels
342
403
  return self.all_reduce(input, registered=False)
343
404
 
344
405
  def close(self):
345
406
  if not self.disabled and self._ptr:
346
407
  ops.dispose(self._ptr)
408
+ if ops.use_vllm_custom_allreduce:
409
+ self.free_shared_buffer(self.meta_ptrs)
410
+ self.free_shared_buffer(self.buffer_ptrs)
411
+ else:
412
+ self.free_shared_buffer(self.buffer_ptrs)
413
+ self.free_shared_buffer(self.tmp_result_buffer_ptrs)
414
+ self.free_shared_buffer(self.barrier_in_ptrs)
415
+ self.free_shared_buffer(self.barrier_out_ptrs)
347
416
  self._ptr = 0
348
- self.free_shared_buffer(self.meta_ptrs)
349
- self.free_shared_buffer(self.buffer_ptrs)
350
417
 
351
418
  def __del__(self):
352
419
  self.close()
@@ -1,4 +1,5 @@
1
- # Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/custom_all_reduce_utils.py
1
+ # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/custom_all_reduce_utils.py
2
+
2
3
  import ctypes
3
4
  import json
4
5
  import logging
@@ -7,7 +8,6 @@ import pickle
7
8
  import subprocess
8
9
  import sys
9
10
  import tempfile
10
- from functools import lru_cache
11
11
  from itertools import product
12
12
  from typing import Dict, List, Optional, Sequence
13
13
 
@@ -1,4 +1,5 @@
1
- # Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/device_communicators/hpu_communicator.py
1
+ # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/distributed/device_communicators/hpu_communicator.py
2
+
2
3
  import torch
3
4
  import torch.distributed as dist
4
5
  from torch.distributed import ProcessGroup