sglang 0.3.6.post3__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. sglang/bench_one_batch.py +4 -0
  2. sglang/bench_serving.py +13 -0
  3. sglang/check_env.py +1 -1
  4. sglang/srt/_custom_ops.py +118 -0
  5. sglang/srt/configs/device_config.py +17 -0
  6. sglang/srt/configs/load_config.py +84 -0
  7. sglang/srt/configs/model_config.py +161 -4
  8. sglang/srt/configs/qwen2vl.py +5 -8
  9. sglang/srt/constrained/outlines_backend.py +6 -1
  10. sglang/srt/constrained/outlines_jump_forward.py +8 -1
  11. sglang/srt/distributed/__init__.py +3 -0
  12. sglang/srt/distributed/communication_op.py +34 -0
  13. sglang/srt/distributed/device_communicators/__init__.py +0 -0
  14. sglang/srt/distributed/device_communicators/cuda_wrapper.py +182 -0
  15. sglang/srt/distributed/device_communicators/custom_all_reduce.py +352 -0
  16. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +291 -0
  17. sglang/srt/distributed/device_communicators/hpu_communicator.py +48 -0
  18. sglang/srt/distributed/device_communicators/pynccl.py +204 -0
  19. sglang/srt/distributed/device_communicators/pynccl_wrapper.py +362 -0
  20. sglang/srt/distributed/device_communicators/shm_broadcast.py +568 -0
  21. sglang/srt/distributed/device_communicators/xpu_communicator.py +47 -0
  22. sglang/srt/distributed/parallel_state.py +1275 -0
  23. sglang/srt/distributed/utils.py +223 -0
  24. sglang/srt/hf_transformers_utils.py +37 -1
  25. sglang/srt/layers/attention/flashinfer_backend.py +13 -15
  26. sglang/srt/layers/attention/torch_native_backend.py +285 -0
  27. sglang/srt/layers/fused_moe_patch.py +20 -11
  28. sglang/srt/layers/linear.py +1 -0
  29. sglang/srt/layers/logits_processor.py +17 -3
  30. sglang/srt/layers/quantization/__init__.py +34 -0
  31. sglang/srt/layers/vocab_parallel_embedding.py +1 -0
  32. sglang/srt/lora/lora.py +1 -1
  33. sglang/srt/managers/io_struct.py +48 -2
  34. sglang/srt/managers/schedule_batch.py +18 -14
  35. sglang/srt/managers/schedule_policy.py +7 -4
  36. sglang/srt/managers/scheduler.py +76 -20
  37. sglang/srt/managers/tokenizer_manager.py +166 -68
  38. sglang/srt/managers/tp_worker.py +36 -3
  39. sglang/srt/managers/tp_worker_overlap_thread.py +21 -3
  40. sglang/srt/model_executor/cuda_graph_runner.py +16 -7
  41. sglang/srt/model_executor/forward_batch_info.py +9 -4
  42. sglang/srt/model_executor/model_runner.py +136 -150
  43. sglang/srt/model_loader/__init__.py +34 -0
  44. sglang/srt/model_loader/loader.py +1139 -0
  45. sglang/srt/model_loader/utils.py +41 -0
  46. sglang/srt/model_loader/weight_utils.py +640 -0
  47. sglang/srt/models/baichuan.py +9 -10
  48. sglang/srt/models/chatglm.py +6 -15
  49. sglang/srt/models/commandr.py +2 -3
  50. sglang/srt/models/dbrx.py +2 -3
  51. sglang/srt/models/deepseek.py +4 -11
  52. sglang/srt/models/deepseek_v2.py +3 -11
  53. sglang/srt/models/exaone.py +2 -3
  54. sglang/srt/models/gemma.py +2 -6
  55. sglang/srt/models/gemma2.py +3 -14
  56. sglang/srt/models/gemma2_reward.py +0 -1
  57. sglang/srt/models/gpt2.py +5 -12
  58. sglang/srt/models/gpt_bigcode.py +6 -22
  59. sglang/srt/models/grok.py +3 -3
  60. sglang/srt/models/internlm2.py +2 -3
  61. sglang/srt/models/internlm2_reward.py +0 -1
  62. sglang/srt/models/llama.py +97 -27
  63. sglang/srt/models/llama_classification.py +1 -2
  64. sglang/srt/models/llama_embedding.py +1 -2
  65. sglang/srt/models/llama_reward.py +2 -3
  66. sglang/srt/models/llava.py +1 -4
  67. sglang/srt/models/llavavid.py +1 -2
  68. sglang/srt/models/minicpm.py +4 -7
  69. sglang/srt/models/minicpm3.py +6 -19
  70. sglang/srt/models/mixtral.py +12 -5
  71. sglang/srt/models/mixtral_quant.py +2 -3
  72. sglang/srt/models/mllama.py +3 -7
  73. sglang/srt/models/olmo.py +2 -8
  74. sglang/srt/models/olmo2.py +0 -1
  75. sglang/srt/models/olmoe.py +3 -5
  76. sglang/srt/models/phi3_small.py +8 -8
  77. sglang/srt/models/qwen.py +2 -3
  78. sglang/srt/models/qwen2.py +10 -9
  79. sglang/srt/models/qwen2_moe.py +4 -11
  80. sglang/srt/models/qwen2_vl.py +2 -6
  81. sglang/srt/models/registry.py +99 -0
  82. sglang/srt/models/stablelm.py +2 -3
  83. sglang/srt/models/torch_native_llama.py +6 -12
  84. sglang/srt/models/xverse.py +2 -4
  85. sglang/srt/models/xverse_moe.py +4 -11
  86. sglang/srt/models/yivl.py +2 -3
  87. sglang/srt/openai_api/adapter.py +9 -5
  88. sglang/srt/openai_api/protocol.py +1 -0
  89. sglang/srt/server.py +267 -170
  90. sglang/srt/server_args.py +65 -31
  91. sglang/srt/utils.py +245 -28
  92. sglang/test/test_utils.py +7 -0
  93. sglang/version.py +1 -1
  94. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.dist-info}/METADATA +1 -1
  95. sglang-0.4.0.dist-info/RECORD +184 -0
  96. sglang-0.3.6.post3.dist-info/RECORD +0 -162
  97. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.dist-info}/LICENSE +0 -0
  98. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.dist-info}/WHEEL +0 -0
  99. {sglang-0.3.6.post3.dist-info → sglang-0.4.0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1275 @@
1
+ # Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/distributed/parallel_state.py
2
+
3
+ # Copyright 2023 The vLLM team.
4
+ # Adapted from
5
+ # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/core/parallel_state.py
6
+ # Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
7
+ """vLLM distributed state.
8
+ It takes over the control of the distributed environment from PyTorch.
9
+ The typical workflow is:
10
+
11
+ - call `init_distributed_environment` to initialize the distributed environment.
12
+ - call `initialize_model_parallel` or `ensure_model_parallel_initialized` to
13
+ initialize the model parallel groups.
14
+
15
+ - any code dealing with the distributed stuff
16
+
17
+ - call `destroy_model_parallel` to destroy the model parallel groups.
18
+ - call `destroy_distributed_environment` to destroy the distributed environment.
19
+
20
+ If you only need to use the distributed environment without model/pipeline
21
+ parallelism, you can skip the model parallel initialization and destruction
22
+ steps.
23
+ """
24
+ import contextlib
25
+ import gc
26
+ import logging
27
+ import os
28
+ import pickle
29
+ import weakref
30
+ from collections import namedtuple
31
+ from contextlib import contextmanager, nullcontext
32
+ from dataclasses import dataclass
33
+ from multiprocessing import shared_memory
34
+ from typing import Any, Callable, Dict, List, Optional, Tuple, Union
35
+ from unittest.mock import patch
36
+
37
+ import torch
38
+ import torch.distributed
39
+ from torch.distributed import Backend, ProcessGroup
40
+
41
+ from sglang.srt.utils import (
42
+ direct_register_custom_op,
43
+ is_cuda_alike,
44
+ supports_custom_op,
45
+ )
46
+
47
+
48
+ @dataclass
49
+ class GraphCaptureContext:
50
+ stream: torch.cuda.Stream
51
+
52
+
53
+ TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
54
+
55
+
56
+ def _split_tensor_dict(
57
+ tensor_dict: Dict[str, Union[torch.Tensor, Any]]
58
+ ) -> Tuple[List[Tuple[str, Any]], List[torch.Tensor]]:
59
+ """Split the tensor dictionary into two parts:
60
+ 1. A list of (key, value) pairs. If the value is a tensor, it is replaced
61
+ by its metadata.
62
+ 2. A list of tensors.
63
+ """
64
+ metadata_list: List[Tuple[str, Any]] = []
65
+ tensor_list: List[torch.Tensor] = []
66
+ for key, value in tensor_dict.items():
67
+ if isinstance(value, torch.Tensor):
68
+ # Note: we cannot use `value.device` here,
69
+ # because it contains not only the device type but also the device
70
+ # index (e.g. "cuda:0"). We only need the device type.
71
+ # receiving side will set the device index.
72
+ device = value.device.type
73
+ metadata_list.append(
74
+ (key, TensorMetadata(device, value.dtype, value.size()))
75
+ )
76
+ tensor_list.append(value)
77
+ else:
78
+ metadata_list.append((key, value))
79
+ return metadata_list, tensor_list
80
+
81
+
82
+ _group_name_counter: Dict[str, int] = {}
83
+
84
+
85
+ def _get_unique_name(name: str) -> str:
86
+ """Get a unique name for the group.
87
+ Example:
88
+ _get_unique_name("tp") -> "tp:0"
89
+ _get_unique_name("tp") -> "tp:1"
90
+ """
91
+ if name not in _group_name_counter:
92
+ _group_name_counter[name] = 0
93
+ newname = f"{name}:{_group_name_counter[name]}"
94
+ _group_name_counter[name] += 1
95
+ return newname
96
+
97
+
98
+ _groups: Dict[str, Callable[[], Optional["GroupCoordinator"]]] = {}
99
+
100
+
101
+ def _register_group(group: "GroupCoordinator") -> None:
102
+ _groups[group.unique_name] = weakref.ref(group)
103
+
104
+
105
+ if supports_custom_op():
106
+
107
+ def inplace_all_reduce(tensor: torch.Tensor, group_name: str) -> None:
108
+ assert group_name in _groups, f"Group {group_name} is not found."
109
+ group = _groups[group_name]()
110
+ if group is None:
111
+ raise ValueError(f"Group {group_name} is destroyed.")
112
+ group._all_reduce_in_place(tensor)
113
+
114
+ def inplace_all_reduce_fake(tensor: torch.Tensor, group_name: str) -> None:
115
+ return
116
+
117
+ direct_register_custom_op(
118
+ op_name="inplace_all_reduce",
119
+ op_func=inplace_all_reduce,
120
+ mutates_args=["tensor"],
121
+ fake_impl=inplace_all_reduce_fake,
122
+ )
123
+
124
+ def outplace_all_reduce(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
125
+ assert group_name in _groups, f"Group {group_name} is not found."
126
+ group = _groups[group_name]()
127
+ if group is None:
128
+ raise ValueError(f"Group {group_name} is destroyed.")
129
+ return group._all_reduce_out_place(tensor)
130
+
131
+ def outplace_all_reduce_fake(tensor: torch.Tensor, group_name: str) -> torch.Tensor:
132
+ return torch.empty_like(tensor)
133
+
134
+ direct_register_custom_op(
135
+ op_name="outplace_all_reduce",
136
+ op_func=outplace_all_reduce,
137
+ mutates_args=[],
138
+ fake_impl=outplace_all_reduce_fake,
139
+ )
140
+
141
+
142
+ class GroupCoordinator:
143
+ """
144
+ PyTorch ProcessGroup wrapper for a group of processes.
145
+ PyTorch ProcessGroup is bound to one specific communication backend,
146
+ e.g. NCCL, Gloo, MPI, etc.
147
+ GroupCoordinator takes charge of all the communication operations among
148
+ the processes in the group. It can route the communication to
149
+ a specific implementation (e.g. switch allreduce implementation
150
+ based on the tensor size and cuda graph mode).
151
+ """
152
+
153
+ # available attributes:
154
+ rank: int # global rank
155
+ ranks: List[int] # global ranks in the group
156
+ world_size: int # size of the group
157
+ # difference between `local_rank` and `rank_in_group`:
158
+ # if we have a group of size 4 across two nodes:
159
+ # Process | Node | Rank | Local Rank | Rank in Group
160
+ # 0 | 0 | 0 | 0 | 0
161
+ # 1 | 0 | 1 | 1 | 1
162
+ # 2 | 1 | 2 | 0 | 2
163
+ # 3 | 1 | 3 | 1 | 3
164
+ local_rank: int # local rank used to assign devices
165
+ rank_in_group: int # rank inside the group
166
+ cpu_group: ProcessGroup # group for CPU communication
167
+ device_group: ProcessGroup # group for device communication
168
+ use_pynccl: bool # a hint of whether to use PyNccl
169
+ use_custom_allreduce: bool # a hint of whether to use CustomAllreduce
170
+ # communicators are only created for world size > 1
171
+ pynccl_comm: Optional[Any] # PyNccl communicator
172
+ ca_comm: Optional[Any] # Custom allreduce communicator
173
+ mq_broadcaster: Optional[Any] # shared memory broadcaster
174
+
175
+ def __init__(
176
+ self,
177
+ group_ranks: List[List[int]],
178
+ local_rank: int,
179
+ torch_distributed_backend: Union[str, Backend],
180
+ use_pynccl: bool,
181
+ use_custom_allreduce: bool,
182
+ use_hpu_communicator: bool,
183
+ use_xpu_communicator: bool,
184
+ use_message_queue_broadcaster: bool = False,
185
+ group_name: Optional[str] = None,
186
+ ):
187
+ group_name = group_name or "anonymous"
188
+ self.unique_name = _get_unique_name(group_name)
189
+ _register_group(self)
190
+
191
+ self.rank = torch.distributed.get_rank()
192
+ self.local_rank = local_rank
193
+ self.device_group = None
194
+ self.cpu_group = None
195
+
196
+ for ranks in group_ranks:
197
+ device_group = torch.distributed.new_group(
198
+ ranks, backend=torch_distributed_backend
199
+ )
200
+ # a group with `gloo` backend, to allow direct coordination between
201
+ # processes through the CPU.
202
+ cpu_group = torch.distributed.new_group(ranks, backend="gloo")
203
+ if self.rank in ranks:
204
+ self.ranks = ranks
205
+ self.world_size = len(ranks)
206
+ self.rank_in_group = ranks.index(self.rank)
207
+ self.device_group = device_group
208
+ self.cpu_group = cpu_group
209
+
210
+ assert self.cpu_group is not None
211
+ assert self.device_group is not None
212
+
213
+ if is_cuda_alike():
214
+ self.device = torch.device(f"cuda:{local_rank}")
215
+ else:
216
+ self.device = torch.device("cpu")
217
+
218
+ self.use_pynccl = use_pynccl
219
+ self.use_custom_allreduce = use_custom_allreduce
220
+ self.use_hpu_communicator = use_hpu_communicator
221
+ self.use_xpu_communicator = use_xpu_communicator
222
+
223
+ # lazy import to avoid documentation build error
224
+ from sglang.srt.distributed.device_communicators.custom_all_reduce import (
225
+ CustomAllreduce,
226
+ )
227
+ from sglang.srt.distributed.device_communicators.pynccl import (
228
+ PyNcclCommunicator,
229
+ )
230
+
231
+ self.pynccl_comm: Optional[PyNcclCommunicator] = None
232
+ if use_pynccl and self.world_size > 1:
233
+ self.pynccl_comm = PyNcclCommunicator(
234
+ group=self.cpu_group,
235
+ device=self.device,
236
+ )
237
+
238
+ self.ca_comm: Optional[CustomAllreduce] = None
239
+ if use_custom_allreduce and self.world_size > 1:
240
+ # Initialize a custom fast all-reduce implementation.
241
+ self.ca_comm = CustomAllreduce(
242
+ group=self.cpu_group,
243
+ device=self.device,
244
+ )
245
+
246
+ from sglang.srt.distributed.device_communicators.hpu_communicator import (
247
+ HpuCommunicator,
248
+ )
249
+
250
+ self.hpu_communicator: Optional[HpuCommunicator]
251
+ if use_hpu_communicator and self.world_size > 1:
252
+ self.hpu_communicator = HpuCommunicator(group=self.device_group)
253
+
254
+ from sglang.srt.distributed.device_communicators.xpu_communicator import (
255
+ XpuCommunicator,
256
+ )
257
+
258
+ self.xpu_communicator: Optional[XpuCommunicator]
259
+ if use_xpu_communicator and self.world_size > 1:
260
+ self.xpu_communicator = XpuCommunicator(group=self.device_group)
261
+
262
+ from sglang.srt.distributed.device_communicators.shm_broadcast import (
263
+ MessageQueue,
264
+ )
265
+
266
+ self.mq_broadcaster: Optional[MessageQueue] = None
267
+ if use_message_queue_broadcaster and self.world_size > 1:
268
+ self.mq_broadcaster = MessageQueue.create_from_process_group(
269
+ self.cpu_group, 1 << 22, 6
270
+ )
271
+
272
+ @property
273
+ def first_rank(self):
274
+ """Return the global rank of the first process in the group"""
275
+ return self.ranks[0]
276
+
277
+ @property
278
+ def last_rank(self):
279
+ """Return the global rank of the last process in the group"""
280
+ return self.ranks[-1]
281
+
282
+ @property
283
+ def is_first_rank(self):
284
+ """Return whether the caller is the first process in the group"""
285
+ return self.rank == self.first_rank
286
+
287
+ @property
288
+ def is_last_rank(self):
289
+ """Return whether the caller is the last process in the group"""
290
+ return self.rank == self.last_rank
291
+
292
+ @property
293
+ def next_rank(self):
294
+ """Return the global rank of the process that follows the caller"""
295
+ rank_in_group = self.rank_in_group
296
+ world_size = self.world_size
297
+ return self.ranks[(rank_in_group + 1) % world_size]
298
+
299
+ @property
300
+ def prev_rank(self):
301
+ """Return the global rank of the process that precedes the caller"""
302
+ rank_in_group = self.rank_in_group
303
+ world_size = self.world_size
304
+ return self.ranks[(rank_in_group - 1) % world_size]
305
+
306
+ @contextmanager
307
+ def graph_capture(
308
+ self, graph_capture_context: Optional[GraphCaptureContext] = None
309
+ ):
310
+ if graph_capture_context is None:
311
+ stream = torch.cuda.Stream()
312
+ graph_capture_context = GraphCaptureContext(stream)
313
+ else:
314
+ stream = graph_capture_context.stream
315
+
316
+ ca_comm = self.ca_comm
317
+ maybe_ca_context = nullcontext() if ca_comm is None else ca_comm.capture()
318
+
319
+ # ensure all initialization operations complete before attempting to
320
+ # capture the graph on another stream
321
+ curr_stream = torch.cuda.current_stream()
322
+ if curr_stream != stream:
323
+ stream.wait_stream(curr_stream)
324
+
325
+ with torch.cuda.stream(stream), maybe_ca_context:
326
+ # In graph mode, we have to be very careful about the collective
327
+ # operations. The current status is:
328
+ # allreduce \ Mode | Eager | Graph |
329
+ # --------------------------------------------
330
+ # custom allreduce | enabled | enabled |
331
+ # PyNccl | disabled| enabled |
332
+ # torch.distributed | enabled | disabled|
333
+ #
334
+ # Note that custom allreduce will have a runtime check, if the
335
+ # tensor size is too large, it will fallback to the next
336
+ # available option.
337
+ # In summary: When using CUDA graph, we use
338
+ # either custom all-reduce kernel or pynccl. When not using
339
+ # CUDA graph, we use either custom all-reduce kernel or
340
+ # PyTorch NCCL. We always prioritize using custom all-reduce
341
+ # kernel but fall back to PyTorch or pynccl if it is
342
+ # disabled or not supported.
343
+ pynccl_comm = self.pynccl_comm
344
+ maybe_pynccl_context: Any
345
+ if not pynccl_comm:
346
+ maybe_pynccl_context = nullcontext()
347
+ else:
348
+ maybe_pynccl_context = pynccl_comm.change_state(
349
+ enable=True, stream=torch.cuda.current_stream()
350
+ )
351
+ with maybe_pynccl_context:
352
+ yield graph_capture_context
353
+
354
+ def all_reduce(self, input_: torch.Tensor) -> torch.Tensor:
355
+ """
356
+ User-facing all-reduce function before we actually call the
357
+ all-reduce operation.
358
+
359
+ We need this because Dynamo does not support passing an arbitrary
360
+ object (`self` in this case) to a custom op. We need to pass the
361
+ group name as a string, and then look up the group coordinator from
362
+ the group name, dispatch the all-reduce operation to the group
363
+ coordinator.
364
+
365
+ In addition, PyTorch custom ops do not support mutation or returning
366
+ a new tensor in the same op. So we need to figure out if the op is
367
+ in-place or out-of-place ahead of time.
368
+ """
369
+ # Bypass the function if we are using only 1 GPU.
370
+ if self.world_size == 1:
371
+ return input_
372
+
373
+ if input_.is_cpu:
374
+ import intel_extension_for_pytorch as ipex
375
+
376
+ ipex.distributed.all_reduce(input_, group=self.device_group)
377
+ return input_
378
+
379
+ if not supports_custom_op():
380
+ self._all_reduce_in_place(input_)
381
+ return input_
382
+
383
+ if self.hpu_communicator is not None and not self.hpu_communicator.disabled:
384
+ return self.hpu_communicator.all_reduce(input_)
385
+
386
+ if self.xpu_communicator is not None and not self.xpu_communicator.disabled:
387
+ return self.xpu_communicator.all_reduce(input_)
388
+
389
+ if (
390
+ self.ca_comm is not None
391
+ and not self.ca_comm.disabled
392
+ and self.ca_comm.should_custom_ar(input_)
393
+ ):
394
+ return torch.ops.sglang.outplace_all_reduce(
395
+ input_, group_name=self.unique_name
396
+ )
397
+ else:
398
+ torch.ops.sglang.inplace_all_reduce(input_, group_name=self.unique_name)
399
+ return input_
400
+
401
+ def _all_reduce_out_place(self, input_: torch.Tensor) -> torch.Tensor:
402
+ ca_comm = self.ca_comm
403
+ assert ca_comm is not None
404
+ assert not ca_comm.disabled
405
+ out = ca_comm.custom_all_reduce(input_)
406
+ assert out is not None
407
+ return out
408
+
409
+ def _all_reduce_in_place(self, input_: torch.Tensor) -> None:
410
+ pynccl_comm = self.pynccl_comm
411
+ if pynccl_comm is not None and not pynccl_comm.disabled:
412
+ pynccl_comm.all_reduce(input_)
413
+ else:
414
+ torch.distributed.all_reduce(input_, group=self.device_group)
415
+
416
+ def all_gather(self, input_: torch.Tensor, dim: int = -1) -> torch.Tensor:
417
+ world_size = self.world_size
418
+ # Bypass the function if we are using only 1 GPU.
419
+ if world_size == 1:
420
+ return input_
421
+ assert (
422
+ -input_.dim() <= dim < input_.dim()
423
+ ), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}"
424
+
425
+ # For HPUs, use HPU communicator.
426
+ hpu_comm = self.hpu_communicator
427
+ if hpu_comm is not None and not hpu_comm.disabled:
428
+ return hpu_comm.all_gather(input_, dim)
429
+
430
+ if dim < 0:
431
+ # Convert negative dim to positive.
432
+ dim += input_.dim()
433
+ input_size = input_.size()
434
+ # NOTE: we have to use concat-style all-gather here,
435
+ # stack-style all-gather has compatibility issues with
436
+ # torch.compile . see https://github.com/pytorch/pytorch/issues/138795
437
+ output_size = (input_size[0] * world_size,) + input_size[1:]
438
+ # Allocate output tensor.
439
+ output_tensor = torch.empty(
440
+ output_size, dtype=input_.dtype, device=input_.device
441
+ )
442
+ # All-gather.
443
+ torch.distributed.all_gather_into_tensor(
444
+ output_tensor, input_, group=self.device_group
445
+ )
446
+ # Reshape
447
+ output_tensor = output_tensor.reshape((world_size,) + input_size)
448
+ output_tensor = output_tensor.movedim(0, dim)
449
+ output_tensor = output_tensor.reshape(
450
+ input_size[:dim] + (world_size * input_size[dim],) + input_size[dim + 1 :]
451
+ )
452
+ return output_tensor
453
+
454
+ def gather(
455
+ self, input_: torch.Tensor, dst: int = 0, dim: int = -1
456
+ ) -> Optional[torch.Tensor]:
457
+ """
458
+ NOTE: We assume that the input tensor is on the same device across
459
+ all the ranks.
460
+ NOTE: `dst` is the local rank of the destination rank.
461
+ """
462
+ world_size = self.world_size
463
+ # Bypass the function if we are using only 1 GPU.
464
+ if world_size == 1:
465
+ return input_
466
+ assert (
467
+ -input_.dim() <= dim < input_.dim()
468
+ ), f"Invalid dim ({dim}) for input tensor with shape {input_.size()}"
469
+ if dim < 0:
470
+ # Convert negative dim to positive.
471
+ dim += input_.dim()
472
+ if self.xpu_communicator is not None and not self.xpu_communicator.disabled:
473
+ return self.xpu_communicator.gather(input_, self.rank_in_group, dst, dim)
474
+ # Allocate output tensor.
475
+ if self.rank_in_group == dst:
476
+ gather_list = [torch.empty_like(input_) for _ in range(world_size)]
477
+ else:
478
+ gather_list = None
479
+ # Gather.
480
+ torch.distributed.gather(
481
+ input_, gather_list, dst=self.ranks[dst], group=self.device_group
482
+ )
483
+ if self.rank_in_group == dst:
484
+ output_tensor = torch.cat(gather_list, dim=dim)
485
+ else:
486
+ output_tensor = None
487
+ return output_tensor
488
+
489
+ def broadcast(self, input_: torch.Tensor, src: int = 0):
490
+ """Broadcast the input tensor.
491
+ NOTE: `src` is the local rank of the source rank.
492
+ """
493
+ assert src < self.world_size, f"Invalid src rank ({src})"
494
+
495
+ # Bypass the function if we are using only 1 GPU.
496
+ if self.world_size == 1:
497
+ return input_
498
+ # Broadcast.
499
+ torch.distributed.broadcast(
500
+ input_, src=self.ranks[src], group=self.device_group
501
+ )
502
+ return input_
503
+
504
+ def broadcast_object(self, obj: Optional[Any] = None, src: int = 0):
505
+ """Broadcast the input object.
506
+ NOTE: `src` is the local rank of the source rank.
507
+ """
508
+ assert src < self.world_size, f"Invalid src rank ({src})"
509
+
510
+ # Bypass the function if we are using only 1 GPU.
511
+ if self.world_size == 1:
512
+ return obj
513
+ if self.mq_broadcaster is not None:
514
+ assert src == 0, "Message queue broadcaster only supports src=0"
515
+ return self.mq_broadcaster.broadcast_object(obj)
516
+ if self.rank_in_group == src:
517
+ torch.distributed.broadcast_object_list(
518
+ [obj], src=self.ranks[src], group=self.cpu_group
519
+ )
520
+ return obj
521
+ else:
522
+ recv = [None]
523
+ torch.distributed.broadcast_object_list(
524
+ recv, src=self.ranks[src], group=self.cpu_group
525
+ )
526
+ return recv[0]
527
+
528
+ def broadcast_object_list(
529
+ self, obj_list: List[Any], src: int = 0, group: Optional[ProcessGroup] = None
530
+ ):
531
+ """Broadcast the input object list.
532
+ NOTE: `src` is the local rank of the source rank.
533
+ """
534
+ assert src < self.world_size, f"Invalid src rank ({src})"
535
+
536
+ # Bypass the function if we are using only 1 GPU.
537
+ if self.world_size == 1:
538
+ return obj_list
539
+ # Broadcast.
540
+ torch.distributed.broadcast_object_list(
541
+ obj_list, src=self.ranks[src], group=self.device_group
542
+ )
543
+ return obj_list
544
+
545
+ def send_object(self, obj: Any, dst: int) -> None:
546
+ """Send the input object list to the destination rank."""
547
+ """NOTE: `dst` is the local rank of the destination rank."""
548
+
549
+ assert dst < self.world_size, f"Invalid dst rank ({dst})"
550
+
551
+ assert dst != self.rank_in_group, (
552
+ "Invalid destination rank. Destination rank is the same "
553
+ "as the current rank."
554
+ )
555
+
556
+ # Serialize object to tensor and get the size as well
557
+ object_tensor = torch.frombuffer(pickle.dumps(obj), dtype=torch.uint8)
558
+
559
+ size_tensor = torch.tensor(
560
+ [object_tensor.numel()], dtype=torch.long, device="cpu"
561
+ )
562
+
563
+ # Send object size
564
+
565
+ torch.distributed.send(size_tensor, dst=self.ranks[dst], group=self.cpu_group)
566
+
567
+ # Send object
568
+ torch.distributed.send(object_tensor, dst=self.ranks[dst], group=self.cpu_group)
569
+
570
+ return None
571
+
572
+ def recv_object(self, src: int) -> Any:
573
+ """Receive the input object list from the source rank."""
574
+ """NOTE: `src` is the local rank of the source rank."""
575
+
576
+ assert src < self.world_size, f"Invalid src rank ({src})"
577
+
578
+ assert (
579
+ src != self.rank_in_group
580
+ ), "Invalid source rank. Source rank is the same as the current rank."
581
+
582
+ size_tensor = torch.empty(1, dtype=torch.long, device="cpu")
583
+
584
+ # Receive object size
585
+ rank_size = torch.distributed.recv(
586
+ size_tensor, src=self.ranks[src], group=self.cpu_group
587
+ )
588
+
589
+ # Tensor to receive serialized objects into.
590
+ object_tensor = torch.empty( # type: ignore[call-overload]
591
+ size_tensor.item(), # type: ignore[arg-type]
592
+ dtype=torch.uint8,
593
+ device="cpu",
594
+ )
595
+
596
+ rank_object = torch.distributed.recv(
597
+ object_tensor, src=self.ranks[src], group=self.cpu_group
598
+ )
599
+
600
+ assert (
601
+ rank_object == rank_size
602
+ ), "Received object sender rank does not match the size sender rank."
603
+
604
+ obj = pickle.loads(object_tensor.numpy().tobytes())
605
+
606
+ return obj
607
+
608
+ def broadcast_tensor_dict(
609
+ self,
610
+ tensor_dict: Optional[Dict[str, Union[torch.Tensor, Any]]] = None,
611
+ src: int = 0,
612
+ group: Optional[ProcessGroup] = None,
613
+ metadata_group: Optional[ProcessGroup] = None,
614
+ ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
615
+ """Broadcast the input tensor dictionary.
616
+ NOTE: `src` is the local rank of the source rank.
617
+ """
618
+ # Bypass the function if we are using only 1 GPU.
619
+ if not torch.distributed.is_initialized() or self.world_size == 1:
620
+ return tensor_dict
621
+
622
+ group = self.device_group
623
+ metadata_group = self.cpu_group
624
+ assert src < self.world_size, f"Invalid src rank ({src})"
625
+
626
+ rank_in_group = self.rank_in_group
627
+ if rank_in_group == src:
628
+ metadata_list: List[Tuple[Any, Any]] = []
629
+ assert isinstance(
630
+ tensor_dict, dict
631
+ ), f"Expecting a dictionary, got {type(tensor_dict)}"
632
+ metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
633
+ # `metadata_list` lives in CPU memory.
634
+ # `broadcast_object_list` has serialization & deserialization,
635
+ # all happening on CPU. Therefore, we can use the CPU group.
636
+ self.broadcast_object(metadata_list, src=src)
637
+ async_handles = []
638
+ for tensor in tensor_list:
639
+ if tensor.numel() == 0:
640
+ # Skip broadcasting empty tensors.
641
+ continue
642
+ if tensor.is_cpu:
643
+ # use metadata_group for CPU tensors
644
+ handle = torch.distributed.broadcast(
645
+ tensor, src=self.ranks[src], group=metadata_group, async_op=True
646
+ )
647
+ else:
648
+ # use group for GPU tensors
649
+ handle = torch.distributed.broadcast(
650
+ tensor, src=self.ranks[src], group=group, async_op=True
651
+ )
652
+ async_handles.append(handle)
653
+ for async_handle in async_handles:
654
+ async_handle.wait()
655
+
656
+ else:
657
+ metadata_list = self.broadcast_object(None, src=src)
658
+ tensor_dict = {}
659
+ async_handles = []
660
+ for key, value in metadata_list:
661
+ if isinstance(value, TensorMetadata):
662
+ tensor = torch.empty(
663
+ value.size, dtype=value.dtype, device=value.device
664
+ )
665
+ if tensor.numel() == 0:
666
+ # Skip broadcasting empty tensors.
667
+ tensor_dict[key] = tensor
668
+ continue
669
+ if tensor.is_cpu:
670
+ # use metadata_group for CPU tensors
671
+ handle = torch.distributed.broadcast(
672
+ tensor,
673
+ src=self.ranks[src],
674
+ group=metadata_group,
675
+ async_op=True,
676
+ )
677
+ else:
678
+ # use group for GPU tensors
679
+ handle = torch.distributed.broadcast(
680
+ tensor, src=self.ranks[src], group=group, async_op=True
681
+ )
682
+ async_handles.append(handle)
683
+ tensor_dict[key] = tensor
684
+ else:
685
+ tensor_dict[key] = value
686
+ for async_handle in async_handles:
687
+ async_handle.wait()
688
+ return tensor_dict
689
+
690
+ def send_tensor_dict(
691
+ self,
692
+ tensor_dict: Dict[str, Union[torch.Tensor, Any]],
693
+ dst: Optional[int] = None,
694
+ all_gather_group: Optional["GroupCoordinator"] = None,
695
+ ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
696
+ """Send the input tensor dictionary.
697
+ NOTE: `dst` is the local rank of the source rank.
698
+ """
699
+ # Bypass the function if we are using only 1 GPU.
700
+ if not torch.distributed.is_initialized() or self.world_size == 1:
701
+ return tensor_dict
702
+
703
+ all_gather_size = 1 if all_gather_group is None else all_gather_group.world_size
704
+ all_gather_rank = (
705
+ 0 if all_gather_group is None else all_gather_group.rank_in_group
706
+ )
707
+
708
+ group = self.device_group
709
+ metadata_group = self.cpu_group
710
+
711
+ if dst is None:
712
+ dst = (self.rank_in_group + 1) % self.world_size
713
+ assert dst < self.world_size, f"Invalid dst rank ({dst})"
714
+
715
+ metadata_list: List[Tuple[Any, Any]] = []
716
+ assert isinstance(
717
+ tensor_dict, dict
718
+ ), f"Expecting a dictionary, got {type(tensor_dict)}"
719
+ metadata_list, tensor_list = _split_tensor_dict(tensor_dict)
720
+ # `metadata_list` lives in CPU memory.
721
+ # `send_object_list` has serialization & deserialization,
722
+ # all happening on CPU. Therefore, we can use the CPU group.
723
+ self.send_object(metadata_list, dst=dst)
724
+ for tensor in tensor_list:
725
+ if tensor.numel() == 0:
726
+ # Skip sending empty tensors.
727
+ continue
728
+
729
+ # send-allgather: send only a slice, then do allgather.
730
+ if all_gather_group is not None and tensor.numel() % all_gather_size == 0:
731
+ tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]
732
+
733
+ if tensor.is_cpu:
734
+ # use metadata_group for CPU tensors
735
+ torch.distributed.send(
736
+ tensor, dst=self.ranks[dst], group=metadata_group
737
+ )
738
+ else:
739
+ # use group for GPU tensors
740
+ torch.distributed.send(tensor, dst=self.ranks[dst], group=group)
741
+ return None
742
+
743
+ def recv_tensor_dict(
744
+ self,
745
+ src: Optional[int] = None,
746
+ all_gather_group: Optional["GroupCoordinator"] = None,
747
+ ) -> Optional[Dict[str, Union[torch.Tensor, Any]]]:
748
+ """Recv the input tensor dictionary.
749
+ NOTE: `src` is the local rank of the source rank.
750
+ """
751
+ # Bypass the function if we are using only 1 GPU.
752
+ if not torch.distributed.is_initialized() or self.world_size == 1:
753
+ return None
754
+
755
+ all_gather_size = 1 if all_gather_group is None else all_gather_group.world_size
756
+ all_gather_rank = (
757
+ 0 if all_gather_group is None else all_gather_group.rank_in_group
758
+ )
759
+
760
+ group = self.device_group
761
+ metadata_group = self.cpu_group
762
+
763
+ if src is None:
764
+ src = (self.rank_in_group - 1) % self.world_size
765
+ assert src < self.world_size, f"Invalid src rank ({src})"
766
+
767
+ recv_metadata_list = self.recv_object(src=src)
768
+ tensor_dict: Dict[str, Any] = {}
769
+ for key, value in recv_metadata_list:
770
+ if isinstance(value, TensorMetadata):
771
+ tensor = torch.empty(value.size, dtype=value.dtype, device=value.device)
772
+ if tensor.numel() == 0:
773
+ # Skip broadcasting empty tensors.
774
+ tensor_dict[key] = tensor
775
+ continue
776
+
777
+ # send-allgather: send only a slice, then do allgather.
778
+ use_all_gather = (
779
+ all_gather_group is not None
780
+ and tensor.numel() % all_gather_size == 0
781
+ )
782
+
783
+ if use_all_gather:
784
+ orig_shape = tensor.shape
785
+ tensor = tensor.reshape(all_gather_size, -1)[all_gather_rank]
786
+
787
+ if tensor.is_cpu:
788
+ # use metadata_group for CPU tensors
789
+ torch.distributed.recv(
790
+ tensor, src=self.ranks[src], group=metadata_group
791
+ )
792
+ else:
793
+ # use group for GPU tensors
794
+ torch.distributed.recv(tensor, src=self.ranks[src], group=group)
795
+ if use_all_gather:
796
+ # do the allgather
797
+ tensor = all_gather_group.all_gather(tensor, dim=0) # type: ignore
798
+ tensor = tensor.reshape(orig_shape)
799
+
800
+ tensor_dict[key] = tensor
801
+ else:
802
+ tensor_dict[key] = value
803
+ return tensor_dict
804
+
805
+ def barrier(self):
806
+ """Barrier synchronization among the group.
807
+ NOTE: don't use `device_group` here! `barrier` in NCCL is
808
+ terrible because it is internally a broadcast operation with
809
+ secretly created GPU tensors. It is easy to mess up the current
810
+ device. Use the CPU group instead.
811
+ """
812
+ torch.distributed.barrier(group=self.cpu_group)
813
+
814
+ def send(self, tensor: torch.Tensor, dst: Optional[int] = None) -> None:
815
+ """Sends a tensor to the destination rank in a non-blocking way"""
816
+ """NOTE: `dst` is the local rank of the destination rank."""
817
+ if dst is None:
818
+ dst = (self.rank_in_group + 1) % self.world_size
819
+
820
+ pynccl_comm = self.pynccl_comm
821
+ if pynccl_comm is not None and not pynccl_comm.disabled:
822
+ pynccl_comm.send(tensor, dst)
823
+ else:
824
+ torch.distributed.send(tensor, self.ranks[dst], self.device_group)
825
+
826
+ def recv(
827
+ self, size: torch.Size, dtype: torch.dtype, src: Optional[int] = None
828
+ ) -> torch.Tensor:
829
+ """Receives a tensor from the source rank."""
830
+ """NOTE: `src` is the local rank of the source rank."""
831
+ if src is None:
832
+ src = (self.rank_in_group - 1) % self.world_size
833
+
834
+ tensor = torch.empty(size, dtype=dtype, device=self.device)
835
+ pynccl_comm = self.pynccl_comm
836
+ if pynccl_comm is not None and not pynccl_comm.disabled:
837
+ pynccl_comm.recv(tensor, src)
838
+ else:
839
+ torch.distributed.recv(tensor, self.ranks[src], self.device_group)
840
+ return tensor
841
+
842
+ def destroy(self):
843
+ if self.device_group is not None:
844
+ torch.distributed.destroy_process_group(self.device_group)
845
+ self.device_group = None
846
+ if self.cpu_group is not None:
847
+ torch.distributed.destroy_process_group(self.cpu_group)
848
+ self.cpu_group = None
849
+ if self.pynccl_comm is not None:
850
+ self.pynccl_comm = None
851
+ if self.ca_comm is not None:
852
+ self.ca_comm = None
853
+ if self.mq_broadcaster is not None:
854
+ self.mq_broadcaster = None
855
+
856
+
857
+ _WORLD: Optional[GroupCoordinator] = None
858
+
859
+
860
+ def get_world_group() -> GroupCoordinator:
861
+ assert _WORLD is not None, "world group is not initialized"
862
+ return _WORLD
863
+
864
+
865
+ def init_world_group(
866
+ ranks: List[int], local_rank: int, backend: str
867
+ ) -> GroupCoordinator:
868
+ return GroupCoordinator(
869
+ group_ranks=[ranks],
870
+ local_rank=local_rank,
871
+ torch_distributed_backend=backend,
872
+ use_pynccl=False,
873
+ use_custom_allreduce=False,
874
+ use_hpu_communicator=False,
875
+ use_xpu_communicator=False,
876
+ group_name="world",
877
+ )
878
+
879
+
880
+ def init_model_parallel_group(
881
+ group_ranks: List[List[int]],
882
+ local_rank: int,
883
+ backend: str,
884
+ use_custom_allreduce: Optional[bool] = None,
885
+ use_message_queue_broadcaster: bool = False,
886
+ group_name: Optional[str] = None,
887
+ ) -> GroupCoordinator:
888
+ if use_custom_allreduce is None:
889
+ use_custom_allreduce = _ENABLE_CUSTOM_ALL_REDUCE
890
+ return GroupCoordinator(
891
+ group_ranks=group_ranks,
892
+ local_rank=local_rank,
893
+ torch_distributed_backend=backend,
894
+ use_pynccl=True,
895
+ use_custom_allreduce=use_custom_allreduce,
896
+ use_hpu_communicator=True,
897
+ use_xpu_communicator=True,
898
+ use_message_queue_broadcaster=use_message_queue_broadcaster,
899
+ group_name=group_name,
900
+ )
901
+
902
+
903
+ _TP: Optional[GroupCoordinator] = None
904
+
905
+
906
+ def get_tp_group() -> GroupCoordinator:
907
+ assert _TP is not None, "tensor model parallel group is not initialized"
908
+ return _TP
909
+
910
+
911
+ # kept for backward compatibility
912
+ get_tensor_model_parallel_group = get_tp_group
913
+
914
+ _PP: Optional[GroupCoordinator] = None
915
+
916
+
917
+ def get_pp_group() -> GroupCoordinator:
918
+ assert _PP is not None, "pipeline model parallel group is not initialized"
919
+ return _PP
920
+
921
+
922
+ # kept for backward compatibility
923
+ get_pipeline_model_parallel_group = get_pp_group
924
+
925
+
926
+ @contextmanager
927
+ def graph_capture():
928
+ """
929
+ `graph_capture` is a context manager which should surround the code that
930
+ is capturing the CUDA graph. Its main purpose is to ensure that the
931
+ some operations will be run after the graph is captured, before the graph
932
+ is replayed. It returns a `GraphCaptureContext` object which contains the
933
+ necessary data for the graph capture. Currently, it only contains the
934
+ stream that the graph capture is running on. This stream is set to the
935
+ current CUDA stream when the context manager is entered and reset to the
936
+ default stream when the context manager is exited. This is to ensure that
937
+ the graph capture is running on a separate stream from the default stream,
938
+ in order to explicitly distinguish the kernels to capture
939
+ from other kernels possibly launched on background in the default stream.
940
+ """
941
+ with get_tp_group().graph_capture() as context, get_pp_group().graph_capture(
942
+ context
943
+ ):
944
+ yield context
945
+
946
+
947
+ logger = logging.getLogger(__name__)
948
+
949
+ _ENABLE_CUSTOM_ALL_REDUCE = True
950
+
951
+
952
+ def set_custom_all_reduce(enable: bool):
953
+ global _ENABLE_CUSTOM_ALL_REDUCE
954
+ _ENABLE_CUSTOM_ALL_REDUCE = enable
955
+
956
+
957
+ def init_distributed_environment(
958
+ world_size: int = -1,
959
+ rank: int = -1,
960
+ distributed_init_method: str = "env://",
961
+ local_rank: int = -1,
962
+ backend: str = "nccl",
963
+ ):
964
+ logger.debug(
965
+ "world_size=%d rank=%d local_rank=%d " "distributed_init_method=%s backend=%s",
966
+ world_size,
967
+ rank,
968
+ local_rank,
969
+ distributed_init_method,
970
+ backend,
971
+ )
972
+ if not torch.distributed.is_initialized():
973
+ assert distributed_init_method is not None, (
974
+ "distributed_init_method must be provided when initializing "
975
+ "distributed environment"
976
+ )
977
+ # this backend is used for WORLD
978
+ torch.distributed.init_process_group(
979
+ backend=backend,
980
+ init_method=distributed_init_method,
981
+ world_size=world_size,
982
+ rank=rank,
983
+ )
984
+ # set the local rank
985
+ # local_rank is not available in torch ProcessGroup,
986
+ # see https://github.com/pytorch/pytorch/issues/122816
987
+ if local_rank == -1:
988
+ # local rank not set, this usually happens in single-node
989
+ # setting, where we can use rank as local rank
990
+ if distributed_init_method == "env://":
991
+ local_rank = int(os.environ.get("LOCAL_RANK", "0"))
992
+ else:
993
+ local_rank = rank
994
+ global _WORLD
995
+ if _WORLD is None:
996
+ ranks = list(range(torch.distributed.get_world_size()))
997
+ _WORLD = init_world_group(ranks, local_rank, backend)
998
+ else:
999
+ assert (
1000
+ _WORLD.world_size == torch.distributed.get_world_size()
1001
+ ), "world group already initialized with a different world size"
1002
+
1003
+
1004
+ def initialize_model_parallel(
1005
+ tensor_model_parallel_size: int = 1,
1006
+ pipeline_model_parallel_size: int = 1,
1007
+ backend: Optional[str] = None,
1008
+ ) -> None:
1009
+ """
1010
+ Initialize model parallel groups.
1011
+
1012
+ Arguments:
1013
+ tensor_model_parallel_size: number of GPUs used for tensor model
1014
+ parallelism.
1015
+ pipeline_model_parallel_size: number of GPUs used for pipeline model
1016
+ parallelism.
1017
+
1018
+ Let's say we have a total of 8 GPUs denoted by g0 ... g7 and we
1019
+ use 2 GPUs to parallelize the model tensor, and 4 GPUs to parallelize
1020
+ the model pipeline. The present function will
1021
+ create 4 tensor model-parallel groups and 2 pipeline model-parallel groups:
1022
+ 4 tensor model-parallel groups:
1023
+ [g0, g1], [g2, g3], [g4, g5], [g6, g7]
1024
+ 2 pipeline model-parallel groups:
1025
+ [g0, g2, g4, g6], [g1, g3, g5, g7]
1026
+ Note that for efficiency, the caller should make sure adjacent ranks
1027
+ are on the same DGX box. For example if we are using 2 DGX-1 boxes
1028
+ with a total of 16 GPUs, rank 0 to 7 belong to the first box and
1029
+ ranks 8 to 15 belong to the second box.
1030
+ """
1031
+ # Get world size and rank. Ensure some consistencies.
1032
+ assert torch.distributed.is_initialized()
1033
+ world_size: int = torch.distributed.get_world_size()
1034
+ backend = backend or torch.distributed.get_backend(get_world_group().device_group)
1035
+
1036
+ if world_size != tensor_model_parallel_size * pipeline_model_parallel_size:
1037
+ raise RuntimeError(
1038
+ f"world_size ({world_size}) is not equal to "
1039
+ f"tensor_model_parallel_size ({tensor_model_parallel_size}) x "
1040
+ f"pipeline_model_parallel_size ({pipeline_model_parallel_size})"
1041
+ )
1042
+
1043
+ # Build the tensor model-parallel groups.
1044
+ num_tensor_model_parallel_groups: int = world_size // tensor_model_parallel_size
1045
+ global _TP
1046
+ assert _TP is None, "tensor model parallel group is already initialized"
1047
+ group_ranks = []
1048
+ for i in range(num_tensor_model_parallel_groups):
1049
+ ranks = list(
1050
+ range(i * tensor_model_parallel_size, (i + 1) * tensor_model_parallel_size)
1051
+ )
1052
+ group_ranks.append(ranks)
1053
+
1054
+ # message queue broadcaster is only used in tensor model parallel group
1055
+ _TP = init_model_parallel_group(
1056
+ group_ranks,
1057
+ get_world_group().local_rank,
1058
+ backend,
1059
+ use_message_queue_broadcaster=True,
1060
+ group_name="tp",
1061
+ )
1062
+
1063
+ # Build the pipeline model-parallel groups.
1064
+ num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size
1065
+ global _PP
1066
+ assert _PP is None, "pipeline model parallel group is already initialized"
1067
+ group_ranks = []
1068
+ for i in range(num_pipeline_model_parallel_groups):
1069
+ ranks = list(range(i, world_size, num_pipeline_model_parallel_groups))
1070
+ group_ranks.append(ranks)
1071
+ # pipeline parallel does not need custom allreduce
1072
+ _PP = init_model_parallel_group(
1073
+ group_ranks,
1074
+ get_world_group().local_rank,
1075
+ backend,
1076
+ use_custom_allreduce=False,
1077
+ group_name="pp",
1078
+ )
1079
+
1080
+
1081
+ def ensure_model_parallel_initialized(
1082
+ tensor_model_parallel_size: int,
1083
+ pipeline_model_parallel_size: int,
1084
+ backend: Optional[str] = None,
1085
+ ) -> None:
1086
+ """Helper to initialize model parallel groups if they are not initialized,
1087
+ or ensure tensor-parallel and pipeline-parallel sizes are equal to expected
1088
+ values if the model parallel groups are initialized.
1089
+ """
1090
+ backend = backend or torch.distributed.get_backend(get_world_group().device_group)
1091
+ if not model_parallel_is_initialized():
1092
+ initialize_model_parallel(
1093
+ tensor_model_parallel_size, pipeline_model_parallel_size, backend
1094
+ )
1095
+ return
1096
+
1097
+ assert get_tensor_model_parallel_world_size() == tensor_model_parallel_size, (
1098
+ "tensor parallel group already initialized, but of unexpected size: "
1099
+ f"{get_tensor_model_parallel_world_size()=} vs. "
1100
+ f"{tensor_model_parallel_size=}"
1101
+ )
1102
+ pp_world_size = get_pp_group().world_size
1103
+ assert pp_world_size == pipeline_model_parallel_size, (
1104
+ "pipeline parallel group already initialized, but of unexpected size: "
1105
+ f"{pp_world_size=} vs. "
1106
+ f"{pipeline_model_parallel_size=}"
1107
+ )
1108
+
1109
+
1110
+ def model_parallel_is_initialized():
1111
+ """Check if tensor and pipeline parallel groups are initialized."""
1112
+ return _TP is not None and _PP is not None
1113
+
1114
+
1115
+ _TP_STATE_PATCHED = False
1116
+
1117
+
1118
+ @contextmanager
1119
+ def patch_tensor_parallel_group(tp_group: GroupCoordinator):
1120
+ """Patch the tp group temporarily until this function ends.
1121
+
1122
+ This method is for draft workers of speculative decoding to run draft model
1123
+ with different tp degree from that of target model workers.
1124
+
1125
+ Args:
1126
+ tp_group (GroupCoordinator): the tp group coordinator
1127
+ """
1128
+ global _TP_STATE_PATCHED
1129
+ assert not _TP_STATE_PATCHED, "Should not call when it's already patched"
1130
+
1131
+ _TP_STATE_PATCHED = True
1132
+ old_tp_group = get_tp_group()
1133
+ global _TP
1134
+ _TP = tp_group
1135
+ try:
1136
+ yield
1137
+ finally:
1138
+ # restore the original state
1139
+ _TP_STATE_PATCHED = False
1140
+ _TP = old_tp_group
1141
+
1142
+
1143
+ def get_tensor_model_parallel_world_size():
1144
+ """Return world size for the tensor model parallel group."""
1145
+ return get_tp_group().world_size
1146
+
1147
+
1148
+ def get_tensor_model_parallel_rank():
1149
+ """Return my rank for the tensor model parallel group."""
1150
+ return get_tp_group().rank_in_group
1151
+
1152
+
1153
+ def destroy_model_parallel():
1154
+ """Set the groups to none and destroy them."""
1155
+ global _TP
1156
+ if _TP:
1157
+ _TP.destroy()
1158
+ _TP = None
1159
+
1160
+ global _PP
1161
+ if _PP:
1162
+ _PP.destroy()
1163
+ _PP = None
1164
+
1165
+
1166
+ def destroy_distributed_environment():
1167
+ global _WORLD
1168
+ if _WORLD:
1169
+ _WORLD.destroy()
1170
+ _WORLD = None
1171
+ if torch.distributed.is_initialized():
1172
+ torch.distributed.destroy_process_group()
1173
+
1174
+
1175
+ def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
1176
+ destroy_model_parallel()
1177
+ destroy_distributed_environment()
1178
+ with contextlib.suppress(AssertionError):
1179
+ torch.distributed.destroy_process_group()
1180
+ if shutdown_ray:
1181
+ import ray # Lazy import Ray
1182
+
1183
+ ray.shutdown()
1184
+ gc.collect()
1185
+ if not current_platform.is_cpu():
1186
+ torch.cuda.empty_cache()
1187
+
1188
+
1189
+ def in_the_same_node_as(pg: ProcessGroup, source_rank: int = 0) -> List[bool]:
1190
+ """
1191
+ This is a collective operation that returns if each rank is in the same node
1192
+ as the source rank. It tests if processes are attached to the same
1193
+ memory system (shared access to shared memory).
1194
+ """
1195
+ assert (
1196
+ torch.distributed.get_backend(pg) != torch.distributed.Backend.NCCL
1197
+ ), "in_the_same_node_as should be tested with a non-NCCL group."
1198
+ # local rank inside the group
1199
+ rank = torch.distributed.get_rank(group=pg)
1200
+ world_size = torch.distributed.get_world_size(group=pg)
1201
+
1202
+ # local tensor in each process to store the result
1203
+ is_in_the_same_node = torch.tensor([0] * world_size, dtype=torch.int32)
1204
+
1205
+ # global ranks of the processes in the group
1206
+ ranks = torch.distributed.get_process_group_ranks(pg)
1207
+
1208
+ magic_message = b"magic_message"
1209
+ shm = None
1210
+
1211
+ try:
1212
+ with contextlib.suppress(OSError):
1213
+ if rank == source_rank:
1214
+ # create a shared memory segment
1215
+ shm = shared_memory.SharedMemory(create=True, size=128)
1216
+ shm.buf[: len(magic_message)] = magic_message
1217
+ torch.distributed.broadcast_object_list(
1218
+ [shm.name], src=ranks[source_rank], group=pg
1219
+ )
1220
+ is_in_the_same_node[rank] = 1
1221
+ else:
1222
+ # try to open the shared memory segment
1223
+ recv = [None]
1224
+ torch.distributed.broadcast_object_list(
1225
+ recv, src=ranks[source_rank], group=pg
1226
+ )
1227
+ name = recv[0]
1228
+ # fix to https://stackoverflow.com/q/62748654/9191338
1229
+ # Python incorrectly tracks shared memory even if it is not
1230
+ # created by the process. The following patch is a workaround.
1231
+ with patch(
1232
+ "multiprocessing.resource_tracker.register",
1233
+ lambda *args, **kwargs: None,
1234
+ ):
1235
+ shm = shared_memory.SharedMemory(name=name)
1236
+ if shm.buf[: len(magic_message)] == magic_message:
1237
+ is_in_the_same_node[rank] = 1
1238
+ except Exception as e:
1239
+ logger.error("Error ignored in is_in_the_same_node: %s", e)
1240
+ finally:
1241
+ if shm:
1242
+ shm.close()
1243
+
1244
+ torch.distributed.barrier(group=pg)
1245
+
1246
+ # clean up the shared memory segment
1247
+ with contextlib.suppress(OSError):
1248
+ if rank == source_rank and shm:
1249
+ shm.unlink()
1250
+ torch.distributed.all_reduce(is_in_the_same_node, group=pg)
1251
+
1252
+ return [x == 1 for x in is_in_the_same_node.tolist()]
1253
+
1254
+
1255
+ vllm_get_pp_group = None
1256
+ vllm_get_tp_group = None
1257
+ vllm_get_world_group = None
1258
+
1259
+
1260
+ def monkey_patch_vllm_parallel_state(reverse: bool = False):
1261
+ import vllm.distributed.parallel_state as vllm_parrlel_state
1262
+
1263
+ global vllm_get_pp_group, vllm_get_tp_group, vllm_get_world_group
1264
+ if vllm_get_pp_group is None:
1265
+ vllm_get_pp_group = vllm_parrlel_state.get_pp_group
1266
+ vllm_get_tp_group = vllm_parrlel_state.get_tp_group
1267
+ vllm_get_world_group = vllm_parrlel_state.get_world_group
1268
+ if reverse:
1269
+ setattr(vllm_parrlel_state, "get_pp_group", vllm_get_pp_group)
1270
+ setattr(vllm_parrlel_state, "get_tp_group", vllm_get_tp_group)
1271
+ setattr(vllm_parrlel_state, "get_world_group", vllm_get_world_group)
1272
+ else:
1273
+ setattr(vllm_parrlel_state, "get_pp_group", get_pp_group)
1274
+ setattr(vllm_parrlel_state, "get_tp_group", get_tp_group)
1275
+ setattr(vllm_parrlel_state, "get_world_group", get_world_group)