torchft-nightly 2026.1.3__cp310-cp310-manylinux_2_24_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchft/__init__.py +34 -0
- torchft/_test/diloco_trainer.py +287 -0
- torchft/_test/managed_work_test.py +320 -0
- torchft/_test_utils.py +111 -0
- torchft/_torchft.cpython-310-x86_64-linux-gnu.so +0 -0
- torchft/_torchft.pyi +116 -0
- torchft/checkpointing/__init__.py +20 -0
- torchft/checkpointing/_rwlock.py +136 -0
- torchft/checkpointing/_serialization.py +39 -0
- torchft/checkpointing/http_transport.py +299 -0
- torchft/checkpointing/http_transport_bench.py +61 -0
- torchft/checkpointing/http_transport_test.py +146 -0
- torchft/checkpointing/pg_transport.py +306 -0
- torchft/checkpointing/pg_transport_bench.py +99 -0
- torchft/checkpointing/pg_transport_test.py +101 -0
- torchft/checkpointing/rwlock_test.py +58 -0
- torchft/checkpointing/transport.py +68 -0
- torchft/checkpointing/transport_test.py +161 -0
- torchft/collectives.py +415 -0
- torchft/collectives_test.py +212 -0
- torchft/coordination.py +39 -0
- torchft/coordination_test.py +29 -0
- torchft/data.py +77 -0
- torchft/data_test.py +39 -0
- torchft/ddp.py +105 -0
- torchft/ddp_test.py +68 -0
- torchft/diloco_regression_test.py +644 -0
- torchft/examples/slurm/README.md +34 -0
- torchft/examples/slurm/punisher.py +95 -0
- torchft/examples/slurm/runner.py +221 -0
- torchft/fsdp_test.py +102 -0
- torchft/futures.py +353 -0
- torchft/futures_test.py +140 -0
- torchft/http.py +13 -0
- torchft/lighthouse_test.py +163 -0
- torchft/local_sgd.py +796 -0
- torchft/local_sgd_integ_test.py +600 -0
- torchft/local_sgd_test.py +324 -0
- torchft/manager.py +1358 -0
- torchft/manager_integ_test.py +653 -0
- torchft/manager_test.py +911 -0
- torchft/multiprocessing.py +38 -0
- torchft/multiprocessing_dummy_context.py +135 -0
- torchft/multiprocessing_test.py +58 -0
- torchft/optim.py +63 -0
- torchft/optim_test.py +50 -0
- torchft/otel.py +134 -0
- torchft/parameter_server.py +195 -0
- torchft/parameter_server_test.py +47 -0
- torchft/process_group.py +2118 -0
- torchft/process_group_test.py +1028 -0
- torchft/quantization.py +686 -0
- torchft/quantization_test.py +131 -0
- torchft/torchx.py +89 -0
- torchft/utils.py +67 -0
- torchft/work.py +26 -0
- torchft_nightly-2026.1.3.dist-info/METADATA +308 -0
- torchft_nightly-2026.1.3.dist-info/RECORD +61 -0
- torchft_nightly-2026.1.3.dist-info/WHEEL +4 -0
- torchft_nightly-2026.1.3.dist-info/entry_points.txt +2 -0
- torchft_nightly-2026.1.3.dist-info/licenses/LICENSE +34 -0
torchft/process_group.py
ADDED
|
@@ -0,0 +1,2118 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
"""
|
|
8
|
+
Process Groups
|
|
9
|
+
=========================
|
|
10
|
+
|
|
11
|
+
This module implements fault tolerant process groups that can be reconfigured
|
|
12
|
+
and resized at runtime.
|
|
13
|
+
|
|
14
|
+
These extend the standard PyTorch ProcessGroup API and can be used in most
|
|
15
|
+
places that would accept a standard process group. As these can change size at
|
|
16
|
+
runtime users need to take care to not assume a static rank or world size.
|
|
17
|
+
"""
|
|
18
|
+
|
|
19
|
+
import logging
|
|
20
|
+
import os
|
|
21
|
+
import threading
|
|
22
|
+
import time
|
|
23
|
+
import warnings
|
|
24
|
+
from contextlib import contextmanager, nullcontext
|
|
25
|
+
from dataclasses import dataclass
|
|
26
|
+
from datetime import timedelta
|
|
27
|
+
from multiprocessing.connection import Connection
|
|
28
|
+
from typing import (
|
|
29
|
+
Any,
|
|
30
|
+
Callable,
|
|
31
|
+
cast,
|
|
32
|
+
Dict,
|
|
33
|
+
Generator,
|
|
34
|
+
List,
|
|
35
|
+
Optional,
|
|
36
|
+
Tuple,
|
|
37
|
+
TYPE_CHECKING,
|
|
38
|
+
TypeVar,
|
|
39
|
+
Union,
|
|
40
|
+
)
|
|
41
|
+
|
|
42
|
+
import torch
|
|
43
|
+
import torch.distributed as dist
|
|
44
|
+
import torch.multiprocessing as mp
|
|
45
|
+
|
|
46
|
+
# pyre-fixme[21]: no attribute ProcessGroupGloo
|
|
47
|
+
from torch.distributed import (
|
|
48
|
+
PrefixStore,
|
|
49
|
+
ProcessGroup as BaseProcessGroup,
|
|
50
|
+
ProcessGroupGloo as BaseProcessGroupGloo,
|
|
51
|
+
Store,
|
|
52
|
+
TCPStore,
|
|
53
|
+
)
|
|
54
|
+
from torch.distributed.distributed_c10d import (
|
|
55
|
+
AllgatherOptions,
|
|
56
|
+
AllreduceCoalescedOptions,
|
|
57
|
+
AllreduceOptions,
|
|
58
|
+
AllToAllOptions,
|
|
59
|
+
BarrierOptions,
|
|
60
|
+
BroadcastOptions,
|
|
61
|
+
ReduceOp,
|
|
62
|
+
ReduceScatterOptions,
|
|
63
|
+
Work,
|
|
64
|
+
)
|
|
65
|
+
from torch.futures import Future
|
|
66
|
+
from torch.utils._pytree import tree_any
|
|
67
|
+
|
|
68
|
+
# We import these for backwards compatibility
|
|
69
|
+
from torchft.futures import context_timeout, stream_timeout
|
|
70
|
+
from torchft.multiprocessing import _MonitoredPipe
|
|
71
|
+
from torchft.utils import get_stream_context, record_event, synchronize
|
|
72
|
+
from torchft.work import _DummyWork
|
|
73
|
+
|
|
74
|
+
if TYPE_CHECKING:
|
|
75
|
+
from torchft.manager import Manager
|
|
76
|
+
|
|
77
|
+
logger: logging.Logger = logging.getLogger(__name__)
|
|
78
|
+
|
|
79
|
+
# TODO: use non strings which are cheaper
|
|
80
|
+
_QUEUE_CLOSE = "queue_close"
|
|
81
|
+
_FUTURE_RESULT = "fut_result"
|
|
82
|
+
_FUTURE_EXCEPTION = "fut_exception"
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
T = TypeVar("T")
|
|
86
|
+
|
|
87
|
+
TORCH_NCCL_DEBUG_INFO_PIPE_FILE_ENV_VAR = "TORCH_NCCL_DEBUG_INFO_PIPE_FILE"
|
|
88
|
+
# Used to trigger flight recorder if we trigger abort on the process group
|
|
89
|
+
TORCHFT_TRIGGER_FR_ON_ABORT = "TORCHFT_TRIGGER_FR_ON_ABORT"
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
def trigger_nccl_fr_trace_through_pipe(rank: int) -> bool:
|
|
93
|
+
"""Collect NCCL flight recorder trace through the pipe."""
|
|
94
|
+
dump_file_prefix = os.environ.get(TORCH_NCCL_DEBUG_INFO_PIPE_FILE_ENV_VAR, "")
|
|
95
|
+
if not dump_file_prefix:
|
|
96
|
+
logging.info(
|
|
97
|
+
f"[rank {rank}] Triggering FR trace dump through pipe failed: pipe is not enabled."
|
|
98
|
+
)
|
|
99
|
+
return False
|
|
100
|
+
pipe_name = f"{dump_file_prefix}{rank}.pipe"
|
|
101
|
+
with open(pipe_name, "w") as f:
|
|
102
|
+
# Trigger fr trace dump through pipe
|
|
103
|
+
logging.info(f"[rank {rank}] Triggering FR trace dump through pipe...")
|
|
104
|
+
f.write("1\n")
|
|
105
|
+
time.sleep(60)
|
|
106
|
+
return True
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def create_store_client(store_addr: str, timeout: timedelta) -> Store:
|
|
110
|
+
"""
|
|
111
|
+
Creates a PrefixStore(TCPStore(...)) client from an address in the format:
|
|
112
|
+
|
|
113
|
+
host:port/prefix
|
|
114
|
+
|
|
115
|
+
Ex: localhost:1234/my/prefix
|
|
116
|
+
"""
|
|
117
|
+
host, _, rest = store_addr.partition(":")
|
|
118
|
+
port, _, prefix = rest.partition("/")
|
|
119
|
+
|
|
120
|
+
store = TCPStore(
|
|
121
|
+
host_name=host,
|
|
122
|
+
port=int(port),
|
|
123
|
+
is_master=False,
|
|
124
|
+
wait_for_workers=False,
|
|
125
|
+
timeout=timeout,
|
|
126
|
+
)
|
|
127
|
+
store = PrefixStore(prefix, store)
|
|
128
|
+
return store
|
|
129
|
+
|
|
130
|
+
|
|
131
|
+
class ProcessGroup(BaseProcessGroup):
|
|
132
|
+
def __init__(self, *args: object, **kwargs: object) -> None:
|
|
133
|
+
# pyre-fixme[6]: got object
|
|
134
|
+
super().__init__(*args, **kwargs)
|
|
135
|
+
|
|
136
|
+
self._group_name: Optional[str] = None
|
|
137
|
+
|
|
138
|
+
# pyre-fixme[14]: inconsistent override
|
|
139
|
+
def allgather(
|
|
140
|
+
self,
|
|
141
|
+
output_tensors: List[List[torch.Tensor]],
|
|
142
|
+
input_tensor: List[torch.Tensor],
|
|
143
|
+
opts: AllgatherOptions,
|
|
144
|
+
) -> Work:
|
|
145
|
+
"""
|
|
146
|
+
Gathers tensors from the whole group in a list.
|
|
147
|
+
|
|
148
|
+
See torch.distributed.all_gather for more details.
|
|
149
|
+
"""
|
|
150
|
+
raise NotImplementedError("not implemented")
|
|
151
|
+
|
|
152
|
+
# pyre-fixme[14]: inconsistent override
|
|
153
|
+
def allgather_into_tensor_coalesced(
|
|
154
|
+
self,
|
|
155
|
+
output_tensors: List[torch.Tensor],
|
|
156
|
+
input_tensors: List[torch.Tensor],
|
|
157
|
+
opts: AllgatherOptions,
|
|
158
|
+
) -> Work:
|
|
159
|
+
"""
|
|
160
|
+
Performs an allgather operation on coalesced tensors.
|
|
161
|
+
|
|
162
|
+
See torch.distributed.allgather_coalesced for more details.
|
|
163
|
+
"""
|
|
164
|
+
raise NotImplementedError("not implemented")
|
|
165
|
+
|
|
166
|
+
# pyre-fixme[14]: inconsistent override
|
|
167
|
+
def allreduce(
|
|
168
|
+
self,
|
|
169
|
+
tensors: List[torch.Tensor],
|
|
170
|
+
opts: Union[AllreduceOptions, ReduceOp],
|
|
171
|
+
) -> Work:
|
|
172
|
+
"""
|
|
173
|
+
Reduces the tensor data across all machines in such a way that all get the final result.
|
|
174
|
+
|
|
175
|
+
See torch.distributed.all_reduce for more details.
|
|
176
|
+
"""
|
|
177
|
+
raise NotImplementedError("not implemented")
|
|
178
|
+
|
|
179
|
+
def allreduce_coalesced(
|
|
180
|
+
self,
|
|
181
|
+
tensors: List[torch.Tensor],
|
|
182
|
+
opts: AllreduceCoalescedOptions,
|
|
183
|
+
) -> Work:
|
|
184
|
+
"""
|
|
185
|
+
Performs an all_reduce operation in a coalesced manner.
|
|
186
|
+
|
|
187
|
+
See torch.distributed.all_reduce_coalesced for more details.
|
|
188
|
+
"""
|
|
189
|
+
raise NotImplementedError("not implemented")
|
|
190
|
+
|
|
191
|
+
# pyre-fixme[14]: inconsistent override
|
|
192
|
+
def alltoall_base(
|
|
193
|
+
self,
|
|
194
|
+
output_buffer: torch.Tensor,
|
|
195
|
+
input_buffer: torch.Tensor,
|
|
196
|
+
output_split_sizes: List[int],
|
|
197
|
+
input_split_sizes: List[int],
|
|
198
|
+
opts: AllToAllOptions,
|
|
199
|
+
) -> Work:
|
|
200
|
+
"""
|
|
201
|
+
Performs an all_to_all operation.
|
|
202
|
+
|
|
203
|
+
See torch.distributed.all_to_all_single for more details.
|
|
204
|
+
"""
|
|
205
|
+
raise NotImplementedError("not implemented")
|
|
206
|
+
|
|
207
|
+
# pyre-fixme[14]: inconsistent override
|
|
208
|
+
def barrier(self, opts: BarrierOptions) -> Work:
|
|
209
|
+
"""
|
|
210
|
+
Synchronizes all processes.
|
|
211
|
+
|
|
212
|
+
See torch.distributed.barrier for more details.
|
|
213
|
+
"""
|
|
214
|
+
raise NotImplementedError("not implemented")
|
|
215
|
+
|
|
216
|
+
# pyre-fixme[14]: inconsistent override
|
|
217
|
+
def broadcast(
|
|
218
|
+
self, tensor_list: List[torch.Tensor], opts: BroadcastOptions
|
|
219
|
+
) -> Work:
|
|
220
|
+
"""
|
|
221
|
+
Broadcasts the tensor to the whole group.
|
|
222
|
+
|
|
223
|
+
See torch.distributed.broadcast for more details.
|
|
224
|
+
"""
|
|
225
|
+
raise NotImplementedError("not implemented")
|
|
226
|
+
|
|
227
|
+
def broadcast_one(self, tensor: torch.Tensor, root: int) -> Work:
|
|
228
|
+
opts = BroadcastOptions()
|
|
229
|
+
opts.rootRank = root
|
|
230
|
+
return self.broadcast([tensor], opts)
|
|
231
|
+
|
|
232
|
+
# pyre-fixme[14]: inconsistent override
|
|
233
|
+
def recv(self, tensors: List[torch.Tensor], src_rank: int, tag: int) -> Work:
|
|
234
|
+
"""
|
|
235
|
+
Receives a list of tensors from the process with rank `rank`.
|
|
236
|
+
|
|
237
|
+
See torch.distributed.recv for more details.
|
|
238
|
+
"""
|
|
239
|
+
raise NotImplementedError("not implemented")
|
|
240
|
+
|
|
241
|
+
# pyre-fixme[14]: inconsistent override
|
|
242
|
+
def reduce_scatter(
|
|
243
|
+
self,
|
|
244
|
+
output_tensors: List[torch.Tensor],
|
|
245
|
+
input_tensors: List[List[torch.Tensor]],
|
|
246
|
+
opts: ReduceScatterOptions,
|
|
247
|
+
) -> Work:
|
|
248
|
+
"""
|
|
249
|
+
Reduces, then scatters a list of tensors to all processes in a group.
|
|
250
|
+
|
|
251
|
+
See torch.distributed.reduce_scatter for more details.
|
|
252
|
+
"""
|
|
253
|
+
raise NotImplementedError("not implemented")
|
|
254
|
+
|
|
255
|
+
# pyre-fixme[14]: inconsistent override
|
|
256
|
+
def reduce_scatter_tensor_coalesced(
|
|
257
|
+
self,
|
|
258
|
+
output_tensors: List[torch.Tensor],
|
|
259
|
+
input_tensors: List[torch.Tensor],
|
|
260
|
+
opts: ReduceScatterOptions,
|
|
261
|
+
) -> Work:
|
|
262
|
+
"""
|
|
263
|
+
Performs a reduce-scatter operation on coalesced tensors.
|
|
264
|
+
|
|
265
|
+
See torch.distributed.reduce_scatter_tensor for more details.
|
|
266
|
+
"""
|
|
267
|
+
raise NotImplementedError("not implemented")
|
|
268
|
+
|
|
269
|
+
# pyre-fixme[14]: inconsistent override
|
|
270
|
+
def send(self, tensors: List[torch.Tensor], dst_rank: int, tag: int) -> Work:
|
|
271
|
+
"""
|
|
272
|
+
Sends a list of tensors to the process with rank `dst_rank`.
|
|
273
|
+
|
|
274
|
+
See torch.distributed.send for more details.
|
|
275
|
+
"""
|
|
276
|
+
raise NotImplementedError("not implemented")
|
|
277
|
+
|
|
278
|
+
def configure(
|
|
279
|
+
self,
|
|
280
|
+
store_addr: str,
|
|
281
|
+
replica_id: str,
|
|
282
|
+
rank: int,
|
|
283
|
+
world_size: int,
|
|
284
|
+
quorum_id: Optional[int] = None,
|
|
285
|
+
group_rank: Optional[int] = None,
|
|
286
|
+
group_world_size: Optional[int] = None,
|
|
287
|
+
global_ranks: Optional[list[int]] = None,
|
|
288
|
+
) -> None:
|
|
289
|
+
"""
|
|
290
|
+
This reconfigures the ProcessGroup to use a new store, rank and world size.
|
|
291
|
+
|
|
292
|
+
Every time this is called it must be provided with a unique prefixed
|
|
293
|
+
store address. I.e. localhost:1234/my/prefix/1
|
|
294
|
+
|
|
295
|
+
This function will block until the underlying ProcessGroup is created.
|
|
296
|
+
If an error occurs this will throw.
|
|
297
|
+
|
|
298
|
+
Args:
|
|
299
|
+
store_addr: address of the store to use
|
|
300
|
+
replica_id: the replica_id for this group
|
|
301
|
+
rank: rank of this process
|
|
302
|
+
world_size: world size of this process group
|
|
303
|
+
quorum_id: current quorum's identifier
|
|
304
|
+
group_rank: local rank within the replica group
|
|
305
|
+
group_world_size: the number of ranks within a replica
|
|
306
|
+
global_ranks: the global ranks part of this process group
|
|
307
|
+
"""
|
|
308
|
+
raise NotImplementedError("not implemented")
|
|
309
|
+
|
|
310
|
+
def size(self) -> int:
|
|
311
|
+
raise NotImplementedError("not implemented")
|
|
312
|
+
|
|
313
|
+
def getBackendName(self) -> str:
|
|
314
|
+
raise NotImplementedError("not implemented")
|
|
315
|
+
|
|
316
|
+
def _register(self, name: str) -> str:
|
|
317
|
+
group_name = f"{self.getBackendName()}:{name}"
|
|
318
|
+
|
|
319
|
+
# This is needed for DeviceMesh and functional collectives to work.
|
|
320
|
+
# Resizable worlds don't fit well into DeviceMesh so we register a world
|
|
321
|
+
# size 1 PG.
|
|
322
|
+
|
|
323
|
+
def create_pg(
|
|
324
|
+
prefix_store: PrefixStore, rank: int, world_size: int, timeout: float
|
|
325
|
+
) -> ProcessGroup:
|
|
326
|
+
return self
|
|
327
|
+
|
|
328
|
+
devices = ["cpu"]
|
|
329
|
+
if torch.cuda.is_available():
|
|
330
|
+
devices.append("cuda")
|
|
331
|
+
elif torch.xpu.is_available():
|
|
332
|
+
devices.append("xpu")
|
|
333
|
+
dist.Backend.register_backend(group_name, create_pg, devices=devices)
|
|
334
|
+
|
|
335
|
+
return group_name
|
|
336
|
+
|
|
337
|
+
def register(self, name: str) -> "ProcessGroup":
|
|
338
|
+
"""
|
|
339
|
+
Registers the process group with the global registry. This enables usage
|
|
340
|
+
with things like functional_collectives which are compilable.
|
|
341
|
+
|
|
342
|
+
This should only be called once.
|
|
343
|
+
|
|
344
|
+
Args:
|
|
345
|
+
name: name must be a unique name for this process group
|
|
346
|
+
"""
|
|
347
|
+
|
|
348
|
+
group_name = self._register(name)
|
|
349
|
+
|
|
350
|
+
return dist.new_group(
|
|
351
|
+
ranks=[dist.get_rank()],
|
|
352
|
+
backend=group_name,
|
|
353
|
+
group_desc=group_name,
|
|
354
|
+
timeout=timedelta(seconds=60.0), # this timeout isn't used
|
|
355
|
+
)
|
|
356
|
+
|
|
357
|
+
@property
|
|
358
|
+
def group_name(self) -> str:
|
|
359
|
+
if self._group_name is None:
|
|
360
|
+
raise ValueError("ProcessGroup name not set")
|
|
361
|
+
return self._group_name
|
|
362
|
+
|
|
363
|
+
def _set_group_name(self, name: str) -> None:
|
|
364
|
+
self._group_name = name
|
|
365
|
+
|
|
366
|
+
def unregister(self) -> None:
|
|
367
|
+
"""
|
|
368
|
+
Unregisters the process group with the global registry.
|
|
369
|
+
|
|
370
|
+
Must be registered first.
|
|
371
|
+
"""
|
|
372
|
+
dist.destroy_process_group(self)
|
|
373
|
+
|
|
374
|
+
def abort(self) -> None:
|
|
375
|
+
"""
|
|
376
|
+
Aborts the process group.
|
|
377
|
+
"""
|
|
378
|
+
pass
|
|
379
|
+
|
|
380
|
+
def shutdown(self) -> None:
|
|
381
|
+
"""
|
|
382
|
+
Shuts down the process group.
|
|
383
|
+
"""
|
|
384
|
+
pass
|
|
385
|
+
|
|
386
|
+
def errored(self) -> Optional[Exception]:
|
|
387
|
+
"""
|
|
388
|
+
Whether an async error occured that requires reconfiguration.
|
|
389
|
+
"""
|
|
390
|
+
return None
|
|
391
|
+
|
|
392
|
+
def set_timeout(self, timeout: timedelta) -> None:
|
|
393
|
+
"""
|
|
394
|
+
Sets the default timeout for the process group.
|
|
395
|
+
"""
|
|
396
|
+
raise NotImplementedError("set_timeout not implemented")
|
|
397
|
+
|
|
398
|
+
def __repr__(self) -> str:
|
|
399
|
+
return f"{self.__class__.__name__}()"
|
|
400
|
+
|
|
401
|
+
|
|
402
|
+
class ProcessGroupWrapper(ProcessGroup):
|
|
403
|
+
"""
|
|
404
|
+
This is a wrapper around any ProcessGroup with a reconfiguration method.
|
|
405
|
+
|
|
406
|
+
Args:
|
|
407
|
+
timeout: timeout for reconfiguration for TCPStore
|
|
408
|
+
pg: optional ProcessGroup to use, if None a new one will be created
|
|
409
|
+
"""
|
|
410
|
+
|
|
411
|
+
def __init__(
|
|
412
|
+
self,
|
|
413
|
+
timeout: timedelta = timedelta(seconds=60),
|
|
414
|
+
pg: Optional[ProcessGroup] = None,
|
|
415
|
+
) -> None:
|
|
416
|
+
super().__init__(0, 1)
|
|
417
|
+
self._pg: Optional[BaseProcessGroup] = pg
|
|
418
|
+
self._timeout = timeout
|
|
419
|
+
self._replica_id: str | None = None
|
|
420
|
+
self._rank: int | None = None
|
|
421
|
+
self._quorum_id: int | None = None
|
|
422
|
+
self._group_rank: int | None = None
|
|
423
|
+
self._group_world_size: int | None = None
|
|
424
|
+
self._global_ranks: list[int] | None = None
|
|
425
|
+
|
|
426
|
+
self.errors_logger: logging.Logger = logging.getLogger("torchft_errors")
|
|
427
|
+
|
|
428
|
+
def getBackendName(self) -> str:
|
|
429
|
+
pg = self._pg
|
|
430
|
+
if isinstance(pg, ProcessGroup):
|
|
431
|
+
return pg.getBackendName()
|
|
432
|
+
|
|
433
|
+
raise NotImplementedError("not implemented")
|
|
434
|
+
|
|
435
|
+
def configure(
|
|
436
|
+
self,
|
|
437
|
+
store_addr: str,
|
|
438
|
+
replica_id: str,
|
|
439
|
+
rank: int,
|
|
440
|
+
world_size: int,
|
|
441
|
+
quorum_id: Optional[int] = None,
|
|
442
|
+
group_rank: Optional[int] = None,
|
|
443
|
+
group_world_size: Optional[int] = None,
|
|
444
|
+
global_ranks: Optional[list[int]] = None,
|
|
445
|
+
) -> None:
|
|
446
|
+
pg = self._pg
|
|
447
|
+
self._replica_id = replica_id
|
|
448
|
+
self._quorum_id = quorum_id
|
|
449
|
+
self._group_rank = group_rank
|
|
450
|
+
self._group_world_size = group_world_size
|
|
451
|
+
self._rank = rank
|
|
452
|
+
self._global_ranks = global_ranks
|
|
453
|
+
if isinstance(pg, ProcessGroup):
|
|
454
|
+
pg.configure(
|
|
455
|
+
store_addr,
|
|
456
|
+
replica_id,
|
|
457
|
+
rank,
|
|
458
|
+
world_size,
|
|
459
|
+
quorum_id,
|
|
460
|
+
group_rank,
|
|
461
|
+
group_world_size,
|
|
462
|
+
global_ranks,
|
|
463
|
+
)
|
|
464
|
+
return
|
|
465
|
+
|
|
466
|
+
# abort if already initialized
|
|
467
|
+
self.abort(errored=False)
|
|
468
|
+
|
|
469
|
+
store = create_store_client(store_addr, timeout=self._timeout)
|
|
470
|
+
|
|
471
|
+
self._pg = self._create_pg(store, rank, world_size)
|
|
472
|
+
|
|
473
|
+
def abort(self, errored: bool = True) -> None:
|
|
474
|
+
if errored:
|
|
475
|
+
self.errors_logger.info(
|
|
476
|
+
"",
|
|
477
|
+
extra={
|
|
478
|
+
"job_id": os.environ.get("JOB_ID", "unknown"),
|
|
479
|
+
"replica_id": self._replica_id,
|
|
480
|
+
"rank": self._rank,
|
|
481
|
+
"quorum_id": self._quorum_id,
|
|
482
|
+
"error": "process_group_abort",
|
|
483
|
+
},
|
|
484
|
+
)
|
|
485
|
+
pg = self._pg
|
|
486
|
+
if pg is not None:
|
|
487
|
+
if hasattr(pg, "abort"):
|
|
488
|
+
pg.abort()
|
|
489
|
+
else:
|
|
490
|
+
backend = None
|
|
491
|
+
try:
|
|
492
|
+
if torch.cuda.is_available():
|
|
493
|
+
backend = pg._get_backend(torch.device("cuda"))
|
|
494
|
+
elif torch.xpu.is_available():
|
|
495
|
+
backend = pg._get_backend(torch.device("xpu"))
|
|
496
|
+
except RuntimeError:
|
|
497
|
+
backend = None
|
|
498
|
+
if backend is not None and hasattr(backend, "abort"):
|
|
499
|
+
backend.abort()
|
|
500
|
+
|
|
501
|
+
self._pg = None
|
|
502
|
+
|
|
503
|
+
def shutdown(self) -> None:
|
|
504
|
+
# TODO: abort PG if possible
|
|
505
|
+
self._pg = None
|
|
506
|
+
|
|
507
|
+
def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGroup:
|
|
508
|
+
raise NotImplementedError("not implemented")
|
|
509
|
+
|
|
510
|
+
def _wrap_work(self, work: Work, opts: object) -> Work:
|
|
511
|
+
return work
|
|
512
|
+
|
|
513
|
+
def _opts_hook(self, opts: T) -> T:
|
|
514
|
+
return opts
|
|
515
|
+
|
|
516
|
+
@contextmanager
|
|
517
|
+
def _run_context(self) -> Generator[None, None, None]:
|
|
518
|
+
yield
|
|
519
|
+
|
|
520
|
+
def set_timeout(self, timeout: timedelta) -> None:
|
|
521
|
+
self._timeout = timeout
|
|
522
|
+
|
|
523
|
+
def allgather(
|
|
524
|
+
self,
|
|
525
|
+
output_tensors: List[List[torch.Tensor]],
|
|
526
|
+
input_tensor: List[torch.Tensor],
|
|
527
|
+
opts: AllgatherOptions,
|
|
528
|
+
) -> Work:
|
|
529
|
+
with self._run_context():
|
|
530
|
+
return self._wrap_work(
|
|
531
|
+
self.parent.allgather(
|
|
532
|
+
output_tensors, input_tensor, self._opts_hook(opts)
|
|
533
|
+
),
|
|
534
|
+
opts,
|
|
535
|
+
)
|
|
536
|
+
|
|
537
|
+
def allgather_into_tensor_coalesced(
|
|
538
|
+
self,
|
|
539
|
+
output_tensors: List[torch.Tensor],
|
|
540
|
+
input_tensors: List[torch.Tensor],
|
|
541
|
+
opts: AllgatherOptions,
|
|
542
|
+
) -> Work:
|
|
543
|
+
with self._run_context():
|
|
544
|
+
return self._wrap_work(
|
|
545
|
+
self.parent.allgather_into_tensor_coalesced(
|
|
546
|
+
output_tensors, input_tensors, self._opts_hook(opts)
|
|
547
|
+
),
|
|
548
|
+
opts,
|
|
549
|
+
)
|
|
550
|
+
|
|
551
|
+
def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
|
|
552
|
+
with self._run_context():
|
|
553
|
+
return self._wrap_work(
|
|
554
|
+
self.parent.allreduce(tensors, self._opts_hook(opts)), opts
|
|
555
|
+
)
|
|
556
|
+
|
|
557
|
+
def allreduce_coalesced(
|
|
558
|
+
self, tensors: List[torch.Tensor], opts: Union[AllreduceOptions, ReduceOp]
|
|
559
|
+
) -> Work:
|
|
560
|
+
with self._run_context():
|
|
561
|
+
return self._wrap_work(
|
|
562
|
+
self.parent.allreduce_coalesced(tensors, self._opts_hook(opts)), opts
|
|
563
|
+
)
|
|
564
|
+
|
|
565
|
+
def alltoall_base(
|
|
566
|
+
self,
|
|
567
|
+
output_buffer: torch.Tensor,
|
|
568
|
+
input_buffer: torch.Tensor,
|
|
569
|
+
output_split_sizes: List[int],
|
|
570
|
+
input_split_sizes: List[int],
|
|
571
|
+
opts: AllToAllOptions,
|
|
572
|
+
) -> Work:
|
|
573
|
+
with self._run_context():
|
|
574
|
+
return self._wrap_work(
|
|
575
|
+
self.parent.alltoall_base(
|
|
576
|
+
output_buffer,
|
|
577
|
+
input_buffer,
|
|
578
|
+
output_split_sizes,
|
|
579
|
+
input_split_sizes,
|
|
580
|
+
self._opts_hook(opts),
|
|
581
|
+
),
|
|
582
|
+
opts,
|
|
583
|
+
)
|
|
584
|
+
|
|
585
|
+
def barrier(self, opts: Optional[BarrierOptions] = None) -> Work:
|
|
586
|
+
with self._run_context():
|
|
587
|
+
return self._wrap_work(self.parent.barrier(self._opts_hook(opts)), opts)
|
|
588
|
+
|
|
589
|
+
def broadcast(self, tensor_list: List[torch.Tensor], opts: object) -> Work:
|
|
590
|
+
with self._run_context():
|
|
591
|
+
return self._wrap_work(
|
|
592
|
+
self.parent.broadcast(tensor_list, self._opts_hook(opts)), opts
|
|
593
|
+
)
|
|
594
|
+
|
|
595
|
+
def recv(self, tensors: List[torch.Tensor], src_rank: int, tag: int) -> Work:
|
|
596
|
+
with self._run_context():
|
|
597
|
+
return self._wrap_work(self.parent.recv(tensors, src_rank, tag), None)
|
|
598
|
+
|
|
599
|
+
def reduce_scatter(
|
|
600
|
+
self,
|
|
601
|
+
output_tensors: List[torch.Tensor],
|
|
602
|
+
input_tensors: List[List[torch.Tensor]],
|
|
603
|
+
opts: object,
|
|
604
|
+
) -> Work:
|
|
605
|
+
with self._run_context():
|
|
606
|
+
return self._wrap_work(
|
|
607
|
+
self.parent.reduce_scatter(
|
|
608
|
+
output_tensors, input_tensors, self._opts_hook(opts)
|
|
609
|
+
),
|
|
610
|
+
opts,
|
|
611
|
+
)
|
|
612
|
+
|
|
613
|
+
def reduce_scatter_tensor_coalesced(
|
|
614
|
+
self,
|
|
615
|
+
output_tensors: List[torch.Tensor],
|
|
616
|
+
input_tensors: List[torch.Tensor],
|
|
617
|
+
opts: ReduceScatterOptions,
|
|
618
|
+
) -> Work:
|
|
619
|
+
with self._run_context():
|
|
620
|
+
return self._wrap_work(
|
|
621
|
+
self.parent.reduce_scatter_tensor_coalesced(
|
|
622
|
+
output_tensors, input_tensors, self._opts_hook(opts)
|
|
623
|
+
),
|
|
624
|
+
opts,
|
|
625
|
+
)
|
|
626
|
+
|
|
627
|
+
def send(self, tensors: List[torch.Tensor], dst_rank: int, tag: int) -> Work:
|
|
628
|
+
with self._run_context():
|
|
629
|
+
return self._wrap_work(self.parent.send(tensors, dst_rank, tag), None)
|
|
630
|
+
|
|
631
|
+
def size(self) -> int:
|
|
632
|
+
return self.parent.size()
|
|
633
|
+
|
|
634
|
+
@property
|
|
635
|
+
def parent(self) -> BaseProcessGroup:
|
|
636
|
+
assert self._pg is not None, "process group not initialized"
|
|
637
|
+
return self._pg
|
|
638
|
+
|
|
639
|
+
def __repr__(self) -> str:
|
|
640
|
+
return f"{self.__class__.__name__}(pg={self._pg})"
|
|
641
|
+
|
|
642
|
+
|
|
643
|
+
class ProcessGroupGloo(ProcessGroupWrapper):
|
|
644
|
+
"""
|
|
645
|
+
This is a reconfigurable version of ProcessGroupGloo.
|
|
646
|
+
"""
|
|
647
|
+
|
|
648
|
+
def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGroup:
|
|
649
|
+
pg = BaseProcessGroup(store, rank, world_size)
|
|
650
|
+
pg._set_default_backend(ProcessGroup.BackendType.GLOO)
|
|
651
|
+
# pyre-fixme[16]: no attribute ProcessGroupGloo
|
|
652
|
+
backend_class = BaseProcessGroupGloo(store, rank, world_size, self._timeout)
|
|
653
|
+
backend_class._set_sequence_number_for_group()
|
|
654
|
+
|
|
655
|
+
if self._global_ranks:
|
|
656
|
+
backend_class.options.global_ranks_in_group = self._global_ranks
|
|
657
|
+
if self._group_rank and self._group_world_size:
|
|
658
|
+
backend_class.options.group_name = f"torchft_quorum_{self._quorum_id}_rank_{self._group_rank % self._group_world_size}"
|
|
659
|
+
|
|
660
|
+
pg._register_backend(
|
|
661
|
+
torch.device("cpu"), ProcessGroup.BackendType.GLOO, backend_class
|
|
662
|
+
)
|
|
663
|
+
if torch.cuda.is_available():
|
|
664
|
+
pg._register_backend(
|
|
665
|
+
torch.device("cuda"), ProcessGroup.BackendType.GLOO, backend_class
|
|
666
|
+
)
|
|
667
|
+
return pg
|
|
668
|
+
|
|
669
|
+
def getBackendName(self) -> str:
|
|
670
|
+
return "torchft-gloo"
|
|
671
|
+
|
|
672
|
+
# pyre-fixme[14,15]: inconsistent override
|
|
673
|
+
def reduce_scatter(
|
|
674
|
+
self,
|
|
675
|
+
output_tensors: List[torch.Tensor],
|
|
676
|
+
input_tensors: List[List[torch.Tensor]],
|
|
677
|
+
opts: ReduceScatterOptions,
|
|
678
|
+
) -> None:
|
|
679
|
+
"""
|
|
680
|
+
This function is a placeholder for the reduce_scatter operation in the
|
|
681
|
+
ProcessGroupGloo class. However, this operation is not supported by the
|
|
682
|
+
Gloo backend, and thus, calling this function will raise a
|
|
683
|
+
RuntimeError.
|
|
684
|
+
|
|
685
|
+
Raises:
|
|
686
|
+
RuntimeError: Always raised since reduce_scatter is not
|
|
687
|
+
supported by ProcessGroupGloo.
|
|
688
|
+
"""
|
|
689
|
+
raise RuntimeError("ProcessGroupGloo does not support reduce_scatter.")
|
|
690
|
+
|
|
691
|
+
# pyre-fixme[15]: inconsistent override
|
|
692
|
+
def reduce_scatter_tensor_coalesced(
|
|
693
|
+
self,
|
|
694
|
+
output_tensors: List[torch.Tensor],
|
|
695
|
+
input_tensors: List[torch.Tensor],
|
|
696
|
+
opts: ReduceScatterOptions,
|
|
697
|
+
) -> None:
|
|
698
|
+
"""
|
|
699
|
+
This function is a placeholder for the reduce_scatter_tensor_coalesced
|
|
700
|
+
operation in the ProcessGroupGloo class.
|
|
701
|
+
However, this operation is not supported by the
|
|
702
|
+
Gloo backend, and thus, calling this function will raise a
|
|
703
|
+
RuntimeError.
|
|
704
|
+
|
|
705
|
+
Raises:
|
|
706
|
+
RuntimeError: Always raised since reduce_scatter is not
|
|
707
|
+
supported by ProcessGroupGloo.
|
|
708
|
+
"""
|
|
709
|
+
raise RuntimeError(
|
|
710
|
+
"ProcessGroupGloo does not support reduce_scatter_tensor_coalesced."
|
|
711
|
+
)
|
|
712
|
+
|
|
713
|
+
|
|
714
|
+
class _WorkAcceleratorTimeout(Work):
|
|
715
|
+
def __init__(self, pg: ProcessGroup, work: Work, timeout: timedelta) -> None:
|
|
716
|
+
super().__init__()
|
|
717
|
+
self._pg = pg
|
|
718
|
+
self._work = work
|
|
719
|
+
self._timeout = timeout
|
|
720
|
+
|
|
721
|
+
def wait(self, timeout: Optional[timedelta] = None) -> bool:
|
|
722
|
+
async_timeout = timeout or self._timeout
|
|
723
|
+
with self._stream_timeout(self._pg, async_timeout):
|
|
724
|
+
# In newer versions of PyTorch work may not exist if the call was
|
|
725
|
+
# not async. In these cases we can just schedule the stream timeout
|
|
726
|
+
# and return.
|
|
727
|
+
if self._work is not None:
|
|
728
|
+
if not self._work.wait():
|
|
729
|
+
return False
|
|
730
|
+
|
|
731
|
+
# Always use cuda stream for timeout to avoid ProcessGroupNCCL
|
|
732
|
+
# watchdog firing and crashing the process.
|
|
733
|
+
if timeout is not None:
|
|
734
|
+
torch.cuda.synchronize()
|
|
735
|
+
|
|
736
|
+
return True
|
|
737
|
+
|
|
738
|
+
@classmethod
|
|
739
|
+
@contextmanager
|
|
740
|
+
def _stream_timeout(
|
|
741
|
+
cls, pg: ProcessGroup, timeout: timedelta
|
|
742
|
+
) -> Generator[None, None, None]:
|
|
743
|
+
"""
|
|
744
|
+
Set a timeout on the CUDA stream for the given process group.
|
|
745
|
+
|
|
746
|
+
This does not hold a reference to self to avoid holding the work
|
|
747
|
+
object/tensors longer than necessary.
|
|
748
|
+
|
|
749
|
+
Args:
|
|
750
|
+
pg: The process group to call abort on.
|
|
751
|
+
timeout: The timeout to set on the CUDA stream.
|
|
752
|
+
"""
|
|
753
|
+
|
|
754
|
+
def callback() -> None:
|
|
755
|
+
logger.error(f"aborting after {timeout}!")
|
|
756
|
+
pg.abort()
|
|
757
|
+
|
|
758
|
+
# make sure .wait() can be cancelled if it blocks i.e. in barrier
|
|
759
|
+
with context_timeout(callback, timeout):
|
|
760
|
+
yield
|
|
761
|
+
|
|
762
|
+
# Cancel work if the cuda stream doesn't complete
|
|
763
|
+
stream_timeout(callback, timeout)
|
|
764
|
+
|
|
765
|
+
def get_future(self) -> torch.futures.Future[object]:
|
|
766
|
+
fut = self._work.get_future()
|
|
767
|
+
|
|
768
|
+
def done_callback(fut: torch.futures.Future[object]) -> None:
|
|
769
|
+
try:
|
|
770
|
+
with self._stream_timeout(self._pg, self._timeout):
|
|
771
|
+
fut.wait()
|
|
772
|
+
|
|
773
|
+
except Exception as e:
|
|
774
|
+
logger.error(f"done callback failed: {e}")
|
|
775
|
+
|
|
776
|
+
fut.add_done_callback(done_callback)
|
|
777
|
+
return fut
|
|
778
|
+
|
|
779
|
+
|
|
780
|
+
class ProcessGroupNCCL(ProcessGroupWrapper):
|
|
781
|
+
"""
|
|
782
|
+
This is a reconfigurable version of ProcessGroupNCCL.
|
|
783
|
+
|
|
784
|
+
If you are using a supported version of NCCL (NCCL >= 2.26, torch >= 2.7)
|
|
785
|
+
this will attempt to use ncclCommAbort to recover from any timeouts.
|
|
786
|
+
|
|
787
|
+
This uses a Python user space event loop to asynchronously wait for the NCCL
|
|
788
|
+
operations to complete. This should not be used with very long timeouts as
|
|
789
|
+
the timeout entries are not cleaned up until the elapsed duration completes
|
|
790
|
+
which may result in slowness or excess memory usage.
|
|
791
|
+
|
|
792
|
+
WARNING: this may result in deadlocks due to NCCL error handling and on old
|
|
793
|
+
versions of torch/NCCL will result in deadlocks.
|
|
794
|
+
|
|
795
|
+
Args:
|
|
796
|
+
timeout: the timeout to use for NCCL operations.
|
|
797
|
+
"""
|
|
798
|
+
|
|
799
|
+
def __init__(self, timeout: timedelta = timedelta(seconds=60.0)) -> None:
|
|
800
|
+
super().__init__(timeout)
|
|
801
|
+
self._use_abort: bool = torch.cuda.nccl.version() >= (2, 25)
|
|
802
|
+
|
|
803
|
+
self._errored: Optional[Exception] = None
|
|
804
|
+
|
|
805
|
+
NONBLOCKING_TIMEOUT_ENV = "TORCH_NCCL_NONBLOCKING_TIMEOUT"
|
|
806
|
+
if NONBLOCKING_TIMEOUT_ENV not in os.environ:
|
|
807
|
+
warnings.warn(
|
|
808
|
+
f"{NONBLOCKING_TIMEOUT_ENV} is not set, defaulting to {timeout}. "
|
|
809
|
+
"If any nonblocking NCCL operations have already run this may "
|
|
810
|
+
"result in the default timeout of 30 minutes and hangs on error."
|
|
811
|
+
)
|
|
812
|
+
os.environ[NONBLOCKING_TIMEOUT_ENV] = str(timeout.total_seconds())
|
|
813
|
+
|
|
814
|
+
def _opts_hook(self, opts: T) -> T:
|
|
815
|
+
if not self._use_abort:
|
|
816
|
+
return opts
|
|
817
|
+
|
|
818
|
+
# We need to clear the timeout to apply our own timeout that doesn't
|
|
819
|
+
# crash the whole program.
|
|
820
|
+
if hasattr(opts, "timeout"):
|
|
821
|
+
# apply default timeout to disable
|
|
822
|
+
opts.timeout = AllgatherOptions().timeout
|
|
823
|
+
return opts
|
|
824
|
+
|
|
825
|
+
def _wrap_work(self, work: Work, opts: object) -> Work:
|
|
826
|
+
if not self._use_abort:
|
|
827
|
+
return work
|
|
828
|
+
|
|
829
|
+
timeout = self._timeout
|
|
830
|
+
# pyre-fixme[16]: no attribute timeout
|
|
831
|
+
if hasattr(opts, "timeout") and opts.timeout.total_seconds() > 0:
|
|
832
|
+
timeout = opts.timeout
|
|
833
|
+
return _WorkAcceleratorTimeout(self, work, timeout)
|
|
834
|
+
|
|
835
|
+
@contextmanager
|
|
836
|
+
def _run_context(self) -> Generator[None, None, None]:
|
|
837
|
+
timeout: timedelta = self._timeout
|
|
838
|
+
|
|
839
|
+
def callback() -> None:
|
|
840
|
+
logger.error(f"aborting after {timeout}!")
|
|
841
|
+
self.abort()
|
|
842
|
+
|
|
843
|
+
# when running in blocking mode we need to make sure collectives can
|
|
844
|
+
# timeout
|
|
845
|
+
with context_timeout(callback, timeout):
|
|
846
|
+
yield
|
|
847
|
+
|
|
848
|
+
def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGroup:
|
|
849
|
+
# pyre-fixme[21]: no attribute ProcessGroupNCCL
|
|
850
|
+
from torch.distributed import ProcessGroupNCCL as BaseProcessGroupNCCL
|
|
851
|
+
|
|
852
|
+
self._errored = None
|
|
853
|
+
|
|
854
|
+
# pyre-fixme[16]: no attribute ProcessGroupNCCL
|
|
855
|
+
opts = BaseProcessGroupNCCL.Options()
|
|
856
|
+
opts.config.blocking = False
|
|
857
|
+
if self._global_ranks:
|
|
858
|
+
opts.global_ranks_in_group = self._global_ranks
|
|
859
|
+
if self._group_rank and self._group_world_size:
|
|
860
|
+
opts.group_name = f"torchft_quorum_{self._quorum_id}_rank_{self._group_rank % self._group_world_size}"
|
|
861
|
+
|
|
862
|
+
pg = BaseProcessGroup(store, rank, world_size)
|
|
863
|
+
pg._set_default_backend(ProcessGroup.BackendType.NCCL)
|
|
864
|
+
# pyre-fixme[16]: no attribute ProcessGroupNCCL
|
|
865
|
+
backend_class = BaseProcessGroupNCCL(store, rank, world_size, opts)
|
|
866
|
+
backend_class._set_sequence_number_for_group()
|
|
867
|
+
backend_class.eager_connect_single_device(
|
|
868
|
+
torch.device(torch.accelerator.current_device_index())
|
|
869
|
+
)
|
|
870
|
+
pg._register_backend(
|
|
871
|
+
torch.device("cuda"), ProcessGroup.BackendType.NCCL, backend_class
|
|
872
|
+
)
|
|
873
|
+
return pg
|
|
874
|
+
|
|
875
|
+
def abort(self, errored: bool = True) -> None:
|
|
876
|
+
# We need to set the error before aborting to ensure that errored()
|
|
877
|
+
# returns the error correctly when NCCL abort fires and unblocks the
|
|
878
|
+
# stream.
|
|
879
|
+
if os.environ.get("TORCHFT_TRIGGER_FR_ON_ABORT", "false") == "true":
|
|
880
|
+
trigger_nccl_fr_trace_through_pipe(dist.get_rank())
|
|
881
|
+
self._errored = RuntimeError("aborted")
|
|
882
|
+
|
|
883
|
+
super().abort(errored=errored)
|
|
884
|
+
|
|
885
|
+
def errored(self) -> Optional[Exception]:
|
|
886
|
+
# force a synchronization to ensure all work is complete
|
|
887
|
+
synchronize()
|
|
888
|
+
return self._errored
|
|
889
|
+
|
|
890
|
+
def getBackendName(self) -> str:
|
|
891
|
+
return "torchft-nccl"
|
|
892
|
+
|
|
893
|
+
|
|
894
|
+
class ProcessGroupXCCL(ProcessGroupWrapper):
|
|
895
|
+
"""
|
|
896
|
+
This is a reconfigurable version of ProcessGroupXCCL for Intel XPU devices.
|
|
897
|
+
|
|
898
|
+
This process group is designed to work with Intel XPU devices using XCCL
|
|
899
|
+
(eXtended Collective Communication Library). It provides similar functionality
|
|
900
|
+
to ProcessGroupNCCL but optimized for Intel XPU architecture.
|
|
901
|
+
|
|
902
|
+
If you are using a supported version of XCCL, this will attempt to use
|
|
903
|
+
xccl abort mechanisms to recover from any timeouts.
|
|
904
|
+
|
|
905
|
+
This uses a Python user space event loop to asynchronously wait for the XCCL
|
|
906
|
+
operations to complete. This should not be used with very long timeouts as
|
|
907
|
+
the timeout entries are not cleaned up until the elapsed duration completes
|
|
908
|
+
which may result in slowness or excess memory usage.
|
|
909
|
+
|
|
910
|
+
Args:
|
|
911
|
+
timeout: the timeout to use for XCCL operations.
|
|
912
|
+
"""
|
|
913
|
+
|
|
914
|
+
def __init__(self, timeout: timedelta = timedelta(seconds=60.0)) -> None:
|
|
915
|
+
super().__init__(timeout)
|
|
916
|
+
# Check if XPU is available and XCCL is supported
|
|
917
|
+
self._use_abort: bool = torch.xpu.is_available()
|
|
918
|
+
|
|
919
|
+
self._errored: Optional[Exception] = None
|
|
920
|
+
|
|
921
|
+
NONBLOCKING_TIMEOUT_ENV = "TORCH_XCCL_NONBLOCKING_TIMEOUT"
|
|
922
|
+
if NONBLOCKING_TIMEOUT_ENV not in os.environ:
|
|
923
|
+
warnings.warn(
|
|
924
|
+
f"{NONBLOCKING_TIMEOUT_ENV} is not set, defaulting to {timeout}. "
|
|
925
|
+
"If any nonblocking XCCL operations have already run this may "
|
|
926
|
+
"result in the default timeout of 30 minutes and hangs on error."
|
|
927
|
+
)
|
|
928
|
+
os.environ[NONBLOCKING_TIMEOUT_ENV] = str(timeout.total_seconds())
|
|
929
|
+
|
|
930
|
+
def _opts_hook(self, opts: T) -> T:
|
|
931
|
+
if not self._use_abort:
|
|
932
|
+
return opts
|
|
933
|
+
|
|
934
|
+
# We need to clear the timeout to apply our own timeout that doesn't
|
|
935
|
+
# crash the whole program.
|
|
936
|
+
if hasattr(opts, "timeout"):
|
|
937
|
+
# apply default timeout to disable
|
|
938
|
+
opts.timeout = AllgatherOptions().timeout
|
|
939
|
+
return opts
|
|
940
|
+
|
|
941
|
+
def _wrap_work(self, work: Work, opts: object) -> Work:
|
|
942
|
+
if not self._use_abort:
|
|
943
|
+
return work
|
|
944
|
+
|
|
945
|
+
timeout = self._timeout
|
|
946
|
+
# pyre-fixme[16]: no attribute timeout
|
|
947
|
+
if hasattr(opts, "timeout") and opts.timeout.total_seconds() > 0:
|
|
948
|
+
timeout = opts.timeout
|
|
949
|
+
return _WorkAcceleratorTimeout(self, work, timeout)
|
|
950
|
+
|
|
951
|
+
@contextmanager
|
|
952
|
+
def _run_context(self) -> Generator[None, None, None]:
|
|
953
|
+
timeout: timedelta = self._timeout
|
|
954
|
+
|
|
955
|
+
def callback() -> None:
|
|
956
|
+
logger.error(f"aborting after {timeout}!")
|
|
957
|
+
self.abort()
|
|
958
|
+
|
|
959
|
+
# when running in blocking mode we need to make sure collectives can
|
|
960
|
+
# timeout
|
|
961
|
+
with context_timeout(callback, timeout):
|
|
962
|
+
yield
|
|
963
|
+
|
|
964
|
+
def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGroup:
|
|
965
|
+
# pyre-fixme[21]: no attribute ProcessGroupXCCL
|
|
966
|
+
from torch.distributed import ProcessGroupXCCL as BaseProcessGroupXCCL
|
|
967
|
+
|
|
968
|
+
self._errored = None
|
|
969
|
+
|
|
970
|
+
# pyre-fixme[16]: no attribute ProcessGroupXCCL
|
|
971
|
+
opts = BaseProcessGroupXCCL.Options()
|
|
972
|
+
# opts.config.blocking = False
|
|
973
|
+
|
|
974
|
+
pg = BaseProcessGroup(store, rank, world_size)
|
|
975
|
+
pg._set_default_backend(ProcessGroup.BackendType.XCCL)
|
|
976
|
+
# pyre-fixme[16]: no attribute ProcessGroupXCCL
|
|
977
|
+
backend_class = BaseProcessGroupXCCL(store, rank, world_size, opts)
|
|
978
|
+
backend_class._set_sequence_number_for_group()
|
|
979
|
+
backend_class.eager_connect_single_device(
|
|
980
|
+
torch.device(torch.accelerator.current_device_index())
|
|
981
|
+
)
|
|
982
|
+
pg._register_backend(
|
|
983
|
+
torch.device("xpu"), ProcessGroup.BackendType.XCCL, backend_class
|
|
984
|
+
)
|
|
985
|
+
return pg
|
|
986
|
+
|
|
987
|
+
def abort(self, errored: bool = True) -> None:
|
|
988
|
+
# We need to set the error before aborting to ensure that errored()
|
|
989
|
+
# returns the error correctly when XCCL abort fires and unblocks the
|
|
990
|
+
# stream.
|
|
991
|
+
self._errored = RuntimeError("aborted")
|
|
992
|
+
|
|
993
|
+
super().abort(errored)
|
|
994
|
+
|
|
995
|
+
def errored(self) -> Optional[Exception]:
|
|
996
|
+
# force a synchronization to ensure all work is complete
|
|
997
|
+
torch.xpu.current_stream().synchronize()
|
|
998
|
+
|
|
999
|
+
return self._errored
|
|
1000
|
+
|
|
1001
|
+
def getBackendName(self) -> str:
|
|
1002
|
+
return "torchft-xccl"
|
|
1003
|
+
|
|
1004
|
+
|
|
1005
|
+
class ProcessGroupDummy(ProcessGroup):
|
|
1006
|
+
"""
|
|
1007
|
+
This process group discards all data passed to it and returns success. This
|
|
1008
|
+
is intended for rare cases where we want to discard certain operations
|
|
1009
|
+
without modifying the underlying library.
|
|
1010
|
+
|
|
1011
|
+
This PG only supports world_size of 1.
|
|
1012
|
+
"""
|
|
1013
|
+
|
|
1014
|
+
def __init__(self, rank: int, world: int) -> None:
|
|
1015
|
+
super().__init__(rank, world)
|
|
1016
|
+
assert rank == 0
|
|
1017
|
+
assert world == 1
|
|
1018
|
+
|
|
1019
|
+
self._rank = rank
|
|
1020
|
+
self._world = world
|
|
1021
|
+
self.wait_count = 0
|
|
1022
|
+
self.get_future_count = 0
|
|
1023
|
+
self._work: List[Work] = []
|
|
1024
|
+
self.configure_count = 0
|
|
1025
|
+
|
|
1026
|
+
def configure(
|
|
1027
|
+
self,
|
|
1028
|
+
store_addr: str,
|
|
1029
|
+
replica_id: str,
|
|
1030
|
+
rank: int,
|
|
1031
|
+
world_size: int,
|
|
1032
|
+
quorum_id: Optional[int] = None,
|
|
1033
|
+
group_rank: Optional[int] = None,
|
|
1034
|
+
group_world_size: Optional[int] = None,
|
|
1035
|
+
global_ranks: Optional[list[int]] = None,
|
|
1036
|
+
) -> None:
|
|
1037
|
+
self.configure_count += 1
|
|
1038
|
+
|
|
1039
|
+
def allgather(
|
|
1040
|
+
self,
|
|
1041
|
+
output_tensors: List[List[torch.Tensor]],
|
|
1042
|
+
input_tensor: List[torch.Tensor],
|
|
1043
|
+
opts: object,
|
|
1044
|
+
) -> Work:
|
|
1045
|
+
for o, i in zip(output_tensors[0], input_tensor):
|
|
1046
|
+
o.copy_(i)
|
|
1047
|
+
|
|
1048
|
+
res = _DummyWork(output_tensors)
|
|
1049
|
+
self._work.append(res)
|
|
1050
|
+
return res
|
|
1051
|
+
|
|
1052
|
+
def allgather_into_tensor_coalesced(
|
|
1053
|
+
self,
|
|
1054
|
+
output_tensors: List[torch.Tensor],
|
|
1055
|
+
input_tensors: List[torch.Tensor],
|
|
1056
|
+
opts: AllgatherOptions,
|
|
1057
|
+
) -> Work:
|
|
1058
|
+
for o, i in zip(output_tensors, input_tensors):
|
|
1059
|
+
o.copy_(i)
|
|
1060
|
+
|
|
1061
|
+
res = _DummyWork(output_tensors)
|
|
1062
|
+
self._work.append(res)
|
|
1063
|
+
return res
|
|
1064
|
+
|
|
1065
|
+
def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
|
|
1066
|
+
res = _DummyWork(tensors)
|
|
1067
|
+
self._work.append(res)
|
|
1068
|
+
return res
|
|
1069
|
+
|
|
1070
|
+
def allreduce_coalesced(
|
|
1071
|
+
self, tensors: List[torch.Tensor], opts: Union[AllreduceOptions, ReduceOp]
|
|
1072
|
+
) -> Work:
|
|
1073
|
+
res = _DummyWork(tensors)
|
|
1074
|
+
self._work.append(res)
|
|
1075
|
+
return res
|
|
1076
|
+
|
|
1077
|
+
def alltoall_base(
|
|
1078
|
+
self,
|
|
1079
|
+
output_buffer: torch.Tensor,
|
|
1080
|
+
input_buffer: torch.Tensor,
|
|
1081
|
+
output_split_sizes: List[int],
|
|
1082
|
+
input_split_sizes: List[int],
|
|
1083
|
+
opts: AllToAllOptions,
|
|
1084
|
+
) -> Work:
|
|
1085
|
+
output_buffer.copy_(input_buffer)
|
|
1086
|
+
res = _DummyWork([output_buffer])
|
|
1087
|
+
self._work.append(res)
|
|
1088
|
+
return res
|
|
1089
|
+
|
|
1090
|
+
def barrier(self, opts: Optional[BarrierOptions] = None) -> Work:
|
|
1091
|
+
return _DummyWork(None)
|
|
1092
|
+
|
|
1093
|
+
def broadcast(self, tensor_list: List[torch.Tensor], opts: object) -> Work:
|
|
1094
|
+
res = _DummyWork(tensor_list)
|
|
1095
|
+
self._work.append(res)
|
|
1096
|
+
return res
|
|
1097
|
+
|
|
1098
|
+
def recv(self, tensors: List[torch.Tensor], src_rank: int, tag: int) -> Work:
|
|
1099
|
+
return _DummyWork(None)
|
|
1100
|
+
|
|
1101
|
+
def reduce_scatter(
|
|
1102
|
+
self,
|
|
1103
|
+
output_tensors: List[torch.Tensor],
|
|
1104
|
+
input_tensors: List[List[torch.Tensor]],
|
|
1105
|
+
opts: object,
|
|
1106
|
+
) -> Work:
|
|
1107
|
+
for o, i in zip(output_tensors, input_tensors[0]):
|
|
1108
|
+
o.copy_(i)
|
|
1109
|
+
|
|
1110
|
+
res = _DummyWork(output_tensors)
|
|
1111
|
+
self._work.append(res)
|
|
1112
|
+
return res
|
|
1113
|
+
|
|
1114
|
+
def reduce_scatter_tensor_coalesced(
|
|
1115
|
+
self,
|
|
1116
|
+
output_tensors: List[torch.Tensor],
|
|
1117
|
+
input_tensors: List[torch.Tensor],
|
|
1118
|
+
opts: ReduceScatterOptions,
|
|
1119
|
+
) -> Work:
|
|
1120
|
+
for o, i in zip(output_tensors, input_tensors):
|
|
1121
|
+
o.copy_(i)
|
|
1122
|
+
|
|
1123
|
+
res = _DummyWork(output_tensors)
|
|
1124
|
+
self._work.append(res)
|
|
1125
|
+
return res
|
|
1126
|
+
|
|
1127
|
+
def send(self, tensors: List[torch.Tensor], dst_rank: int, tag: int) -> Work:
|
|
1128
|
+
return _DummyWork(None)
|
|
1129
|
+
|
|
1130
|
+
def size(self) -> int:
|
|
1131
|
+
return self._world
|
|
1132
|
+
|
|
1133
|
+
def getBackendName(self) -> str:
|
|
1134
|
+
return "torchft-dummy"
|
|
1135
|
+
|
|
1136
|
+
|
|
1137
|
+
class _ErrorSwallowingWork(Work):
|
|
1138
|
+
def __init__(
|
|
1139
|
+
self,
|
|
1140
|
+
pg: "ErrorSwallowingProcessGroupWrapper",
|
|
1141
|
+
work: Work,
|
|
1142
|
+
default_result: object,
|
|
1143
|
+
) -> None:
|
|
1144
|
+
super().__init__()
|
|
1145
|
+
|
|
1146
|
+
self._pg = pg
|
|
1147
|
+
self._work = work
|
|
1148
|
+
self._default_result = default_result
|
|
1149
|
+
|
|
1150
|
+
def wait(self, timeout: Optional[timedelta] = None) -> bool:
|
|
1151
|
+
try:
|
|
1152
|
+
self._work.wait()
|
|
1153
|
+
except Exception as e:
|
|
1154
|
+
self._pg.report_error(e)
|
|
1155
|
+
|
|
1156
|
+
return True
|
|
1157
|
+
|
|
1158
|
+
def get_future(self) -> Future[object]:
|
|
1159
|
+
fut = self._work.get_future()
|
|
1160
|
+
|
|
1161
|
+
# schedule error handling as a continuation on the Future
|
|
1162
|
+
def callback(
|
|
1163
|
+
fut: torch.futures.Future[List[torch.Tensor]],
|
|
1164
|
+
) -> object:
|
|
1165
|
+
try:
|
|
1166
|
+
return fut.value()
|
|
1167
|
+
except Exception as e:
|
|
1168
|
+
logger.exception(f"got exception in future -- skipping remaining: {e}")
|
|
1169
|
+
self._pg.report_error(e)
|
|
1170
|
+
return self._default_result
|
|
1171
|
+
|
|
1172
|
+
fut = fut.then(callback)
|
|
1173
|
+
return fut
|
|
1174
|
+
|
|
1175
|
+
|
|
1176
|
+
class ErrorSwallowingProcessGroupWrapper(ProcessGroupWrapper):
|
|
1177
|
+
"""
|
|
1178
|
+
This is a wrapper around any ProcessGroup that will swallow errors and
|
|
1179
|
+
return dummy results on error.
|
|
1180
|
+
|
|
1181
|
+
This is intended to allow handling errors outside of the training loop to
|
|
1182
|
+
avoid having to modify modeling code to support error handling.
|
|
1183
|
+
|
|
1184
|
+
After an error occurs all future operations will be skipped until the
|
|
1185
|
+
process group is reconfigured via ``configure``.
|
|
1186
|
+
"""
|
|
1187
|
+
|
|
1188
|
+
def __init__(self, pg: ProcessGroup) -> None:
|
|
1189
|
+
super().__init__(pg=pg)
|
|
1190
|
+
|
|
1191
|
+
self._error: Optional[Exception] = None
|
|
1192
|
+
|
|
1193
|
+
def configure(
|
|
1194
|
+
self,
|
|
1195
|
+
store_addr: str,
|
|
1196
|
+
replica_id: str,
|
|
1197
|
+
rank: int,
|
|
1198
|
+
world_size: int,
|
|
1199
|
+
quorum_id: Optional[int] = None,
|
|
1200
|
+
group_rank: Optional[int] = None,
|
|
1201
|
+
group_world_size: Optional[int] = None,
|
|
1202
|
+
global_ranks: Optional[list[int]] = None,
|
|
1203
|
+
) -> None:
|
|
1204
|
+
self._error = None
|
|
1205
|
+
|
|
1206
|
+
super().configure(
|
|
1207
|
+
store_addr,
|
|
1208
|
+
replica_id,
|
|
1209
|
+
rank,
|
|
1210
|
+
world_size,
|
|
1211
|
+
quorum_id,
|
|
1212
|
+
group_rank,
|
|
1213
|
+
group_world_size,
|
|
1214
|
+
global_ranks,
|
|
1215
|
+
)
|
|
1216
|
+
|
|
1217
|
+
def report_error(self, e: Exception) -> None:
|
|
1218
|
+
"""
|
|
1219
|
+
Report an error to this process group. This will cause all future
|
|
1220
|
+
operations to be skipped until the process group is reconfigured via
|
|
1221
|
+
``configure``.
|
|
1222
|
+
|
|
1223
|
+
Args:
|
|
1224
|
+
e: exception to report
|
|
1225
|
+
"""
|
|
1226
|
+
self._error = e
|
|
1227
|
+
|
|
1228
|
+
def error(self) -> Optional[Exception]:
|
|
1229
|
+
"""
|
|
1230
|
+
Returns the error that was reported to this process group.
|
|
1231
|
+
|
|
1232
|
+
Returns:
|
|
1233
|
+
exception that was reported
|
|
1234
|
+
"""
|
|
1235
|
+
return self._error
|
|
1236
|
+
|
|
1237
|
+
def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
|
|
1238
|
+
if self._error is not None:
|
|
1239
|
+
return _DummyWork(tensors)
|
|
1240
|
+
|
|
1241
|
+
try:
|
|
1242
|
+
return _ErrorSwallowingWork(
|
|
1243
|
+
self,
|
|
1244
|
+
super().allreduce(tensors, opts),
|
|
1245
|
+
tensors,
|
|
1246
|
+
)
|
|
1247
|
+
except Exception as e:
|
|
1248
|
+
self.report_error(e)
|
|
1249
|
+
return _DummyWork(tensors)
|
|
1250
|
+
|
|
1251
|
+
|
|
1252
|
+
class FakeProcessGroupWrapper(ProcessGroupWrapper):
|
|
1253
|
+
"""
|
|
1254
|
+
This is a wrapper around any ProcessGroup that can be used to inject
|
|
1255
|
+
errors into the process group at various points.
|
|
1256
|
+
|
|
1257
|
+
This is intended to be used for tests so that they can test cases
|
|
1258
|
+
in which process group operations error out.
|
|
1259
|
+
"""
|
|
1260
|
+
|
|
1261
|
+
def __init__(self, pg: ProcessGroup) -> None:
|
|
1262
|
+
super().__init__(pg=pg)
|
|
1263
|
+
|
|
1264
|
+
self._future_error: Optional[Exception] = None
|
|
1265
|
+
|
|
1266
|
+
def configure(
|
|
1267
|
+
self,
|
|
1268
|
+
store_addr: str,
|
|
1269
|
+
replica_id: str,
|
|
1270
|
+
rank: int,
|
|
1271
|
+
world_size: int,
|
|
1272
|
+
quorum_id: Optional[int] = None,
|
|
1273
|
+
group_rank: Optional[int] = None,
|
|
1274
|
+
group_world_size: Optional[int] = None,
|
|
1275
|
+
global_ranks: Optional[list[int]] = None,
|
|
1276
|
+
) -> None:
|
|
1277
|
+
self._future_error = None
|
|
1278
|
+
|
|
1279
|
+
super().configure(
|
|
1280
|
+
store_addr,
|
|
1281
|
+
replica_id,
|
|
1282
|
+
rank,
|
|
1283
|
+
world_size,
|
|
1284
|
+
quorum_id,
|
|
1285
|
+
group_rank,
|
|
1286
|
+
group_world_size,
|
|
1287
|
+
global_ranks,
|
|
1288
|
+
)
|
|
1289
|
+
|
|
1290
|
+
def report_future_error(self, e: Exception) -> None:
|
|
1291
|
+
"""
|
|
1292
|
+
Report an error to this process group. This will cause the
|
|
1293
|
+
future attached to the next operation to error out.
|
|
1294
|
+
|
|
1295
|
+
Args:
|
|
1296
|
+
e: exception to report
|
|
1297
|
+
"""
|
|
1298
|
+
self._future_error = e
|
|
1299
|
+
|
|
1300
|
+
def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
|
|
1301
|
+
work = super().allreduce(tensors, opts)
|
|
1302
|
+
|
|
1303
|
+
if self._future_error is None:
|
|
1304
|
+
return work
|
|
1305
|
+
|
|
1306
|
+
fut = work.get_future()
|
|
1307
|
+
|
|
1308
|
+
def callback(
|
|
1309
|
+
fut: torch.futures.Future[List[torch.Tensor]],
|
|
1310
|
+
) -> List[torch.Tensor]:
|
|
1311
|
+
future_error, self._future_error = self._future_error, None
|
|
1312
|
+
assert future_error is not None
|
|
1313
|
+
raise future_error
|
|
1314
|
+
|
|
1315
|
+
fut = fut.then(callback)
|
|
1316
|
+
|
|
1317
|
+
return work
|
|
1318
|
+
|
|
1319
|
+
|
|
1320
|
+
class ManagedProcessGroup(ProcessGroupWrapper):
|
|
1321
|
+
"""
|
|
1322
|
+
This is a wrapper around any ProcessGroup that is managed by a torchft
|
|
1323
|
+
Manager.
|
|
1324
|
+
|
|
1325
|
+
This uses the ProcessGroup that is configured in the Manager. The world size
|
|
1326
|
+
is dynamic and will report the number of active particpants in the quorum to
|
|
1327
|
+
the model.
|
|
1328
|
+
|
|
1329
|
+
Any errors will be asynchronously reported to the manager and only successes
|
|
1330
|
+
will be returned to the caller.
|
|
1331
|
+
"""
|
|
1332
|
+
|
|
1333
|
+
def __init__(self, manager: "Manager") -> None:
|
|
1334
|
+
super().__init__(pg=manager._pg)
|
|
1335
|
+
|
|
1336
|
+
self._manager = manager
|
|
1337
|
+
|
|
1338
|
+
def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
|
|
1339
|
+
assert len(tensors) == 1
|
|
1340
|
+
|
|
1341
|
+
if isinstance(opts, ReduceOp):
|
|
1342
|
+
return self._manager.allreduce(tensors[0], reduce_op=opts)
|
|
1343
|
+
|
|
1344
|
+
if isinstance(opts, AllreduceOptions):
|
|
1345
|
+
return self._manager.allreduce(tensors[0], reduce_op=opts.reduceOp)
|
|
1346
|
+
|
|
1347
|
+
assert False, "unreachable"
|
|
1348
|
+
|
|
1349
|
+
def size(self) -> int:
|
|
1350
|
+
return self._manager.num_participants()
|
|
1351
|
+
|
|
1352
|
+
def getBackendName(self) -> str:
|
|
1353
|
+
return self._manager._pg.getBackendName()
|
|
1354
|
+
|
|
1355
|
+
|
|
1356
|
+
class _BabyWork(Work):
|
|
1357
|
+
def __init__(
|
|
1358
|
+
self,
|
|
1359
|
+
pg: "ProcessGroupBaby",
|
|
1360
|
+
op_id: int,
|
|
1361
|
+
stream: Optional[torch.Stream],
|
|
1362
|
+
) -> None:
|
|
1363
|
+
super().__init__()
|
|
1364
|
+
|
|
1365
|
+
self._pg = pg
|
|
1366
|
+
self._op_id = op_id
|
|
1367
|
+
self._stream = stream
|
|
1368
|
+
|
|
1369
|
+
def wait(self, timeout: Optional[timedelta] = None) -> bool:
|
|
1370
|
+
return self._pg._wait(self._op_id, timeout)
|
|
1371
|
+
|
|
1372
|
+
def synchronize(self) -> None:
|
|
1373
|
+
# TODO: No one seems to use this and NCCL wait already only waits the
|
|
1374
|
+
# stream and is non-blocking on the CPU side so no real need for a
|
|
1375
|
+
# separate call.
|
|
1376
|
+
raise NotImplementedError("not implemented")
|
|
1377
|
+
|
|
1378
|
+
def get_future(self) -> Future[object]:
|
|
1379
|
+
return self._pg._get_future(self._op_id, self._stream)
|
|
1380
|
+
|
|
1381
|
+
def __del__(self) -> None:
|
|
1382
|
+
self._pg._del(self._op_id)
|
|
1383
|
+
|
|
1384
|
+
|
|
1385
|
+
def _is_any_cuda(obj: object) -> bool:
|
|
1386
|
+
"""
|
|
1387
|
+
Returns true if any of the tensors in the object are CUDA tensors.
|
|
1388
|
+
|
|
1389
|
+
Supports lists, tuples, dicts, and tensors.
|
|
1390
|
+
"""
|
|
1391
|
+
return tree_any(lambda obj: isinstance(obj, torch.Tensor) and obj.is_cuda, obj)
|
|
1392
|
+
|
|
1393
|
+
|
|
1394
|
+
def _is_any_xpu(obj: object) -> bool:
|
|
1395
|
+
"""
|
|
1396
|
+
Returns true if any of the tensors in the object are XPU tensors.
|
|
1397
|
+
|
|
1398
|
+
Supports lists, tuples, dicts, and tensors.
|
|
1399
|
+
"""
|
|
1400
|
+
return tree_any(lambda obj: isinstance(obj, torch.Tensor) and obj.is_xpu, obj)
|
|
1401
|
+
|
|
1402
|
+
|
|
1403
|
+
@dataclass
|
|
1404
|
+
class _OpMetadata:
|
|
1405
|
+
work: Work
|
|
1406
|
+
stream: Optional[torch.Stream]
|
|
1407
|
+
|
|
1408
|
+
@contextmanager
|
|
1409
|
+
def set_stream(self) -> Generator[None, None, None]:
|
|
1410
|
+
with get_stream_context(self.stream):
|
|
1411
|
+
yield
|
|
1412
|
+
|
|
1413
|
+
|
|
1414
|
+
@dataclass
|
|
1415
|
+
class _FutureMetadata:
|
|
1416
|
+
future: Future[object]
|
|
1417
|
+
stream: Optional[torch.Stream]
|
|
1418
|
+
|
|
1419
|
+
@contextmanager
|
|
1420
|
+
def set_stream(self) -> Generator[None, None, None]:
|
|
1421
|
+
with get_stream_context(self.stream):
|
|
1422
|
+
yield
|
|
1423
|
+
|
|
1424
|
+
|
|
1425
|
+
def _maybe_share_tensors(
|
|
1426
|
+
tensor: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor],
|
|
1427
|
+
) -> None:
|
|
1428
|
+
"""Move a tensor / list of tensors to shared memory if not already in shared memory."""
|
|
1429
|
+
if isinstance(tensor, list):
|
|
1430
|
+
for t in tensor:
|
|
1431
|
+
_maybe_share_tensors(t)
|
|
1432
|
+
elif isinstance(tensor, torch.Tensor):
|
|
1433
|
+
if not tensor.is_shared():
|
|
1434
|
+
tensor.share_memory_()
|
|
1435
|
+
else:
|
|
1436
|
+
raise TypeError(f"expected tensor or list but got {type(tensor)}")
|
|
1437
|
+
|
|
1438
|
+
|
|
1439
|
+
def _assert_list(tensors: Union[List[torch.Tensor], List[List[torch.Tensor]]]) -> None:
|
|
1440
|
+
"""Assert that the input is a list of tensors or a nested list of tensors."""
|
|
1441
|
+
if not isinstance(tensors, list):
|
|
1442
|
+
raise TypeError(f"expected list but got {type(tensors)}")
|
|
1443
|
+
|
|
1444
|
+
|
|
1445
|
+
class ProcessGroupBaby(ProcessGroup):
|
|
1446
|
+
"""
|
|
1447
|
+
This is a process group that runs the underlying process group in a
|
|
1448
|
+
subprocess. Since it's running in a subprocess all tensors need to be in
|
|
1449
|
+
shared memory or will be moved to shared memory. CUDA/XPU tensors are implicitly
|
|
1450
|
+
shareable and don't need any changes.
|
|
1451
|
+
"""
|
|
1452
|
+
|
|
1453
|
+
def __init__(self, timeout: Union[float, timedelta] = 60.0) -> None:
|
|
1454
|
+
super().__init__(0, 1)
|
|
1455
|
+
|
|
1456
|
+
self._world_size = -1
|
|
1457
|
+
|
|
1458
|
+
self._p: Optional[mp.Process] = None
|
|
1459
|
+
self._pipe: Optional[_MonitoredPipe] = None
|
|
1460
|
+
self._future_pipe: Optional[_MonitoredPipe] = None
|
|
1461
|
+
self._future_thread: Optional[threading.Thread] = None
|
|
1462
|
+
self._futures: Dict[int, _FutureMetadata] = {}
|
|
1463
|
+
self._futures_lock = threading.Lock()
|
|
1464
|
+
|
|
1465
|
+
self._next_op_id = 0
|
|
1466
|
+
|
|
1467
|
+
if isinstance(timeout, timedelta):
|
|
1468
|
+
timeout = timeout.total_seconds()
|
|
1469
|
+
|
|
1470
|
+
self._timeout: float = timeout
|
|
1471
|
+
|
|
1472
|
+
def shutdown(self) -> None:
|
|
1473
|
+
"""
|
|
1474
|
+
Shutdown the process group. This will kill the underlying process and
|
|
1475
|
+
close all queues.
|
|
1476
|
+
|
|
1477
|
+
This is a no-op if the process group is already shutdown.
|
|
1478
|
+
|
|
1479
|
+
ProcessGroup can be reconfigured after shutdown.
|
|
1480
|
+
"""
|
|
1481
|
+
|
|
1482
|
+
if self._pipe is not None:
|
|
1483
|
+
self._pipe.close()
|
|
1484
|
+
|
|
1485
|
+
future_pipe = self._future_pipe
|
|
1486
|
+
if future_pipe is not None:
|
|
1487
|
+
# wait for the future thread to exit and then close the queue
|
|
1488
|
+
future_pipe.close()
|
|
1489
|
+
|
|
1490
|
+
future_thread = self._future_thread
|
|
1491
|
+
assert future_thread is not None
|
|
1492
|
+
|
|
1493
|
+
future_thread.join(timeout=10.0)
|
|
1494
|
+
if future_thread.is_alive():
|
|
1495
|
+
raise RuntimeError("future thread did not exit")
|
|
1496
|
+
|
|
1497
|
+
# Kill after closing queues to avoid log spam.
|
|
1498
|
+
if self._p is not None:
|
|
1499
|
+
self._p.kill()
|
|
1500
|
+
|
|
1501
|
+
def configure(
|
|
1502
|
+
self,
|
|
1503
|
+
store_addr: str,
|
|
1504
|
+
replica_id: str,
|
|
1505
|
+
rank: int,
|
|
1506
|
+
world_size: int,
|
|
1507
|
+
quorum_id: Optional[int] = None,
|
|
1508
|
+
group_rank: Optional[int] = None,
|
|
1509
|
+
group_world_size: Optional[int] = None,
|
|
1510
|
+
global_ranks: Optional[list[int]] = None,
|
|
1511
|
+
) -> None:
|
|
1512
|
+
self._world_size = world_size
|
|
1513
|
+
|
|
1514
|
+
self.shutdown()
|
|
1515
|
+
|
|
1516
|
+
ctx = mp.get_context("spawn")
|
|
1517
|
+
req_local, req_remote = ctx.Pipe()
|
|
1518
|
+
future_local, future_remote = ctx.Pipe()
|
|
1519
|
+
|
|
1520
|
+
self._pipe = req_local = _MonitoredPipe(req_local)
|
|
1521
|
+
self._future_pipe = future_local = _MonitoredPipe(future_local)
|
|
1522
|
+
|
|
1523
|
+
curr_device = (
|
|
1524
|
+
torch.accelerator.current_device_index()
|
|
1525
|
+
if torch.accelerator.is_available()
|
|
1526
|
+
else -1
|
|
1527
|
+
)
|
|
1528
|
+
|
|
1529
|
+
self._p = p = ctx.Process(
|
|
1530
|
+
target=self._worker,
|
|
1531
|
+
args=(
|
|
1532
|
+
store_addr,
|
|
1533
|
+
rank,
|
|
1534
|
+
world_size,
|
|
1535
|
+
req_remote,
|
|
1536
|
+
future_remote,
|
|
1537
|
+
curr_device,
|
|
1538
|
+
),
|
|
1539
|
+
daemon=True,
|
|
1540
|
+
)
|
|
1541
|
+
p.start()
|
|
1542
|
+
|
|
1543
|
+
# futures need thread to fire callbacks
|
|
1544
|
+
# this lock needs to be held when manipulating _futures
|
|
1545
|
+
self._futures_lock = threading.Lock()
|
|
1546
|
+
self._futures = {}
|
|
1547
|
+
self._future_thread = threading.Thread(
|
|
1548
|
+
target=self._future_handler,
|
|
1549
|
+
args=(future_local,),
|
|
1550
|
+
daemon=True,
|
|
1551
|
+
)
|
|
1552
|
+
self._future_thread.start()
|
|
1553
|
+
|
|
1554
|
+
# fetch the status of the PG init
|
|
1555
|
+
# if an exception was returned get will throw
|
|
1556
|
+
assert req_local.recv(self._timeout) is None
|
|
1557
|
+
|
|
1558
|
+
@classmethod
|
|
1559
|
+
def _create_pg(cls, store: Store, rank: int, world_size: int) -> BaseProcessGroup:
|
|
1560
|
+
"""
|
|
1561
|
+
This is a class method to avoid pickling the class.
|
|
1562
|
+
"""
|
|
1563
|
+
raise NotImplementedError("not implemented")
|
|
1564
|
+
|
|
1565
|
+
@classmethod
|
|
1566
|
+
def _worker(
|
|
1567
|
+
cls,
|
|
1568
|
+
store_addr: str,
|
|
1569
|
+
rank: int,
|
|
1570
|
+
world_size: int,
|
|
1571
|
+
req_pipe: "Connection[object, object]", # type: ignore
|
|
1572
|
+
future_pipe: "Connection[object, object]", # type: ignore
|
|
1573
|
+
curr_device: int,
|
|
1574
|
+
) -> None:
|
|
1575
|
+
try:
|
|
1576
|
+
if curr_device >= 0 and torch.accelerator.is_available():
|
|
1577
|
+
torch.accelerator.set_device_index(curr_device)
|
|
1578
|
+
|
|
1579
|
+
store = create_store_client(
|
|
1580
|
+
store_addr,
|
|
1581
|
+
# default TCPStore timeout is 5 minutes
|
|
1582
|
+
timeout=timedelta(minutes=5),
|
|
1583
|
+
)
|
|
1584
|
+
|
|
1585
|
+
try:
|
|
1586
|
+
pg = cls._create_pg(store, rank, world_size)
|
|
1587
|
+
except Exception as e:
|
|
1588
|
+
logger.exception(f"got exception in worker: {e}")
|
|
1589
|
+
req_pipe.send(e)
|
|
1590
|
+
return
|
|
1591
|
+
req_pipe.send(None)
|
|
1592
|
+
|
|
1593
|
+
streams: Dict[str, torch.Stream] = {}
|
|
1594
|
+
work: Dict[int, _OpMetadata] = {}
|
|
1595
|
+
|
|
1596
|
+
while True:
|
|
1597
|
+
op = cast(list[object], req_pipe.recv())
|
|
1598
|
+
cmd = op[0]
|
|
1599
|
+
if cmd == "func":
|
|
1600
|
+
op_id: int
|
|
1601
|
+
op_id, func_name, args, kwargs, stream_device, stream_id, event = (
|
|
1602
|
+
cast(
|
|
1603
|
+
Tuple[
|
|
1604
|
+
int,
|
|
1605
|
+
str,
|
|
1606
|
+
list[object],
|
|
1607
|
+
dict[str, object],
|
|
1608
|
+
int,
|
|
1609
|
+
int,
|
|
1610
|
+
Optional[Union[torch.cuda.Event, torch.xpu.Event]],
|
|
1611
|
+
],
|
|
1612
|
+
op[1:],
|
|
1613
|
+
)
|
|
1614
|
+
)
|
|
1615
|
+
|
|
1616
|
+
# To avoid potential deadlocks we need to preserve the
|
|
1617
|
+
# stream/synchronization behavior of the parent process.
|
|
1618
|
+
# We allocate one Stream per stream_id to make sure that we
|
|
1619
|
+
# don't accidentally introduce cross stream synchronization
|
|
1620
|
+
# points.
|
|
1621
|
+
if stream_id is not None:
|
|
1622
|
+
stream_key = f"{stream_device}/{stream_id}"
|
|
1623
|
+
if stream_key not in streams:
|
|
1624
|
+
streams[stream_key] = torch.Stream(device=stream_device)
|
|
1625
|
+
stream = streams[stream_key]
|
|
1626
|
+
else:
|
|
1627
|
+
stream = None
|
|
1628
|
+
|
|
1629
|
+
with get_stream_context(stream):
|
|
1630
|
+
# Make the stream wait on the cuda event to make sure we
|
|
1631
|
+
# don't start the operation until the tensor is ready.
|
|
1632
|
+
if event is not None:
|
|
1633
|
+
event.wait()
|
|
1634
|
+
|
|
1635
|
+
args = _PickleSafeOptions.unsafe_args(args)
|
|
1636
|
+
fn = getattr(pg, func_name)
|
|
1637
|
+
|
|
1638
|
+
work[op_id] = _OpMetadata(
|
|
1639
|
+
work=fn(*args, **kwargs),
|
|
1640
|
+
stream=stream,
|
|
1641
|
+
)
|
|
1642
|
+
|
|
1643
|
+
elif cmd == "wait":
|
|
1644
|
+
op_id, timeout = cast(tuple[int, timedelta], op[1:])
|
|
1645
|
+
|
|
1646
|
+
metadata = work[op_id]
|
|
1647
|
+
|
|
1648
|
+
with metadata.set_stream():
|
|
1649
|
+
# With WorkNCCL this makes the stream wait not the CPU when
|
|
1650
|
+
# no timeout is passed.
|
|
1651
|
+
if timeout is not None:
|
|
1652
|
+
metadata.work.wait(timeout)
|
|
1653
|
+
else:
|
|
1654
|
+
metadata.work.wait()
|
|
1655
|
+
|
|
1656
|
+
# Register event on the stream that we can pass to the main
|
|
1657
|
+
# process.
|
|
1658
|
+
event = record_event() if metadata.stream is not None else None
|
|
1659
|
+
|
|
1660
|
+
req_pipe.send((op_id, event))
|
|
1661
|
+
elif cmd == "del":
|
|
1662
|
+
op_id: int = cast(int, op[1])
|
|
1663
|
+
del work[op_id]
|
|
1664
|
+
elif cmd == "future":
|
|
1665
|
+
op_id: int = cast(int, op[1])
|
|
1666
|
+
metadata: _OpMetadata = work[op_id]
|
|
1667
|
+
|
|
1668
|
+
def callback(fut: Future[object], metadata: _OpMetadata) -> None:
|
|
1669
|
+
try:
|
|
1670
|
+
# create an event after the collective has been issued
|
|
1671
|
+
# to wait on this before we call "future"
|
|
1672
|
+
with metadata.set_stream():
|
|
1673
|
+
fut.wait()
|
|
1674
|
+
event = (
|
|
1675
|
+
record_event()
|
|
1676
|
+
if metadata.stream is not None
|
|
1677
|
+
else None
|
|
1678
|
+
)
|
|
1679
|
+
|
|
1680
|
+
future_pipe.send((op_id, _FUTURE_RESULT, None, event))
|
|
1681
|
+
except Exception as e:
|
|
1682
|
+
future_pipe.send((op_id, _FUTURE_EXCEPTION, e, None))
|
|
1683
|
+
|
|
1684
|
+
metadata.work.get_future().add_done_callback(
|
|
1685
|
+
lambda fut: callback(fut, metadata)
|
|
1686
|
+
)
|
|
1687
|
+
elif cmd == "num_active_work":
|
|
1688
|
+
req_pipe.send(len(work))
|
|
1689
|
+
else:
|
|
1690
|
+
raise ValueError(f"unknown cmd: {cmd}")
|
|
1691
|
+
|
|
1692
|
+
except Exception as e:
|
|
1693
|
+
logger.exception(f"worker errored: {e}")
|
|
1694
|
+
req_pipe.send(e)
|
|
1695
|
+
raise
|
|
1696
|
+
|
|
1697
|
+
def _future_handler(self, future_pipe: _MonitoredPipe) -> None:
|
|
1698
|
+
try:
|
|
1699
|
+
while True:
|
|
1700
|
+
try:
|
|
1701
|
+
cmd = future_pipe.recv(timedelta(seconds=10))
|
|
1702
|
+
except TimeoutError:
|
|
1703
|
+
continue
|
|
1704
|
+
except OSError:
|
|
1705
|
+
# subprocess exited
|
|
1706
|
+
break
|
|
1707
|
+
|
|
1708
|
+
op_id, mode, data, event = cast(
|
|
1709
|
+
Tuple[
|
|
1710
|
+
int,
|
|
1711
|
+
str,
|
|
1712
|
+
object,
|
|
1713
|
+
Optional[Union[torch.cuda.Event, torch.xpu.Event]],
|
|
1714
|
+
],
|
|
1715
|
+
cmd,
|
|
1716
|
+
)
|
|
1717
|
+
with self._futures_lock:
|
|
1718
|
+
meta = self._futures[op_id]
|
|
1719
|
+
del self._futures[op_id]
|
|
1720
|
+
with meta.set_stream():
|
|
1721
|
+
if mode == _FUTURE_RESULT:
|
|
1722
|
+
if event is not None:
|
|
1723
|
+
event.wait()
|
|
1724
|
+
meta.future.set_result(data)
|
|
1725
|
+
elif mode == _FUTURE_EXCEPTION:
|
|
1726
|
+
meta.future.set_exception(data)
|
|
1727
|
+
else:
|
|
1728
|
+
raise ValueError(f"unknown mode {mode}")
|
|
1729
|
+
except Exception as e:
|
|
1730
|
+
logger.exception(f"got unexpected error in future handler: {e}")
|
|
1731
|
+
|
|
1732
|
+
def _get_future(self, op_id: int, stream: Optional[torch.Stream]) -> Future[object]:
|
|
1733
|
+
with self._futures_lock:
|
|
1734
|
+
fut = Future()
|
|
1735
|
+
self._futures[op_id] = _FutureMetadata(future=fut, stream=stream)
|
|
1736
|
+
assert self._pipe is not None
|
|
1737
|
+
self._pipe.send(("future", op_id))
|
|
1738
|
+
|
|
1739
|
+
# TODO: return correct tensor instead of None
|
|
1740
|
+
return fut
|
|
1741
|
+
|
|
1742
|
+
def _wait(self, op_id: int, timeout: Optional[timedelta] = None) -> bool:
|
|
1743
|
+
assert self._pipe is not None
|
|
1744
|
+
self._pipe.send(("wait", op_id, timeout))
|
|
1745
|
+
|
|
1746
|
+
assert self._pipe is not None
|
|
1747
|
+
op_id, event = cast(
|
|
1748
|
+
Tuple[int, Optional[Union[torch.cuda.Event, torch.xpu.Event]]],
|
|
1749
|
+
self._pipe.recv(timeout or self._timeout),
|
|
1750
|
+
)
|
|
1751
|
+
assert op_id == op_id
|
|
1752
|
+
if event is not None:
|
|
1753
|
+
event.wait()
|
|
1754
|
+
|
|
1755
|
+
return True
|
|
1756
|
+
|
|
1757
|
+
def _del(self, op_id: int) -> None:
|
|
1758
|
+
assert self._pipe is not None
|
|
1759
|
+
try:
|
|
1760
|
+
self._pipe.send(("del", op_id))
|
|
1761
|
+
except OSError:
|
|
1762
|
+
# if pipe is closed we can safely do nothing
|
|
1763
|
+
pass
|
|
1764
|
+
|
|
1765
|
+
def _run_func(self, func: str, *args: object, **kwargs: object) -> Work:
|
|
1766
|
+
pipe = self._pipe
|
|
1767
|
+
assert pipe is not None
|
|
1768
|
+
|
|
1769
|
+
is_accelerator = _is_any_cuda(args) or _is_any_xpu(args)
|
|
1770
|
+
|
|
1771
|
+
stream_device = (
|
|
1772
|
+
torch.accelerator.current_stream().device if is_accelerator else None
|
|
1773
|
+
)
|
|
1774
|
+
stream_id = (
|
|
1775
|
+
torch.accelerator.current_stream().stream_id if is_accelerator else None
|
|
1776
|
+
)
|
|
1777
|
+
event = record_event() if is_accelerator else None
|
|
1778
|
+
|
|
1779
|
+
op_id = self._next_op_id
|
|
1780
|
+
self._next_op_id += 1
|
|
1781
|
+
|
|
1782
|
+
pipe.send(
|
|
1783
|
+
(
|
|
1784
|
+
"func",
|
|
1785
|
+
op_id,
|
|
1786
|
+
func,
|
|
1787
|
+
_PickleSafeOptions.safe_args(args),
|
|
1788
|
+
kwargs,
|
|
1789
|
+
stream_device,
|
|
1790
|
+
stream_id,
|
|
1791
|
+
event,
|
|
1792
|
+
),
|
|
1793
|
+
)
|
|
1794
|
+
|
|
1795
|
+
return _BabyWork(
|
|
1796
|
+
pg=self,
|
|
1797
|
+
op_id=op_id,
|
|
1798
|
+
stream=torch.accelerator.current_stream() if is_accelerator else None,
|
|
1799
|
+
)
|
|
1800
|
+
|
|
1801
|
+
def allgather(
|
|
1802
|
+
self,
|
|
1803
|
+
output_tensors: List[List[torch.Tensor]],
|
|
1804
|
+
input_tensor: List[torch.Tensor],
|
|
1805
|
+
opts: AllgatherOptions,
|
|
1806
|
+
) -> Work:
|
|
1807
|
+
_assert_list(output_tensors)
|
|
1808
|
+
_assert_list(input_tensor)
|
|
1809
|
+
_maybe_share_tensors(output_tensors)
|
|
1810
|
+
_maybe_share_tensors(input_tensor)
|
|
1811
|
+
return self._run_func("allgather", output_tensors, input_tensor, opts)
|
|
1812
|
+
|
|
1813
|
+
def allgather_into_tensor_coalesced(
|
|
1814
|
+
self,
|
|
1815
|
+
output_tensors: List[torch.Tensor],
|
|
1816
|
+
input_tensors: List[torch.Tensor],
|
|
1817
|
+
opts: AllgatherOptions,
|
|
1818
|
+
) -> Work:
|
|
1819
|
+
_assert_list(output_tensors)
|
|
1820
|
+
_assert_list(input_tensors)
|
|
1821
|
+
_maybe_share_tensors(output_tensors)
|
|
1822
|
+
_maybe_share_tensors(input_tensors)
|
|
1823
|
+
return self._run_func(
|
|
1824
|
+
"allgather_into_tensor_coalesced", output_tensors, input_tensors, opts
|
|
1825
|
+
)
|
|
1826
|
+
|
|
1827
|
+
def allreduce(
|
|
1828
|
+
self,
|
|
1829
|
+
tensors: List[torch.Tensor],
|
|
1830
|
+
opts: Union[dist.AllreduceOptions, dist.ReduceOp],
|
|
1831
|
+
) -> Work:
|
|
1832
|
+
_assert_list(tensors)
|
|
1833
|
+
_maybe_share_tensors(tensors)
|
|
1834
|
+
return self._run_func("allreduce", tensors, opts)
|
|
1835
|
+
|
|
1836
|
+
def allreduce_coalesced(
|
|
1837
|
+
self,
|
|
1838
|
+
tensors: List[torch.Tensor],
|
|
1839
|
+
opts: Union[dist.AllreduceCoalescedOptions, dist.ReduceOp],
|
|
1840
|
+
) -> Work:
|
|
1841
|
+
_assert_list(tensors)
|
|
1842
|
+
_maybe_share_tensors(tensors)
|
|
1843
|
+
return self._run_func("allreduce_coalesced", tensors, opts)
|
|
1844
|
+
|
|
1845
|
+
def alltoall_base(
|
|
1846
|
+
self,
|
|
1847
|
+
output_buffer: torch.Tensor,
|
|
1848
|
+
input_buffer: torch.Tensor,
|
|
1849
|
+
output_split_sizes: List[int],
|
|
1850
|
+
input_split_sizes: List[int],
|
|
1851
|
+
opts: AllToAllOptions,
|
|
1852
|
+
) -> Work:
|
|
1853
|
+
_maybe_share_tensors(output_buffer)
|
|
1854
|
+
_maybe_share_tensors(input_buffer)
|
|
1855
|
+
return self._run_func(
|
|
1856
|
+
"alltoall_base",
|
|
1857
|
+
output_buffer,
|
|
1858
|
+
input_buffer,
|
|
1859
|
+
output_split_sizes,
|
|
1860
|
+
input_split_sizes,
|
|
1861
|
+
opts,
|
|
1862
|
+
)
|
|
1863
|
+
|
|
1864
|
+
def barrier(self, opts: Optional[BarrierOptions] = None) -> Work:
|
|
1865
|
+
return self._run_func("barrier", opts)
|
|
1866
|
+
|
|
1867
|
+
def broadcast(
|
|
1868
|
+
self,
|
|
1869
|
+
tensor_list: List[torch.Tensor],
|
|
1870
|
+
opts: BroadcastOptions,
|
|
1871
|
+
) -> Work:
|
|
1872
|
+
_assert_list(tensor_list)
|
|
1873
|
+
_maybe_share_tensors(tensor_list)
|
|
1874
|
+
return self._run_func("broadcast", tensor_list, opts)
|
|
1875
|
+
|
|
1876
|
+
def recv(self, tensors: List[torch.Tensor], src_rank: int, tag: int) -> Work:
|
|
1877
|
+
_assert_list(tensors)
|
|
1878
|
+
_maybe_share_tensors(tensors)
|
|
1879
|
+
return self._run_func("recv", tensors, src_rank, tag)
|
|
1880
|
+
|
|
1881
|
+
def reduce_scatter(
|
|
1882
|
+
self,
|
|
1883
|
+
output_tensors: List[torch.Tensor],
|
|
1884
|
+
input_tensors: List[List[torch.Tensor]],
|
|
1885
|
+
opts: ReduceScatterOptions,
|
|
1886
|
+
) -> Work:
|
|
1887
|
+
_assert_list(output_tensors)
|
|
1888
|
+
_assert_list(input_tensors)
|
|
1889
|
+
_maybe_share_tensors(output_tensors)
|
|
1890
|
+
_maybe_share_tensors(input_tensors)
|
|
1891
|
+
return self._run_func("reduce_scatter", output_tensors, input_tensors, opts)
|
|
1892
|
+
|
|
1893
|
+
def reduce_scatter_tensor_coalesced(
|
|
1894
|
+
self,
|
|
1895
|
+
output_tensors: List[torch.Tensor],
|
|
1896
|
+
input_tensors: List[torch.Tensor],
|
|
1897
|
+
opts: ReduceScatterOptions,
|
|
1898
|
+
) -> Work:
|
|
1899
|
+
_assert_list(output_tensors)
|
|
1900
|
+
_assert_list(input_tensors)
|
|
1901
|
+
_maybe_share_tensors(output_tensors)
|
|
1902
|
+
_maybe_share_tensors(input_tensors)
|
|
1903
|
+
return self._run_func(
|
|
1904
|
+
"reduce_scatter_tensor_coalesced", output_tensors, input_tensors, opts
|
|
1905
|
+
)
|
|
1906
|
+
|
|
1907
|
+
def send(self, tensors: List[torch.Tensor], dst_rank: int, tag: int) -> Work:
|
|
1908
|
+
_assert_list(tensors)
|
|
1909
|
+
_maybe_share_tensors(tensors)
|
|
1910
|
+
return self._run_func("send", tensors, dst_rank, tag)
|
|
1911
|
+
|
|
1912
|
+
def size(self) -> int:
|
|
1913
|
+
return self._world_size
|
|
1914
|
+
|
|
1915
|
+
def num_active_work(self) -> int:
|
|
1916
|
+
assert self._pipe is not None
|
|
1917
|
+
self._pipe.send(("num_active_work",))
|
|
1918
|
+
|
|
1919
|
+
assert self._pipe is not None
|
|
1920
|
+
return cast(int, self._pipe.recv(self._timeout))
|
|
1921
|
+
|
|
1922
|
+
def set_timeout(self, timeout: timedelta) -> None:
|
|
1923
|
+
self._timeout = timeout.total_seconds()
|
|
1924
|
+
|
|
1925
|
+
|
|
1926
|
+
@dataclass
|
|
1927
|
+
class _PickleSafeOptions:
|
|
1928
|
+
func: Callable[[], object]
|
|
1929
|
+
fields: Dict[str, object]
|
|
1930
|
+
|
|
1931
|
+
@classmethod
|
|
1932
|
+
def safe_args(cls, args: T) -> T:
|
|
1933
|
+
if isinstance(args, tuple):
|
|
1934
|
+
return tuple(cls.safe_args(arg) for arg in args)
|
|
1935
|
+
elif isinstance(args, list):
|
|
1936
|
+
return [cls.safe_args(arg) for arg in args]
|
|
1937
|
+
elif isinstance(
|
|
1938
|
+
args,
|
|
1939
|
+
(
|
|
1940
|
+
AllgatherOptions,
|
|
1941
|
+
AllreduceOptions,
|
|
1942
|
+
AllreduceCoalescedOptions,
|
|
1943
|
+
AllToAllOptions,
|
|
1944
|
+
BarrierOptions,
|
|
1945
|
+
BroadcastOptions,
|
|
1946
|
+
ReduceScatterOptions,
|
|
1947
|
+
),
|
|
1948
|
+
):
|
|
1949
|
+
return cls.from_torch(args)
|
|
1950
|
+
else:
|
|
1951
|
+
return args
|
|
1952
|
+
|
|
1953
|
+
@classmethod
|
|
1954
|
+
def unsafe_args(cls, args: T) -> T:
|
|
1955
|
+
if isinstance(args, tuple):
|
|
1956
|
+
return tuple(cls.unsafe_args(arg) for arg in args)
|
|
1957
|
+
elif isinstance(args, list):
|
|
1958
|
+
return [cls.unsafe_args(arg) for arg in args]
|
|
1959
|
+
elif isinstance(args, cls):
|
|
1960
|
+
return args.to_torch()
|
|
1961
|
+
else:
|
|
1962
|
+
return args
|
|
1963
|
+
|
|
1964
|
+
@classmethod
|
|
1965
|
+
def from_torch(cls, opts: object) -> "_PickleSafeOptions":
|
|
1966
|
+
return cls(
|
|
1967
|
+
func=opts.__class__,
|
|
1968
|
+
fields={k: getattr(opts, k) for k in dir(opts) if not k.startswith("_")},
|
|
1969
|
+
)
|
|
1970
|
+
|
|
1971
|
+
def to_torch(self) -> object:
|
|
1972
|
+
opts = self.func()
|
|
1973
|
+
for k, v in self.fields.items():
|
|
1974
|
+
setattr(opts, k, v)
|
|
1975
|
+
return opts
|
|
1976
|
+
|
|
1977
|
+
|
|
1978
|
+
class ProcessGroupBabyGloo(ProcessGroupBaby):
|
|
1979
|
+
"""
|
|
1980
|
+
This is a ProcessGroup that runs Gloo in a subprocess.
|
|
1981
|
+
|
|
1982
|
+
For most use cases you should prefer ProcessGroupGloo or
|
|
1983
|
+
ProcessGroupBabyNCCL.
|
|
1984
|
+
"""
|
|
1985
|
+
|
|
1986
|
+
@classmethod
|
|
1987
|
+
def _create_pg(cls, store: Store, rank: int, world_size: int) -> BaseProcessGroup:
|
|
1988
|
+
pg = BaseProcessGroup(store, rank, world_size)
|
|
1989
|
+
pg._set_default_backend(ProcessGroup.BackendType.GLOO)
|
|
1990
|
+
# pyre-fixme[16]: no attribute ProcessGroupGloo
|
|
1991
|
+
backend_class = BaseProcessGroupGloo(store, rank, world_size)
|
|
1992
|
+
pg._register_backend(
|
|
1993
|
+
torch.device("cpu"), ProcessGroup.BackendType.GLOO, backend_class
|
|
1994
|
+
)
|
|
1995
|
+
return pg
|
|
1996
|
+
|
|
1997
|
+
def getBackendName(self) -> str:
|
|
1998
|
+
return "torchft-baby-gloo"
|
|
1999
|
+
|
|
2000
|
+
# pyre-fixme[15]: inconsistent override
|
|
2001
|
+
def reduce_scatter(
|
|
2002
|
+
self,
|
|
2003
|
+
output_tensors: List[torch.Tensor],
|
|
2004
|
+
input_tensors: List[List[torch.Tensor]],
|
|
2005
|
+
opts: ReduceScatterOptions,
|
|
2006
|
+
) -> None:
|
|
2007
|
+
"""
|
|
2008
|
+
This function is a placeholder for the reduce_scatter operation in the
|
|
2009
|
+
ProcessGroupGloo class. However, this operation is not supported by the
|
|
2010
|
+
Gloo backend, and thus, calling this function will raise a
|
|
2011
|
+
RuntimeError.
|
|
2012
|
+
|
|
2013
|
+
Raises:
|
|
2014
|
+
RuntimeError: Always raised since reduce_scatter is not
|
|
2015
|
+
supported by ProcessGroupGloo.
|
|
2016
|
+
"""
|
|
2017
|
+
raise RuntimeError("ProcessGroupBabyGloo does not support reduce_scatter.")
|
|
2018
|
+
|
|
2019
|
+
# pyre-fixme[15]: inconsistent override
|
|
2020
|
+
def reduce_scatter_tensor_coalesced(
|
|
2021
|
+
self,
|
|
2022
|
+
output_tensors: List[torch.Tensor],
|
|
2023
|
+
input_tensors: List[torch.Tensor],
|
|
2024
|
+
opts: ReduceScatterOptions,
|
|
2025
|
+
) -> None:
|
|
2026
|
+
"""
|
|
2027
|
+
This function is a placeholder for the reduce_scatter_tensor_coalesced
|
|
2028
|
+
operation in the ProcessGroupBabyGloo class.
|
|
2029
|
+
However, this operation is not supported by the
|
|
2030
|
+
Gloo backend, and thus, calling this function will raise a
|
|
2031
|
+
RuntimeError.
|
|
2032
|
+
|
|
2033
|
+
Raises:
|
|
2034
|
+
RuntimeError: Always raised since reduce_scatter is not
|
|
2035
|
+
supported by ProcessGroupBabyGloo.
|
|
2036
|
+
"""
|
|
2037
|
+
raise RuntimeError(
|
|
2038
|
+
"ProcessGroupBabyGloo does not support reduce_scatter_tensor_coalesced."
|
|
2039
|
+
)
|
|
2040
|
+
|
|
2041
|
+
|
|
2042
|
+
class ProcessGroupBabyNCCL(ProcessGroupBaby):
|
|
2043
|
+
"""
|
|
2044
|
+
This is a ProcessGroup that runs NCCL in a subprocess.
|
|
2045
|
+
|
|
2046
|
+
For the NCCL backend, extra memory will be used by the subprocesses CUDA
|
|
2047
|
+
context compared to running NCCL in the main process. This is typically
|
|
2048
|
+
around ~1GB.
|
|
2049
|
+
|
|
2050
|
+
The returned Work objects only synchronize on the cuda stream and not on the
|
|
2051
|
+
CPU side. This works by passing CUDA Events between the processes. To do a
|
|
2052
|
+
CPU synchronize, call torch.cuda.synchronize() after wait().
|
|
2053
|
+
|
|
2054
|
+
WARNING: If the child process is killed while an operation is running, CUDA
|
|
2055
|
+
tensors may leak in the current PyTorch implementation. TODO fix
|
|
2056
|
+
|
|
2057
|
+
WARNING: As this uses a separate CUDA context for the subprocess, performance
|
|
2058
|
+
may be slower than using NCCL directly. Separate CUDA contexts can not run
|
|
2059
|
+
at the same time so network and compute kernels will not overlap execution
|
|
2060
|
+
and instead do time sharing which may reduce GPU utilization.
|
|
2061
|
+
"""
|
|
2062
|
+
|
|
2063
|
+
@classmethod
|
|
2064
|
+
def _create_pg(cls, store: Store, rank: int, world_size: int) -> BaseProcessGroup:
|
|
2065
|
+
from torch.distributed import ProcessGroupNCCL as BaseProcessGroupNCCL
|
|
2066
|
+
|
|
2067
|
+
pg = BaseProcessGroup(store, rank, world_size)
|
|
2068
|
+
pg._set_default_backend(ProcessGroup.BackendType.NCCL)
|
|
2069
|
+
# pyre-fixme[16]: no attribute ProcessGroupNCCL
|
|
2070
|
+
backend_class = BaseProcessGroupNCCL(store, rank, world_size)
|
|
2071
|
+
backend_class._set_sequence_number_for_group()
|
|
2072
|
+
pg._register_backend(
|
|
2073
|
+
torch.device("cuda"), ProcessGroup.BackendType.NCCL, backend_class
|
|
2074
|
+
)
|
|
2075
|
+
return pg
|
|
2076
|
+
|
|
2077
|
+
def getBackendName(self) -> str:
|
|
2078
|
+
return "torchft-baby-nccl"
|
|
2079
|
+
|
|
2080
|
+
|
|
2081
|
+
class ProcessGroupBabyXCCL(ProcessGroupBaby):
|
|
2082
|
+
"""
|
|
2083
|
+
This is a ProcessGroup that runs XCCL in a subprocess for Intel XPU devices.
|
|
2084
|
+
|
|
2085
|
+
For the XCCL backend, extra memory will be used by the subprocesses XPU
|
|
2086
|
+
context compared to running XCCL in the main process. This is typically
|
|
2087
|
+
dependent on the XPU memory architecture.
|
|
2088
|
+
|
|
2089
|
+
The returned Work objects only synchronize on the XPU stream and not on the
|
|
2090
|
+
CPU side. This works by passing XPU Events between the processes. To do a
|
|
2091
|
+
CPU synchronize, call torch.xpu.synchronize() after wait().
|
|
2092
|
+
|
|
2093
|
+
WARNING: If the child process is killed while an operation is running, XPU
|
|
2094
|
+
tensors may leak in the current PyTorch implementation. TODO fix
|
|
2095
|
+
|
|
2096
|
+
WARNING: As this uses a separate XPU context for the subprocess, performance
|
|
2097
|
+
may be slower than using XCCL directly. Separate XPU contexts can not run
|
|
2098
|
+
at the same time so network and compute kernels will not overlap execution
|
|
2099
|
+
and instead do time sharing which may reduce XPU utilization.
|
|
2100
|
+
"""
|
|
2101
|
+
|
|
2102
|
+
@classmethod
|
|
2103
|
+
def _create_pg(cls, store: Store, rank: int, world_size: int) -> BaseProcessGroup:
|
|
2104
|
+
# Check if XPU and XCCL are available
|
|
2105
|
+
from torch.distributed import ProcessGroupXCCL as BaseProcessGroupXCCL
|
|
2106
|
+
|
|
2107
|
+
pg = BaseProcessGroup(store, rank, world_size)
|
|
2108
|
+
pg._set_default_backend(ProcessGroup.BackendType.XCCL)
|
|
2109
|
+
# pyre-fixme[16]: no attribute ProcessGroupNCCL
|
|
2110
|
+
backend_class = BaseProcessGroupXCCL(store, rank, world_size)
|
|
2111
|
+
backend_class._set_sequence_number_for_group()
|
|
2112
|
+
pg._register_backend(
|
|
2113
|
+
torch.device("xpu"), ProcessGroup.BackendType.XCCL, backend_class
|
|
2114
|
+
)
|
|
2115
|
+
return pg
|
|
2116
|
+
|
|
2117
|
+
def getBackendName(self) -> str:
|
|
2118
|
+
return "torchft-baby-xccl"
|