torchft-nightly 2026.1.3__cp310-cp310-manylinux_2_24_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. torchft/__init__.py +34 -0
  2. torchft/_test/diloco_trainer.py +287 -0
  3. torchft/_test/managed_work_test.py +320 -0
  4. torchft/_test_utils.py +111 -0
  5. torchft/_torchft.cpython-310-x86_64-linux-gnu.so +0 -0
  6. torchft/_torchft.pyi +116 -0
  7. torchft/checkpointing/__init__.py +20 -0
  8. torchft/checkpointing/_rwlock.py +136 -0
  9. torchft/checkpointing/_serialization.py +39 -0
  10. torchft/checkpointing/http_transport.py +299 -0
  11. torchft/checkpointing/http_transport_bench.py +61 -0
  12. torchft/checkpointing/http_transport_test.py +146 -0
  13. torchft/checkpointing/pg_transport.py +306 -0
  14. torchft/checkpointing/pg_transport_bench.py +99 -0
  15. torchft/checkpointing/pg_transport_test.py +101 -0
  16. torchft/checkpointing/rwlock_test.py +58 -0
  17. torchft/checkpointing/transport.py +68 -0
  18. torchft/checkpointing/transport_test.py +161 -0
  19. torchft/collectives.py +415 -0
  20. torchft/collectives_test.py +212 -0
  21. torchft/coordination.py +39 -0
  22. torchft/coordination_test.py +29 -0
  23. torchft/data.py +77 -0
  24. torchft/data_test.py +39 -0
  25. torchft/ddp.py +105 -0
  26. torchft/ddp_test.py +68 -0
  27. torchft/diloco_regression_test.py +644 -0
  28. torchft/examples/slurm/README.md +34 -0
  29. torchft/examples/slurm/punisher.py +95 -0
  30. torchft/examples/slurm/runner.py +221 -0
  31. torchft/fsdp_test.py +102 -0
  32. torchft/futures.py +353 -0
  33. torchft/futures_test.py +140 -0
  34. torchft/http.py +13 -0
  35. torchft/lighthouse_test.py +163 -0
  36. torchft/local_sgd.py +796 -0
  37. torchft/local_sgd_integ_test.py +600 -0
  38. torchft/local_sgd_test.py +324 -0
  39. torchft/manager.py +1358 -0
  40. torchft/manager_integ_test.py +653 -0
  41. torchft/manager_test.py +911 -0
  42. torchft/multiprocessing.py +38 -0
  43. torchft/multiprocessing_dummy_context.py +135 -0
  44. torchft/multiprocessing_test.py +58 -0
  45. torchft/optim.py +63 -0
  46. torchft/optim_test.py +50 -0
  47. torchft/otel.py +134 -0
  48. torchft/parameter_server.py +195 -0
  49. torchft/parameter_server_test.py +47 -0
  50. torchft/process_group.py +2118 -0
  51. torchft/process_group_test.py +1028 -0
  52. torchft/quantization.py +686 -0
  53. torchft/quantization_test.py +131 -0
  54. torchft/torchx.py +89 -0
  55. torchft/utils.py +67 -0
  56. torchft/work.py +26 -0
  57. torchft_nightly-2026.1.3.dist-info/METADATA +308 -0
  58. torchft_nightly-2026.1.3.dist-info/RECORD +61 -0
  59. torchft_nightly-2026.1.3.dist-info/WHEEL +4 -0
  60. torchft_nightly-2026.1.3.dist-info/entry_points.txt +2 -0
  61. torchft_nightly-2026.1.3.dist-info/licenses/LICENSE +34 -0
@@ -0,0 +1,2118 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Process Groups
9
+ =========================
10
+
11
+ This module implements fault tolerant process groups that can be reconfigured
12
+ and resized at runtime.
13
+
14
+ These extend the standard PyTorch ProcessGroup API and can be used in most
15
+ places that would accept a standard process group. As these can change size at
16
+ runtime users need to take care to not assume a static rank or world size.
17
+ """
18
+
19
+ import logging
20
+ import os
21
+ import threading
22
+ import time
23
+ import warnings
24
+ from contextlib import contextmanager, nullcontext
25
+ from dataclasses import dataclass
26
+ from datetime import timedelta
27
+ from multiprocessing.connection import Connection
28
+ from typing import (
29
+ Any,
30
+ Callable,
31
+ cast,
32
+ Dict,
33
+ Generator,
34
+ List,
35
+ Optional,
36
+ Tuple,
37
+ TYPE_CHECKING,
38
+ TypeVar,
39
+ Union,
40
+ )
41
+
42
+ import torch
43
+ import torch.distributed as dist
44
+ import torch.multiprocessing as mp
45
+
46
+ # pyre-fixme[21]: no attribute ProcessGroupGloo
47
+ from torch.distributed import (
48
+ PrefixStore,
49
+ ProcessGroup as BaseProcessGroup,
50
+ ProcessGroupGloo as BaseProcessGroupGloo,
51
+ Store,
52
+ TCPStore,
53
+ )
54
+ from torch.distributed.distributed_c10d import (
55
+ AllgatherOptions,
56
+ AllreduceCoalescedOptions,
57
+ AllreduceOptions,
58
+ AllToAllOptions,
59
+ BarrierOptions,
60
+ BroadcastOptions,
61
+ ReduceOp,
62
+ ReduceScatterOptions,
63
+ Work,
64
+ )
65
+ from torch.futures import Future
66
+ from torch.utils._pytree import tree_any
67
+
68
+ # We import these for backwards compatibility
69
+ from torchft.futures import context_timeout, stream_timeout
70
+ from torchft.multiprocessing import _MonitoredPipe
71
+ from torchft.utils import get_stream_context, record_event, synchronize
72
+ from torchft.work import _DummyWork
73
+
74
+ if TYPE_CHECKING:
75
+ from torchft.manager import Manager
76
+
77
+ logger: logging.Logger = logging.getLogger(__name__)
78
+
79
+ # TODO: use non strings which are cheaper
80
+ _QUEUE_CLOSE = "queue_close"
81
+ _FUTURE_RESULT = "fut_result"
82
+ _FUTURE_EXCEPTION = "fut_exception"
83
+
84
+
85
+ T = TypeVar("T")
86
+
87
+ TORCH_NCCL_DEBUG_INFO_PIPE_FILE_ENV_VAR = "TORCH_NCCL_DEBUG_INFO_PIPE_FILE"
88
+ # Used to trigger flight recorder if we trigger abort on the process group
89
+ TORCHFT_TRIGGER_FR_ON_ABORT = "TORCHFT_TRIGGER_FR_ON_ABORT"
90
+
91
+
92
+ def trigger_nccl_fr_trace_through_pipe(rank: int) -> bool:
93
+ """Collect NCCL flight recorder trace through the pipe."""
94
+ dump_file_prefix = os.environ.get(TORCH_NCCL_DEBUG_INFO_PIPE_FILE_ENV_VAR, "")
95
+ if not dump_file_prefix:
96
+ logging.info(
97
+ f"[rank {rank}] Triggering FR trace dump through pipe failed: pipe is not enabled."
98
+ )
99
+ return False
100
+ pipe_name = f"{dump_file_prefix}{rank}.pipe"
101
+ with open(pipe_name, "w") as f:
102
+ # Trigger fr trace dump through pipe
103
+ logging.info(f"[rank {rank}] Triggering FR trace dump through pipe...")
104
+ f.write("1\n")
105
+ time.sleep(60)
106
+ return True
107
+
108
+
109
+ def create_store_client(store_addr: str, timeout: timedelta) -> Store:
110
+ """
111
+ Creates a PrefixStore(TCPStore(...)) client from an address in the format:
112
+
113
+ host:port/prefix
114
+
115
+ Ex: localhost:1234/my/prefix
116
+ """
117
+ host, _, rest = store_addr.partition(":")
118
+ port, _, prefix = rest.partition("/")
119
+
120
+ store = TCPStore(
121
+ host_name=host,
122
+ port=int(port),
123
+ is_master=False,
124
+ wait_for_workers=False,
125
+ timeout=timeout,
126
+ )
127
+ store = PrefixStore(prefix, store)
128
+ return store
129
+
130
+
131
+ class ProcessGroup(BaseProcessGroup):
132
+ def __init__(self, *args: object, **kwargs: object) -> None:
133
+ # pyre-fixme[6]: got object
134
+ super().__init__(*args, **kwargs)
135
+
136
+ self._group_name: Optional[str] = None
137
+
138
+ # pyre-fixme[14]: inconsistent override
139
+ def allgather(
140
+ self,
141
+ output_tensors: List[List[torch.Tensor]],
142
+ input_tensor: List[torch.Tensor],
143
+ opts: AllgatherOptions,
144
+ ) -> Work:
145
+ """
146
+ Gathers tensors from the whole group in a list.
147
+
148
+ See torch.distributed.all_gather for more details.
149
+ """
150
+ raise NotImplementedError("not implemented")
151
+
152
+ # pyre-fixme[14]: inconsistent override
153
+ def allgather_into_tensor_coalesced(
154
+ self,
155
+ output_tensors: List[torch.Tensor],
156
+ input_tensors: List[torch.Tensor],
157
+ opts: AllgatherOptions,
158
+ ) -> Work:
159
+ """
160
+ Performs an allgather operation on coalesced tensors.
161
+
162
+ See torch.distributed.allgather_coalesced for more details.
163
+ """
164
+ raise NotImplementedError("not implemented")
165
+
166
+ # pyre-fixme[14]: inconsistent override
167
+ def allreduce(
168
+ self,
169
+ tensors: List[torch.Tensor],
170
+ opts: Union[AllreduceOptions, ReduceOp],
171
+ ) -> Work:
172
+ """
173
+ Reduces the tensor data across all machines in such a way that all get the final result.
174
+
175
+ See torch.distributed.all_reduce for more details.
176
+ """
177
+ raise NotImplementedError("not implemented")
178
+
179
+ def allreduce_coalesced(
180
+ self,
181
+ tensors: List[torch.Tensor],
182
+ opts: AllreduceCoalescedOptions,
183
+ ) -> Work:
184
+ """
185
+ Performs an all_reduce operation in a coalesced manner.
186
+
187
+ See torch.distributed.all_reduce_coalesced for more details.
188
+ """
189
+ raise NotImplementedError("not implemented")
190
+
191
+ # pyre-fixme[14]: inconsistent override
192
+ def alltoall_base(
193
+ self,
194
+ output_buffer: torch.Tensor,
195
+ input_buffer: torch.Tensor,
196
+ output_split_sizes: List[int],
197
+ input_split_sizes: List[int],
198
+ opts: AllToAllOptions,
199
+ ) -> Work:
200
+ """
201
+ Performs an all_to_all operation.
202
+
203
+ See torch.distributed.all_to_all_single for more details.
204
+ """
205
+ raise NotImplementedError("not implemented")
206
+
207
+ # pyre-fixme[14]: inconsistent override
208
+ def barrier(self, opts: BarrierOptions) -> Work:
209
+ """
210
+ Synchronizes all processes.
211
+
212
+ See torch.distributed.barrier for more details.
213
+ """
214
+ raise NotImplementedError("not implemented")
215
+
216
+ # pyre-fixme[14]: inconsistent override
217
+ def broadcast(
218
+ self, tensor_list: List[torch.Tensor], opts: BroadcastOptions
219
+ ) -> Work:
220
+ """
221
+ Broadcasts the tensor to the whole group.
222
+
223
+ See torch.distributed.broadcast for more details.
224
+ """
225
+ raise NotImplementedError("not implemented")
226
+
227
+ def broadcast_one(self, tensor: torch.Tensor, root: int) -> Work:
228
+ opts = BroadcastOptions()
229
+ opts.rootRank = root
230
+ return self.broadcast([tensor], opts)
231
+
232
+ # pyre-fixme[14]: inconsistent override
233
+ def recv(self, tensors: List[torch.Tensor], src_rank: int, tag: int) -> Work:
234
+ """
235
+ Receives a list of tensors from the process with rank `rank`.
236
+
237
+ See torch.distributed.recv for more details.
238
+ """
239
+ raise NotImplementedError("not implemented")
240
+
241
+ # pyre-fixme[14]: inconsistent override
242
+ def reduce_scatter(
243
+ self,
244
+ output_tensors: List[torch.Tensor],
245
+ input_tensors: List[List[torch.Tensor]],
246
+ opts: ReduceScatterOptions,
247
+ ) -> Work:
248
+ """
249
+ Reduces, then scatters a list of tensors to all processes in a group.
250
+
251
+ See torch.distributed.reduce_scatter for more details.
252
+ """
253
+ raise NotImplementedError("not implemented")
254
+
255
+ # pyre-fixme[14]: inconsistent override
256
+ def reduce_scatter_tensor_coalesced(
257
+ self,
258
+ output_tensors: List[torch.Tensor],
259
+ input_tensors: List[torch.Tensor],
260
+ opts: ReduceScatterOptions,
261
+ ) -> Work:
262
+ """
263
+ Performs a reduce-scatter operation on coalesced tensors.
264
+
265
+ See torch.distributed.reduce_scatter_tensor for more details.
266
+ """
267
+ raise NotImplementedError("not implemented")
268
+
269
+ # pyre-fixme[14]: inconsistent override
270
+ def send(self, tensors: List[torch.Tensor], dst_rank: int, tag: int) -> Work:
271
+ """
272
+ Sends a list of tensors to the process with rank `dst_rank`.
273
+
274
+ See torch.distributed.send for more details.
275
+ """
276
+ raise NotImplementedError("not implemented")
277
+
278
+ def configure(
279
+ self,
280
+ store_addr: str,
281
+ replica_id: str,
282
+ rank: int,
283
+ world_size: int,
284
+ quorum_id: Optional[int] = None,
285
+ group_rank: Optional[int] = None,
286
+ group_world_size: Optional[int] = None,
287
+ global_ranks: Optional[list[int]] = None,
288
+ ) -> None:
289
+ """
290
+ This reconfigures the ProcessGroup to use a new store, rank and world size.
291
+
292
+ Every time this is called it must be provided with a unique prefixed
293
+ store address. I.e. localhost:1234/my/prefix/1
294
+
295
+ This function will block until the underlying ProcessGroup is created.
296
+ If an error occurs this will throw.
297
+
298
+ Args:
299
+ store_addr: address of the store to use
300
+ replica_id: the replica_id for this group
301
+ rank: rank of this process
302
+ world_size: world size of this process group
303
+ quorum_id: current quorum's identifier
304
+ group_rank: local rank within the replica group
305
+ group_world_size: the number of ranks within a replica
306
+ global_ranks: the global ranks part of this process group
307
+ """
308
+ raise NotImplementedError("not implemented")
309
+
310
+ def size(self) -> int:
311
+ raise NotImplementedError("not implemented")
312
+
313
+ def getBackendName(self) -> str:
314
+ raise NotImplementedError("not implemented")
315
+
316
+ def _register(self, name: str) -> str:
317
+ group_name = f"{self.getBackendName()}:{name}"
318
+
319
+ # This is needed for DeviceMesh and functional collectives to work.
320
+ # Resizable worlds don't fit well into DeviceMesh so we register a world
321
+ # size 1 PG.
322
+
323
+ def create_pg(
324
+ prefix_store: PrefixStore, rank: int, world_size: int, timeout: float
325
+ ) -> ProcessGroup:
326
+ return self
327
+
328
+ devices = ["cpu"]
329
+ if torch.cuda.is_available():
330
+ devices.append("cuda")
331
+ elif torch.xpu.is_available():
332
+ devices.append("xpu")
333
+ dist.Backend.register_backend(group_name, create_pg, devices=devices)
334
+
335
+ return group_name
336
+
337
+ def register(self, name: str) -> "ProcessGroup":
338
+ """
339
+ Registers the process group with the global registry. This enables usage
340
+ with things like functional_collectives which are compilable.
341
+
342
+ This should only be called once.
343
+
344
+ Args:
345
+ name: name must be a unique name for this process group
346
+ """
347
+
348
+ group_name = self._register(name)
349
+
350
+ return dist.new_group(
351
+ ranks=[dist.get_rank()],
352
+ backend=group_name,
353
+ group_desc=group_name,
354
+ timeout=timedelta(seconds=60.0), # this timeout isn't used
355
+ )
356
+
357
+ @property
358
+ def group_name(self) -> str:
359
+ if self._group_name is None:
360
+ raise ValueError("ProcessGroup name not set")
361
+ return self._group_name
362
+
363
+ def _set_group_name(self, name: str) -> None:
364
+ self._group_name = name
365
+
366
+ def unregister(self) -> None:
367
+ """
368
+ Unregisters the process group with the global registry.
369
+
370
+ Must be registered first.
371
+ """
372
+ dist.destroy_process_group(self)
373
+
374
+ def abort(self) -> None:
375
+ """
376
+ Aborts the process group.
377
+ """
378
+ pass
379
+
380
+ def shutdown(self) -> None:
381
+ """
382
+ Shuts down the process group.
383
+ """
384
+ pass
385
+
386
+ def errored(self) -> Optional[Exception]:
387
+ """
388
+ Whether an async error occured that requires reconfiguration.
389
+ """
390
+ return None
391
+
392
+ def set_timeout(self, timeout: timedelta) -> None:
393
+ """
394
+ Sets the default timeout for the process group.
395
+ """
396
+ raise NotImplementedError("set_timeout not implemented")
397
+
398
+ def __repr__(self) -> str:
399
+ return f"{self.__class__.__name__}()"
400
+
401
+
402
+ class ProcessGroupWrapper(ProcessGroup):
403
+ """
404
+ This is a wrapper around any ProcessGroup with a reconfiguration method.
405
+
406
+ Args:
407
+ timeout: timeout for reconfiguration for TCPStore
408
+ pg: optional ProcessGroup to use, if None a new one will be created
409
+ """
410
+
411
+ def __init__(
412
+ self,
413
+ timeout: timedelta = timedelta(seconds=60),
414
+ pg: Optional[ProcessGroup] = None,
415
+ ) -> None:
416
+ super().__init__(0, 1)
417
+ self._pg: Optional[BaseProcessGroup] = pg
418
+ self._timeout = timeout
419
+ self._replica_id: str | None = None
420
+ self._rank: int | None = None
421
+ self._quorum_id: int | None = None
422
+ self._group_rank: int | None = None
423
+ self._group_world_size: int | None = None
424
+ self._global_ranks: list[int] | None = None
425
+
426
+ self.errors_logger: logging.Logger = logging.getLogger("torchft_errors")
427
+
428
+ def getBackendName(self) -> str:
429
+ pg = self._pg
430
+ if isinstance(pg, ProcessGroup):
431
+ return pg.getBackendName()
432
+
433
+ raise NotImplementedError("not implemented")
434
+
435
+ def configure(
436
+ self,
437
+ store_addr: str,
438
+ replica_id: str,
439
+ rank: int,
440
+ world_size: int,
441
+ quorum_id: Optional[int] = None,
442
+ group_rank: Optional[int] = None,
443
+ group_world_size: Optional[int] = None,
444
+ global_ranks: Optional[list[int]] = None,
445
+ ) -> None:
446
+ pg = self._pg
447
+ self._replica_id = replica_id
448
+ self._quorum_id = quorum_id
449
+ self._group_rank = group_rank
450
+ self._group_world_size = group_world_size
451
+ self._rank = rank
452
+ self._global_ranks = global_ranks
453
+ if isinstance(pg, ProcessGroup):
454
+ pg.configure(
455
+ store_addr,
456
+ replica_id,
457
+ rank,
458
+ world_size,
459
+ quorum_id,
460
+ group_rank,
461
+ group_world_size,
462
+ global_ranks,
463
+ )
464
+ return
465
+
466
+ # abort if already initialized
467
+ self.abort(errored=False)
468
+
469
+ store = create_store_client(store_addr, timeout=self._timeout)
470
+
471
+ self._pg = self._create_pg(store, rank, world_size)
472
+
473
+ def abort(self, errored: bool = True) -> None:
474
+ if errored:
475
+ self.errors_logger.info(
476
+ "",
477
+ extra={
478
+ "job_id": os.environ.get("JOB_ID", "unknown"),
479
+ "replica_id": self._replica_id,
480
+ "rank": self._rank,
481
+ "quorum_id": self._quorum_id,
482
+ "error": "process_group_abort",
483
+ },
484
+ )
485
+ pg = self._pg
486
+ if pg is not None:
487
+ if hasattr(pg, "abort"):
488
+ pg.abort()
489
+ else:
490
+ backend = None
491
+ try:
492
+ if torch.cuda.is_available():
493
+ backend = pg._get_backend(torch.device("cuda"))
494
+ elif torch.xpu.is_available():
495
+ backend = pg._get_backend(torch.device("xpu"))
496
+ except RuntimeError:
497
+ backend = None
498
+ if backend is not None and hasattr(backend, "abort"):
499
+ backend.abort()
500
+
501
+ self._pg = None
502
+
503
+ def shutdown(self) -> None:
504
+ # TODO: abort PG if possible
505
+ self._pg = None
506
+
507
+ def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGroup:
508
+ raise NotImplementedError("not implemented")
509
+
510
+ def _wrap_work(self, work: Work, opts: object) -> Work:
511
+ return work
512
+
513
+ def _opts_hook(self, opts: T) -> T:
514
+ return opts
515
+
516
+ @contextmanager
517
+ def _run_context(self) -> Generator[None, None, None]:
518
+ yield
519
+
520
+ def set_timeout(self, timeout: timedelta) -> None:
521
+ self._timeout = timeout
522
+
523
+ def allgather(
524
+ self,
525
+ output_tensors: List[List[torch.Tensor]],
526
+ input_tensor: List[torch.Tensor],
527
+ opts: AllgatherOptions,
528
+ ) -> Work:
529
+ with self._run_context():
530
+ return self._wrap_work(
531
+ self.parent.allgather(
532
+ output_tensors, input_tensor, self._opts_hook(opts)
533
+ ),
534
+ opts,
535
+ )
536
+
537
+ def allgather_into_tensor_coalesced(
538
+ self,
539
+ output_tensors: List[torch.Tensor],
540
+ input_tensors: List[torch.Tensor],
541
+ opts: AllgatherOptions,
542
+ ) -> Work:
543
+ with self._run_context():
544
+ return self._wrap_work(
545
+ self.parent.allgather_into_tensor_coalesced(
546
+ output_tensors, input_tensors, self._opts_hook(opts)
547
+ ),
548
+ opts,
549
+ )
550
+
551
+ def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
552
+ with self._run_context():
553
+ return self._wrap_work(
554
+ self.parent.allreduce(tensors, self._opts_hook(opts)), opts
555
+ )
556
+
557
+ def allreduce_coalesced(
558
+ self, tensors: List[torch.Tensor], opts: Union[AllreduceOptions, ReduceOp]
559
+ ) -> Work:
560
+ with self._run_context():
561
+ return self._wrap_work(
562
+ self.parent.allreduce_coalesced(tensors, self._opts_hook(opts)), opts
563
+ )
564
+
565
+ def alltoall_base(
566
+ self,
567
+ output_buffer: torch.Tensor,
568
+ input_buffer: torch.Tensor,
569
+ output_split_sizes: List[int],
570
+ input_split_sizes: List[int],
571
+ opts: AllToAllOptions,
572
+ ) -> Work:
573
+ with self._run_context():
574
+ return self._wrap_work(
575
+ self.parent.alltoall_base(
576
+ output_buffer,
577
+ input_buffer,
578
+ output_split_sizes,
579
+ input_split_sizes,
580
+ self._opts_hook(opts),
581
+ ),
582
+ opts,
583
+ )
584
+
585
+ def barrier(self, opts: Optional[BarrierOptions] = None) -> Work:
586
+ with self._run_context():
587
+ return self._wrap_work(self.parent.barrier(self._opts_hook(opts)), opts)
588
+
589
+ def broadcast(self, tensor_list: List[torch.Tensor], opts: object) -> Work:
590
+ with self._run_context():
591
+ return self._wrap_work(
592
+ self.parent.broadcast(tensor_list, self._opts_hook(opts)), opts
593
+ )
594
+
595
+ def recv(self, tensors: List[torch.Tensor], src_rank: int, tag: int) -> Work:
596
+ with self._run_context():
597
+ return self._wrap_work(self.parent.recv(tensors, src_rank, tag), None)
598
+
599
+ def reduce_scatter(
600
+ self,
601
+ output_tensors: List[torch.Tensor],
602
+ input_tensors: List[List[torch.Tensor]],
603
+ opts: object,
604
+ ) -> Work:
605
+ with self._run_context():
606
+ return self._wrap_work(
607
+ self.parent.reduce_scatter(
608
+ output_tensors, input_tensors, self._opts_hook(opts)
609
+ ),
610
+ opts,
611
+ )
612
+
613
+ def reduce_scatter_tensor_coalesced(
614
+ self,
615
+ output_tensors: List[torch.Tensor],
616
+ input_tensors: List[torch.Tensor],
617
+ opts: ReduceScatterOptions,
618
+ ) -> Work:
619
+ with self._run_context():
620
+ return self._wrap_work(
621
+ self.parent.reduce_scatter_tensor_coalesced(
622
+ output_tensors, input_tensors, self._opts_hook(opts)
623
+ ),
624
+ opts,
625
+ )
626
+
627
+ def send(self, tensors: List[torch.Tensor], dst_rank: int, tag: int) -> Work:
628
+ with self._run_context():
629
+ return self._wrap_work(self.parent.send(tensors, dst_rank, tag), None)
630
+
631
+ def size(self) -> int:
632
+ return self.parent.size()
633
+
634
+ @property
635
+ def parent(self) -> BaseProcessGroup:
636
+ assert self._pg is not None, "process group not initialized"
637
+ return self._pg
638
+
639
+ def __repr__(self) -> str:
640
+ return f"{self.__class__.__name__}(pg={self._pg})"
641
+
642
+
643
+ class ProcessGroupGloo(ProcessGroupWrapper):
644
+ """
645
+ This is a reconfigurable version of ProcessGroupGloo.
646
+ """
647
+
648
+ def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGroup:
649
+ pg = BaseProcessGroup(store, rank, world_size)
650
+ pg._set_default_backend(ProcessGroup.BackendType.GLOO)
651
+ # pyre-fixme[16]: no attribute ProcessGroupGloo
652
+ backend_class = BaseProcessGroupGloo(store, rank, world_size, self._timeout)
653
+ backend_class._set_sequence_number_for_group()
654
+
655
+ if self._global_ranks:
656
+ backend_class.options.global_ranks_in_group = self._global_ranks
657
+ if self._group_rank and self._group_world_size:
658
+ backend_class.options.group_name = f"torchft_quorum_{self._quorum_id}_rank_{self._group_rank % self._group_world_size}"
659
+
660
+ pg._register_backend(
661
+ torch.device("cpu"), ProcessGroup.BackendType.GLOO, backend_class
662
+ )
663
+ if torch.cuda.is_available():
664
+ pg._register_backend(
665
+ torch.device("cuda"), ProcessGroup.BackendType.GLOO, backend_class
666
+ )
667
+ return pg
668
+
669
+ def getBackendName(self) -> str:
670
+ return "torchft-gloo"
671
+
672
+ # pyre-fixme[14,15]: inconsistent override
673
+ def reduce_scatter(
674
+ self,
675
+ output_tensors: List[torch.Tensor],
676
+ input_tensors: List[List[torch.Tensor]],
677
+ opts: ReduceScatterOptions,
678
+ ) -> None:
679
+ """
680
+ This function is a placeholder for the reduce_scatter operation in the
681
+ ProcessGroupGloo class. However, this operation is not supported by the
682
+ Gloo backend, and thus, calling this function will raise a
683
+ RuntimeError.
684
+
685
+ Raises:
686
+ RuntimeError: Always raised since reduce_scatter is not
687
+ supported by ProcessGroupGloo.
688
+ """
689
+ raise RuntimeError("ProcessGroupGloo does not support reduce_scatter.")
690
+
691
+ # pyre-fixme[15]: inconsistent override
692
+ def reduce_scatter_tensor_coalesced(
693
+ self,
694
+ output_tensors: List[torch.Tensor],
695
+ input_tensors: List[torch.Tensor],
696
+ opts: ReduceScatterOptions,
697
+ ) -> None:
698
+ """
699
+ This function is a placeholder for the reduce_scatter_tensor_coalesced
700
+ operation in the ProcessGroupGloo class.
701
+ However, this operation is not supported by the
702
+ Gloo backend, and thus, calling this function will raise a
703
+ RuntimeError.
704
+
705
+ Raises:
706
+ RuntimeError: Always raised since reduce_scatter is not
707
+ supported by ProcessGroupGloo.
708
+ """
709
+ raise RuntimeError(
710
+ "ProcessGroupGloo does not support reduce_scatter_tensor_coalesced."
711
+ )
712
+
713
+
714
+ class _WorkAcceleratorTimeout(Work):
715
+ def __init__(self, pg: ProcessGroup, work: Work, timeout: timedelta) -> None:
716
+ super().__init__()
717
+ self._pg = pg
718
+ self._work = work
719
+ self._timeout = timeout
720
+
721
+ def wait(self, timeout: Optional[timedelta] = None) -> bool:
722
+ async_timeout = timeout or self._timeout
723
+ with self._stream_timeout(self._pg, async_timeout):
724
+ # In newer versions of PyTorch work may not exist if the call was
725
+ # not async. In these cases we can just schedule the stream timeout
726
+ # and return.
727
+ if self._work is not None:
728
+ if not self._work.wait():
729
+ return False
730
+
731
+ # Always use cuda stream for timeout to avoid ProcessGroupNCCL
732
+ # watchdog firing and crashing the process.
733
+ if timeout is not None:
734
+ torch.cuda.synchronize()
735
+
736
+ return True
737
+
738
+ @classmethod
739
+ @contextmanager
740
+ def _stream_timeout(
741
+ cls, pg: ProcessGroup, timeout: timedelta
742
+ ) -> Generator[None, None, None]:
743
+ """
744
+ Set a timeout on the CUDA stream for the given process group.
745
+
746
+ This does not hold a reference to self to avoid holding the work
747
+ object/tensors longer than necessary.
748
+
749
+ Args:
750
+ pg: The process group to call abort on.
751
+ timeout: The timeout to set on the CUDA stream.
752
+ """
753
+
754
+ def callback() -> None:
755
+ logger.error(f"aborting after {timeout}!")
756
+ pg.abort()
757
+
758
+ # make sure .wait() can be cancelled if it blocks i.e. in barrier
759
+ with context_timeout(callback, timeout):
760
+ yield
761
+
762
+ # Cancel work if the cuda stream doesn't complete
763
+ stream_timeout(callback, timeout)
764
+
765
+ def get_future(self) -> torch.futures.Future[object]:
766
+ fut = self._work.get_future()
767
+
768
+ def done_callback(fut: torch.futures.Future[object]) -> None:
769
+ try:
770
+ with self._stream_timeout(self._pg, self._timeout):
771
+ fut.wait()
772
+
773
+ except Exception as e:
774
+ logger.error(f"done callback failed: {e}")
775
+
776
+ fut.add_done_callback(done_callback)
777
+ return fut
778
+
779
+
780
+ class ProcessGroupNCCL(ProcessGroupWrapper):
781
+ """
782
+ This is a reconfigurable version of ProcessGroupNCCL.
783
+
784
+ If you are using a supported version of NCCL (NCCL >= 2.26, torch >= 2.7)
785
+ this will attempt to use ncclCommAbort to recover from any timeouts.
786
+
787
+ This uses a Python user space event loop to asynchronously wait for the NCCL
788
+ operations to complete. This should not be used with very long timeouts as
789
+ the timeout entries are not cleaned up until the elapsed duration completes
790
+ which may result in slowness or excess memory usage.
791
+
792
+ WARNING: this may result in deadlocks due to NCCL error handling and on old
793
+ versions of torch/NCCL will result in deadlocks.
794
+
795
+ Args:
796
+ timeout: the timeout to use for NCCL operations.
797
+ """
798
+
799
+ def __init__(self, timeout: timedelta = timedelta(seconds=60.0)) -> None:
800
+ super().__init__(timeout)
801
+ self._use_abort: bool = torch.cuda.nccl.version() >= (2, 25)
802
+
803
+ self._errored: Optional[Exception] = None
804
+
805
+ NONBLOCKING_TIMEOUT_ENV = "TORCH_NCCL_NONBLOCKING_TIMEOUT"
806
+ if NONBLOCKING_TIMEOUT_ENV not in os.environ:
807
+ warnings.warn(
808
+ f"{NONBLOCKING_TIMEOUT_ENV} is not set, defaulting to {timeout}. "
809
+ "If any nonblocking NCCL operations have already run this may "
810
+ "result in the default timeout of 30 minutes and hangs on error."
811
+ )
812
+ os.environ[NONBLOCKING_TIMEOUT_ENV] = str(timeout.total_seconds())
813
+
814
+ def _opts_hook(self, opts: T) -> T:
815
+ if not self._use_abort:
816
+ return opts
817
+
818
+ # We need to clear the timeout to apply our own timeout that doesn't
819
+ # crash the whole program.
820
+ if hasattr(opts, "timeout"):
821
+ # apply default timeout to disable
822
+ opts.timeout = AllgatherOptions().timeout
823
+ return opts
824
+
825
+ def _wrap_work(self, work: Work, opts: object) -> Work:
826
+ if not self._use_abort:
827
+ return work
828
+
829
+ timeout = self._timeout
830
+ # pyre-fixme[16]: no attribute timeout
831
+ if hasattr(opts, "timeout") and opts.timeout.total_seconds() > 0:
832
+ timeout = opts.timeout
833
+ return _WorkAcceleratorTimeout(self, work, timeout)
834
+
835
+ @contextmanager
836
+ def _run_context(self) -> Generator[None, None, None]:
837
+ timeout: timedelta = self._timeout
838
+
839
+ def callback() -> None:
840
+ logger.error(f"aborting after {timeout}!")
841
+ self.abort()
842
+
843
+ # when running in blocking mode we need to make sure collectives can
844
+ # timeout
845
+ with context_timeout(callback, timeout):
846
+ yield
847
+
848
+ def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGroup:
849
+ # pyre-fixme[21]: no attribute ProcessGroupNCCL
850
+ from torch.distributed import ProcessGroupNCCL as BaseProcessGroupNCCL
851
+
852
+ self._errored = None
853
+
854
+ # pyre-fixme[16]: no attribute ProcessGroupNCCL
855
+ opts = BaseProcessGroupNCCL.Options()
856
+ opts.config.blocking = False
857
+ if self._global_ranks:
858
+ opts.global_ranks_in_group = self._global_ranks
859
+ if self._group_rank and self._group_world_size:
860
+ opts.group_name = f"torchft_quorum_{self._quorum_id}_rank_{self._group_rank % self._group_world_size}"
861
+
862
+ pg = BaseProcessGroup(store, rank, world_size)
863
+ pg._set_default_backend(ProcessGroup.BackendType.NCCL)
864
+ # pyre-fixme[16]: no attribute ProcessGroupNCCL
865
+ backend_class = BaseProcessGroupNCCL(store, rank, world_size, opts)
866
+ backend_class._set_sequence_number_for_group()
867
+ backend_class.eager_connect_single_device(
868
+ torch.device(torch.accelerator.current_device_index())
869
+ )
870
+ pg._register_backend(
871
+ torch.device("cuda"), ProcessGroup.BackendType.NCCL, backend_class
872
+ )
873
+ return pg
874
+
875
+ def abort(self, errored: bool = True) -> None:
876
+ # We need to set the error before aborting to ensure that errored()
877
+ # returns the error correctly when NCCL abort fires and unblocks the
878
+ # stream.
879
+ if os.environ.get("TORCHFT_TRIGGER_FR_ON_ABORT", "false") == "true":
880
+ trigger_nccl_fr_trace_through_pipe(dist.get_rank())
881
+ self._errored = RuntimeError("aborted")
882
+
883
+ super().abort(errored=errored)
884
+
885
+ def errored(self) -> Optional[Exception]:
886
+ # force a synchronization to ensure all work is complete
887
+ synchronize()
888
+ return self._errored
889
+
890
+ def getBackendName(self) -> str:
891
+ return "torchft-nccl"
892
+
893
+
894
+ class ProcessGroupXCCL(ProcessGroupWrapper):
895
+ """
896
+ This is a reconfigurable version of ProcessGroupXCCL for Intel XPU devices.
897
+
898
+ This process group is designed to work with Intel XPU devices using XCCL
899
+ (eXtended Collective Communication Library). It provides similar functionality
900
+ to ProcessGroupNCCL but optimized for Intel XPU architecture.
901
+
902
+ If you are using a supported version of XCCL, this will attempt to use
903
+ xccl abort mechanisms to recover from any timeouts.
904
+
905
+ This uses a Python user space event loop to asynchronously wait for the XCCL
906
+ operations to complete. This should not be used with very long timeouts as
907
+ the timeout entries are not cleaned up until the elapsed duration completes
908
+ which may result in slowness or excess memory usage.
909
+
910
+ Args:
911
+ timeout: the timeout to use for XCCL operations.
912
+ """
913
+
914
+ def __init__(self, timeout: timedelta = timedelta(seconds=60.0)) -> None:
915
+ super().__init__(timeout)
916
+ # Check if XPU is available and XCCL is supported
917
+ self._use_abort: bool = torch.xpu.is_available()
918
+
919
+ self._errored: Optional[Exception] = None
920
+
921
+ NONBLOCKING_TIMEOUT_ENV = "TORCH_XCCL_NONBLOCKING_TIMEOUT"
922
+ if NONBLOCKING_TIMEOUT_ENV not in os.environ:
923
+ warnings.warn(
924
+ f"{NONBLOCKING_TIMEOUT_ENV} is not set, defaulting to {timeout}. "
925
+ "If any nonblocking XCCL operations have already run this may "
926
+ "result in the default timeout of 30 minutes and hangs on error."
927
+ )
928
+ os.environ[NONBLOCKING_TIMEOUT_ENV] = str(timeout.total_seconds())
929
+
930
+ def _opts_hook(self, opts: T) -> T:
931
+ if not self._use_abort:
932
+ return opts
933
+
934
+ # We need to clear the timeout to apply our own timeout that doesn't
935
+ # crash the whole program.
936
+ if hasattr(opts, "timeout"):
937
+ # apply default timeout to disable
938
+ opts.timeout = AllgatherOptions().timeout
939
+ return opts
940
+
941
+ def _wrap_work(self, work: Work, opts: object) -> Work:
942
+ if not self._use_abort:
943
+ return work
944
+
945
+ timeout = self._timeout
946
+ # pyre-fixme[16]: no attribute timeout
947
+ if hasattr(opts, "timeout") and opts.timeout.total_seconds() > 0:
948
+ timeout = opts.timeout
949
+ return _WorkAcceleratorTimeout(self, work, timeout)
950
+
951
+ @contextmanager
952
+ def _run_context(self) -> Generator[None, None, None]:
953
+ timeout: timedelta = self._timeout
954
+
955
+ def callback() -> None:
956
+ logger.error(f"aborting after {timeout}!")
957
+ self.abort()
958
+
959
+ # when running in blocking mode we need to make sure collectives can
960
+ # timeout
961
+ with context_timeout(callback, timeout):
962
+ yield
963
+
964
+ def _create_pg(self, store: Store, rank: int, world_size: int) -> BaseProcessGroup:
965
+ # pyre-fixme[21]: no attribute ProcessGroupXCCL
966
+ from torch.distributed import ProcessGroupXCCL as BaseProcessGroupXCCL
967
+
968
+ self._errored = None
969
+
970
+ # pyre-fixme[16]: no attribute ProcessGroupXCCL
971
+ opts = BaseProcessGroupXCCL.Options()
972
+ # opts.config.blocking = False
973
+
974
+ pg = BaseProcessGroup(store, rank, world_size)
975
+ pg._set_default_backend(ProcessGroup.BackendType.XCCL)
976
+ # pyre-fixme[16]: no attribute ProcessGroupXCCL
977
+ backend_class = BaseProcessGroupXCCL(store, rank, world_size, opts)
978
+ backend_class._set_sequence_number_for_group()
979
+ backend_class.eager_connect_single_device(
980
+ torch.device(torch.accelerator.current_device_index())
981
+ )
982
+ pg._register_backend(
983
+ torch.device("xpu"), ProcessGroup.BackendType.XCCL, backend_class
984
+ )
985
+ return pg
986
+
987
+ def abort(self, errored: bool = True) -> None:
988
+ # We need to set the error before aborting to ensure that errored()
989
+ # returns the error correctly when XCCL abort fires and unblocks the
990
+ # stream.
991
+ self._errored = RuntimeError("aborted")
992
+
993
+ super().abort(errored)
994
+
995
+ def errored(self) -> Optional[Exception]:
996
+ # force a synchronization to ensure all work is complete
997
+ torch.xpu.current_stream().synchronize()
998
+
999
+ return self._errored
1000
+
1001
+ def getBackendName(self) -> str:
1002
+ return "torchft-xccl"
1003
+
1004
+
1005
+ class ProcessGroupDummy(ProcessGroup):
1006
+ """
1007
+ This process group discards all data passed to it and returns success. This
1008
+ is intended for rare cases where we want to discard certain operations
1009
+ without modifying the underlying library.
1010
+
1011
+ This PG only supports world_size of 1.
1012
+ """
1013
+
1014
+ def __init__(self, rank: int, world: int) -> None:
1015
+ super().__init__(rank, world)
1016
+ assert rank == 0
1017
+ assert world == 1
1018
+
1019
+ self._rank = rank
1020
+ self._world = world
1021
+ self.wait_count = 0
1022
+ self.get_future_count = 0
1023
+ self._work: List[Work] = []
1024
+ self.configure_count = 0
1025
+
1026
+ def configure(
1027
+ self,
1028
+ store_addr: str,
1029
+ replica_id: str,
1030
+ rank: int,
1031
+ world_size: int,
1032
+ quorum_id: Optional[int] = None,
1033
+ group_rank: Optional[int] = None,
1034
+ group_world_size: Optional[int] = None,
1035
+ global_ranks: Optional[list[int]] = None,
1036
+ ) -> None:
1037
+ self.configure_count += 1
1038
+
1039
+ def allgather(
1040
+ self,
1041
+ output_tensors: List[List[torch.Tensor]],
1042
+ input_tensor: List[torch.Tensor],
1043
+ opts: object,
1044
+ ) -> Work:
1045
+ for o, i in zip(output_tensors[0], input_tensor):
1046
+ o.copy_(i)
1047
+
1048
+ res = _DummyWork(output_tensors)
1049
+ self._work.append(res)
1050
+ return res
1051
+
1052
+ def allgather_into_tensor_coalesced(
1053
+ self,
1054
+ output_tensors: List[torch.Tensor],
1055
+ input_tensors: List[torch.Tensor],
1056
+ opts: AllgatherOptions,
1057
+ ) -> Work:
1058
+ for o, i in zip(output_tensors, input_tensors):
1059
+ o.copy_(i)
1060
+
1061
+ res = _DummyWork(output_tensors)
1062
+ self._work.append(res)
1063
+ return res
1064
+
1065
+ def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
1066
+ res = _DummyWork(tensors)
1067
+ self._work.append(res)
1068
+ return res
1069
+
1070
+ def allreduce_coalesced(
1071
+ self, tensors: List[torch.Tensor], opts: Union[AllreduceOptions, ReduceOp]
1072
+ ) -> Work:
1073
+ res = _DummyWork(tensors)
1074
+ self._work.append(res)
1075
+ return res
1076
+
1077
+ def alltoall_base(
1078
+ self,
1079
+ output_buffer: torch.Tensor,
1080
+ input_buffer: torch.Tensor,
1081
+ output_split_sizes: List[int],
1082
+ input_split_sizes: List[int],
1083
+ opts: AllToAllOptions,
1084
+ ) -> Work:
1085
+ output_buffer.copy_(input_buffer)
1086
+ res = _DummyWork([output_buffer])
1087
+ self._work.append(res)
1088
+ return res
1089
+
1090
+ def barrier(self, opts: Optional[BarrierOptions] = None) -> Work:
1091
+ return _DummyWork(None)
1092
+
1093
+ def broadcast(self, tensor_list: List[torch.Tensor], opts: object) -> Work:
1094
+ res = _DummyWork(tensor_list)
1095
+ self._work.append(res)
1096
+ return res
1097
+
1098
+ def recv(self, tensors: List[torch.Tensor], src_rank: int, tag: int) -> Work:
1099
+ return _DummyWork(None)
1100
+
1101
+ def reduce_scatter(
1102
+ self,
1103
+ output_tensors: List[torch.Tensor],
1104
+ input_tensors: List[List[torch.Tensor]],
1105
+ opts: object,
1106
+ ) -> Work:
1107
+ for o, i in zip(output_tensors, input_tensors[0]):
1108
+ o.copy_(i)
1109
+
1110
+ res = _DummyWork(output_tensors)
1111
+ self._work.append(res)
1112
+ return res
1113
+
1114
+ def reduce_scatter_tensor_coalesced(
1115
+ self,
1116
+ output_tensors: List[torch.Tensor],
1117
+ input_tensors: List[torch.Tensor],
1118
+ opts: ReduceScatterOptions,
1119
+ ) -> Work:
1120
+ for o, i in zip(output_tensors, input_tensors):
1121
+ o.copy_(i)
1122
+
1123
+ res = _DummyWork(output_tensors)
1124
+ self._work.append(res)
1125
+ return res
1126
+
1127
+ def send(self, tensors: List[torch.Tensor], dst_rank: int, tag: int) -> Work:
1128
+ return _DummyWork(None)
1129
+
1130
+ def size(self) -> int:
1131
+ return self._world
1132
+
1133
+ def getBackendName(self) -> str:
1134
+ return "torchft-dummy"
1135
+
1136
+
1137
+ class _ErrorSwallowingWork(Work):
1138
+ def __init__(
1139
+ self,
1140
+ pg: "ErrorSwallowingProcessGroupWrapper",
1141
+ work: Work,
1142
+ default_result: object,
1143
+ ) -> None:
1144
+ super().__init__()
1145
+
1146
+ self._pg = pg
1147
+ self._work = work
1148
+ self._default_result = default_result
1149
+
1150
+ def wait(self, timeout: Optional[timedelta] = None) -> bool:
1151
+ try:
1152
+ self._work.wait()
1153
+ except Exception as e:
1154
+ self._pg.report_error(e)
1155
+
1156
+ return True
1157
+
1158
+ def get_future(self) -> Future[object]:
1159
+ fut = self._work.get_future()
1160
+
1161
+ # schedule error handling as a continuation on the Future
1162
+ def callback(
1163
+ fut: torch.futures.Future[List[torch.Tensor]],
1164
+ ) -> object:
1165
+ try:
1166
+ return fut.value()
1167
+ except Exception as e:
1168
+ logger.exception(f"got exception in future -- skipping remaining: {e}")
1169
+ self._pg.report_error(e)
1170
+ return self._default_result
1171
+
1172
+ fut = fut.then(callback)
1173
+ return fut
1174
+
1175
+
1176
+ class ErrorSwallowingProcessGroupWrapper(ProcessGroupWrapper):
1177
+ """
1178
+ This is a wrapper around any ProcessGroup that will swallow errors and
1179
+ return dummy results on error.
1180
+
1181
+ This is intended to allow handling errors outside of the training loop to
1182
+ avoid having to modify modeling code to support error handling.
1183
+
1184
+ After an error occurs all future operations will be skipped until the
1185
+ process group is reconfigured via ``configure``.
1186
+ """
1187
+
1188
+ def __init__(self, pg: ProcessGroup) -> None:
1189
+ super().__init__(pg=pg)
1190
+
1191
+ self._error: Optional[Exception] = None
1192
+
1193
+ def configure(
1194
+ self,
1195
+ store_addr: str,
1196
+ replica_id: str,
1197
+ rank: int,
1198
+ world_size: int,
1199
+ quorum_id: Optional[int] = None,
1200
+ group_rank: Optional[int] = None,
1201
+ group_world_size: Optional[int] = None,
1202
+ global_ranks: Optional[list[int]] = None,
1203
+ ) -> None:
1204
+ self._error = None
1205
+
1206
+ super().configure(
1207
+ store_addr,
1208
+ replica_id,
1209
+ rank,
1210
+ world_size,
1211
+ quorum_id,
1212
+ group_rank,
1213
+ group_world_size,
1214
+ global_ranks,
1215
+ )
1216
+
1217
+ def report_error(self, e: Exception) -> None:
1218
+ """
1219
+ Report an error to this process group. This will cause all future
1220
+ operations to be skipped until the process group is reconfigured via
1221
+ ``configure``.
1222
+
1223
+ Args:
1224
+ e: exception to report
1225
+ """
1226
+ self._error = e
1227
+
1228
+ def error(self) -> Optional[Exception]:
1229
+ """
1230
+ Returns the error that was reported to this process group.
1231
+
1232
+ Returns:
1233
+ exception that was reported
1234
+ """
1235
+ return self._error
1236
+
1237
+ def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
1238
+ if self._error is not None:
1239
+ return _DummyWork(tensors)
1240
+
1241
+ try:
1242
+ return _ErrorSwallowingWork(
1243
+ self,
1244
+ super().allreduce(tensors, opts),
1245
+ tensors,
1246
+ )
1247
+ except Exception as e:
1248
+ self.report_error(e)
1249
+ return _DummyWork(tensors)
1250
+
1251
+
1252
+ class FakeProcessGroupWrapper(ProcessGroupWrapper):
1253
+ """
1254
+ This is a wrapper around any ProcessGroup that can be used to inject
1255
+ errors into the process group at various points.
1256
+
1257
+ This is intended to be used for tests so that they can test cases
1258
+ in which process group operations error out.
1259
+ """
1260
+
1261
+ def __init__(self, pg: ProcessGroup) -> None:
1262
+ super().__init__(pg=pg)
1263
+
1264
+ self._future_error: Optional[Exception] = None
1265
+
1266
+ def configure(
1267
+ self,
1268
+ store_addr: str,
1269
+ replica_id: str,
1270
+ rank: int,
1271
+ world_size: int,
1272
+ quorum_id: Optional[int] = None,
1273
+ group_rank: Optional[int] = None,
1274
+ group_world_size: Optional[int] = None,
1275
+ global_ranks: Optional[list[int]] = None,
1276
+ ) -> None:
1277
+ self._future_error = None
1278
+
1279
+ super().configure(
1280
+ store_addr,
1281
+ replica_id,
1282
+ rank,
1283
+ world_size,
1284
+ quorum_id,
1285
+ group_rank,
1286
+ group_world_size,
1287
+ global_ranks,
1288
+ )
1289
+
1290
+ def report_future_error(self, e: Exception) -> None:
1291
+ """
1292
+ Report an error to this process group. This will cause the
1293
+ future attached to the next operation to error out.
1294
+
1295
+ Args:
1296
+ e: exception to report
1297
+ """
1298
+ self._future_error = e
1299
+
1300
+ def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
1301
+ work = super().allreduce(tensors, opts)
1302
+
1303
+ if self._future_error is None:
1304
+ return work
1305
+
1306
+ fut = work.get_future()
1307
+
1308
+ def callback(
1309
+ fut: torch.futures.Future[List[torch.Tensor]],
1310
+ ) -> List[torch.Tensor]:
1311
+ future_error, self._future_error = self._future_error, None
1312
+ assert future_error is not None
1313
+ raise future_error
1314
+
1315
+ fut = fut.then(callback)
1316
+
1317
+ return work
1318
+
1319
+
1320
+ class ManagedProcessGroup(ProcessGroupWrapper):
1321
+ """
1322
+ This is a wrapper around any ProcessGroup that is managed by a torchft
1323
+ Manager.
1324
+
1325
+ This uses the ProcessGroup that is configured in the Manager. The world size
1326
+ is dynamic and will report the number of active particpants in the quorum to
1327
+ the model.
1328
+
1329
+ Any errors will be asynchronously reported to the manager and only successes
1330
+ will be returned to the caller.
1331
+ """
1332
+
1333
+ def __init__(self, manager: "Manager") -> None:
1334
+ super().__init__(pg=manager._pg)
1335
+
1336
+ self._manager = manager
1337
+
1338
+ def allreduce(self, tensors: List[torch.Tensor], opts: object) -> Work:
1339
+ assert len(tensors) == 1
1340
+
1341
+ if isinstance(opts, ReduceOp):
1342
+ return self._manager.allreduce(tensors[0], reduce_op=opts)
1343
+
1344
+ if isinstance(opts, AllreduceOptions):
1345
+ return self._manager.allreduce(tensors[0], reduce_op=opts.reduceOp)
1346
+
1347
+ assert False, "unreachable"
1348
+
1349
+ def size(self) -> int:
1350
+ return self._manager.num_participants()
1351
+
1352
+ def getBackendName(self) -> str:
1353
+ return self._manager._pg.getBackendName()
1354
+
1355
+
1356
+ class _BabyWork(Work):
1357
+ def __init__(
1358
+ self,
1359
+ pg: "ProcessGroupBaby",
1360
+ op_id: int,
1361
+ stream: Optional[torch.Stream],
1362
+ ) -> None:
1363
+ super().__init__()
1364
+
1365
+ self._pg = pg
1366
+ self._op_id = op_id
1367
+ self._stream = stream
1368
+
1369
+ def wait(self, timeout: Optional[timedelta] = None) -> bool:
1370
+ return self._pg._wait(self._op_id, timeout)
1371
+
1372
+ def synchronize(self) -> None:
1373
+ # TODO: No one seems to use this and NCCL wait already only waits the
1374
+ # stream and is non-blocking on the CPU side so no real need for a
1375
+ # separate call.
1376
+ raise NotImplementedError("not implemented")
1377
+
1378
+ def get_future(self) -> Future[object]:
1379
+ return self._pg._get_future(self._op_id, self._stream)
1380
+
1381
+ def __del__(self) -> None:
1382
+ self._pg._del(self._op_id)
1383
+
1384
+
1385
+ def _is_any_cuda(obj: object) -> bool:
1386
+ """
1387
+ Returns true if any of the tensors in the object are CUDA tensors.
1388
+
1389
+ Supports lists, tuples, dicts, and tensors.
1390
+ """
1391
+ return tree_any(lambda obj: isinstance(obj, torch.Tensor) and obj.is_cuda, obj)
1392
+
1393
+
1394
+ def _is_any_xpu(obj: object) -> bool:
1395
+ """
1396
+ Returns true if any of the tensors in the object are XPU tensors.
1397
+
1398
+ Supports lists, tuples, dicts, and tensors.
1399
+ """
1400
+ return tree_any(lambda obj: isinstance(obj, torch.Tensor) and obj.is_xpu, obj)
1401
+
1402
+
1403
+ @dataclass
1404
+ class _OpMetadata:
1405
+ work: Work
1406
+ stream: Optional[torch.Stream]
1407
+
1408
+ @contextmanager
1409
+ def set_stream(self) -> Generator[None, None, None]:
1410
+ with get_stream_context(self.stream):
1411
+ yield
1412
+
1413
+
1414
+ @dataclass
1415
+ class _FutureMetadata:
1416
+ future: Future[object]
1417
+ stream: Optional[torch.Stream]
1418
+
1419
+ @contextmanager
1420
+ def set_stream(self) -> Generator[None, None, None]:
1421
+ with get_stream_context(self.stream):
1422
+ yield
1423
+
1424
+
1425
+ def _maybe_share_tensors(
1426
+ tensor: Union[List[List[torch.Tensor]], List[torch.Tensor], torch.Tensor],
1427
+ ) -> None:
1428
+ """Move a tensor / list of tensors to shared memory if not already in shared memory."""
1429
+ if isinstance(tensor, list):
1430
+ for t in tensor:
1431
+ _maybe_share_tensors(t)
1432
+ elif isinstance(tensor, torch.Tensor):
1433
+ if not tensor.is_shared():
1434
+ tensor.share_memory_()
1435
+ else:
1436
+ raise TypeError(f"expected tensor or list but got {type(tensor)}")
1437
+
1438
+
1439
+ def _assert_list(tensors: Union[List[torch.Tensor], List[List[torch.Tensor]]]) -> None:
1440
+ """Assert that the input is a list of tensors or a nested list of tensors."""
1441
+ if not isinstance(tensors, list):
1442
+ raise TypeError(f"expected list but got {type(tensors)}")
1443
+
1444
+
1445
+ class ProcessGroupBaby(ProcessGroup):
1446
+ """
1447
+ This is a process group that runs the underlying process group in a
1448
+ subprocess. Since it's running in a subprocess all tensors need to be in
1449
+ shared memory or will be moved to shared memory. CUDA/XPU tensors are implicitly
1450
+ shareable and don't need any changes.
1451
+ """
1452
+
1453
+ def __init__(self, timeout: Union[float, timedelta] = 60.0) -> None:
1454
+ super().__init__(0, 1)
1455
+
1456
+ self._world_size = -1
1457
+
1458
+ self._p: Optional[mp.Process] = None
1459
+ self._pipe: Optional[_MonitoredPipe] = None
1460
+ self._future_pipe: Optional[_MonitoredPipe] = None
1461
+ self._future_thread: Optional[threading.Thread] = None
1462
+ self._futures: Dict[int, _FutureMetadata] = {}
1463
+ self._futures_lock = threading.Lock()
1464
+
1465
+ self._next_op_id = 0
1466
+
1467
+ if isinstance(timeout, timedelta):
1468
+ timeout = timeout.total_seconds()
1469
+
1470
+ self._timeout: float = timeout
1471
+
1472
+ def shutdown(self) -> None:
1473
+ """
1474
+ Shutdown the process group. This will kill the underlying process and
1475
+ close all queues.
1476
+
1477
+ This is a no-op if the process group is already shutdown.
1478
+
1479
+ ProcessGroup can be reconfigured after shutdown.
1480
+ """
1481
+
1482
+ if self._pipe is not None:
1483
+ self._pipe.close()
1484
+
1485
+ future_pipe = self._future_pipe
1486
+ if future_pipe is not None:
1487
+ # wait for the future thread to exit and then close the queue
1488
+ future_pipe.close()
1489
+
1490
+ future_thread = self._future_thread
1491
+ assert future_thread is not None
1492
+
1493
+ future_thread.join(timeout=10.0)
1494
+ if future_thread.is_alive():
1495
+ raise RuntimeError("future thread did not exit")
1496
+
1497
+ # Kill after closing queues to avoid log spam.
1498
+ if self._p is not None:
1499
+ self._p.kill()
1500
+
1501
+ def configure(
1502
+ self,
1503
+ store_addr: str,
1504
+ replica_id: str,
1505
+ rank: int,
1506
+ world_size: int,
1507
+ quorum_id: Optional[int] = None,
1508
+ group_rank: Optional[int] = None,
1509
+ group_world_size: Optional[int] = None,
1510
+ global_ranks: Optional[list[int]] = None,
1511
+ ) -> None:
1512
+ self._world_size = world_size
1513
+
1514
+ self.shutdown()
1515
+
1516
+ ctx = mp.get_context("spawn")
1517
+ req_local, req_remote = ctx.Pipe()
1518
+ future_local, future_remote = ctx.Pipe()
1519
+
1520
+ self._pipe = req_local = _MonitoredPipe(req_local)
1521
+ self._future_pipe = future_local = _MonitoredPipe(future_local)
1522
+
1523
+ curr_device = (
1524
+ torch.accelerator.current_device_index()
1525
+ if torch.accelerator.is_available()
1526
+ else -1
1527
+ )
1528
+
1529
+ self._p = p = ctx.Process(
1530
+ target=self._worker,
1531
+ args=(
1532
+ store_addr,
1533
+ rank,
1534
+ world_size,
1535
+ req_remote,
1536
+ future_remote,
1537
+ curr_device,
1538
+ ),
1539
+ daemon=True,
1540
+ )
1541
+ p.start()
1542
+
1543
+ # futures need thread to fire callbacks
1544
+ # this lock needs to be held when manipulating _futures
1545
+ self._futures_lock = threading.Lock()
1546
+ self._futures = {}
1547
+ self._future_thread = threading.Thread(
1548
+ target=self._future_handler,
1549
+ args=(future_local,),
1550
+ daemon=True,
1551
+ )
1552
+ self._future_thread.start()
1553
+
1554
+ # fetch the status of the PG init
1555
+ # if an exception was returned get will throw
1556
+ assert req_local.recv(self._timeout) is None
1557
+
1558
+ @classmethod
1559
+ def _create_pg(cls, store: Store, rank: int, world_size: int) -> BaseProcessGroup:
1560
+ """
1561
+ This is a class method to avoid pickling the class.
1562
+ """
1563
+ raise NotImplementedError("not implemented")
1564
+
1565
+ @classmethod
1566
+ def _worker(
1567
+ cls,
1568
+ store_addr: str,
1569
+ rank: int,
1570
+ world_size: int,
1571
+ req_pipe: "Connection[object, object]", # type: ignore
1572
+ future_pipe: "Connection[object, object]", # type: ignore
1573
+ curr_device: int,
1574
+ ) -> None:
1575
+ try:
1576
+ if curr_device >= 0 and torch.accelerator.is_available():
1577
+ torch.accelerator.set_device_index(curr_device)
1578
+
1579
+ store = create_store_client(
1580
+ store_addr,
1581
+ # default TCPStore timeout is 5 minutes
1582
+ timeout=timedelta(minutes=5),
1583
+ )
1584
+
1585
+ try:
1586
+ pg = cls._create_pg(store, rank, world_size)
1587
+ except Exception as e:
1588
+ logger.exception(f"got exception in worker: {e}")
1589
+ req_pipe.send(e)
1590
+ return
1591
+ req_pipe.send(None)
1592
+
1593
+ streams: Dict[str, torch.Stream] = {}
1594
+ work: Dict[int, _OpMetadata] = {}
1595
+
1596
+ while True:
1597
+ op = cast(list[object], req_pipe.recv())
1598
+ cmd = op[0]
1599
+ if cmd == "func":
1600
+ op_id: int
1601
+ op_id, func_name, args, kwargs, stream_device, stream_id, event = (
1602
+ cast(
1603
+ Tuple[
1604
+ int,
1605
+ str,
1606
+ list[object],
1607
+ dict[str, object],
1608
+ int,
1609
+ int,
1610
+ Optional[Union[torch.cuda.Event, torch.xpu.Event]],
1611
+ ],
1612
+ op[1:],
1613
+ )
1614
+ )
1615
+
1616
+ # To avoid potential deadlocks we need to preserve the
1617
+ # stream/synchronization behavior of the parent process.
1618
+ # We allocate one Stream per stream_id to make sure that we
1619
+ # don't accidentally introduce cross stream synchronization
1620
+ # points.
1621
+ if stream_id is not None:
1622
+ stream_key = f"{stream_device}/{stream_id}"
1623
+ if stream_key not in streams:
1624
+ streams[stream_key] = torch.Stream(device=stream_device)
1625
+ stream = streams[stream_key]
1626
+ else:
1627
+ stream = None
1628
+
1629
+ with get_stream_context(stream):
1630
+ # Make the stream wait on the cuda event to make sure we
1631
+ # don't start the operation until the tensor is ready.
1632
+ if event is not None:
1633
+ event.wait()
1634
+
1635
+ args = _PickleSafeOptions.unsafe_args(args)
1636
+ fn = getattr(pg, func_name)
1637
+
1638
+ work[op_id] = _OpMetadata(
1639
+ work=fn(*args, **kwargs),
1640
+ stream=stream,
1641
+ )
1642
+
1643
+ elif cmd == "wait":
1644
+ op_id, timeout = cast(tuple[int, timedelta], op[1:])
1645
+
1646
+ metadata = work[op_id]
1647
+
1648
+ with metadata.set_stream():
1649
+ # With WorkNCCL this makes the stream wait not the CPU when
1650
+ # no timeout is passed.
1651
+ if timeout is not None:
1652
+ metadata.work.wait(timeout)
1653
+ else:
1654
+ metadata.work.wait()
1655
+
1656
+ # Register event on the stream that we can pass to the main
1657
+ # process.
1658
+ event = record_event() if metadata.stream is not None else None
1659
+
1660
+ req_pipe.send((op_id, event))
1661
+ elif cmd == "del":
1662
+ op_id: int = cast(int, op[1])
1663
+ del work[op_id]
1664
+ elif cmd == "future":
1665
+ op_id: int = cast(int, op[1])
1666
+ metadata: _OpMetadata = work[op_id]
1667
+
1668
+ def callback(fut: Future[object], metadata: _OpMetadata) -> None:
1669
+ try:
1670
+ # create an event after the collective has been issued
1671
+ # to wait on this before we call "future"
1672
+ with metadata.set_stream():
1673
+ fut.wait()
1674
+ event = (
1675
+ record_event()
1676
+ if metadata.stream is not None
1677
+ else None
1678
+ )
1679
+
1680
+ future_pipe.send((op_id, _FUTURE_RESULT, None, event))
1681
+ except Exception as e:
1682
+ future_pipe.send((op_id, _FUTURE_EXCEPTION, e, None))
1683
+
1684
+ metadata.work.get_future().add_done_callback(
1685
+ lambda fut: callback(fut, metadata)
1686
+ )
1687
+ elif cmd == "num_active_work":
1688
+ req_pipe.send(len(work))
1689
+ else:
1690
+ raise ValueError(f"unknown cmd: {cmd}")
1691
+
1692
+ except Exception as e:
1693
+ logger.exception(f"worker errored: {e}")
1694
+ req_pipe.send(e)
1695
+ raise
1696
+
1697
+ def _future_handler(self, future_pipe: _MonitoredPipe) -> None:
1698
+ try:
1699
+ while True:
1700
+ try:
1701
+ cmd = future_pipe.recv(timedelta(seconds=10))
1702
+ except TimeoutError:
1703
+ continue
1704
+ except OSError:
1705
+ # subprocess exited
1706
+ break
1707
+
1708
+ op_id, mode, data, event = cast(
1709
+ Tuple[
1710
+ int,
1711
+ str,
1712
+ object,
1713
+ Optional[Union[torch.cuda.Event, torch.xpu.Event]],
1714
+ ],
1715
+ cmd,
1716
+ )
1717
+ with self._futures_lock:
1718
+ meta = self._futures[op_id]
1719
+ del self._futures[op_id]
1720
+ with meta.set_stream():
1721
+ if mode == _FUTURE_RESULT:
1722
+ if event is not None:
1723
+ event.wait()
1724
+ meta.future.set_result(data)
1725
+ elif mode == _FUTURE_EXCEPTION:
1726
+ meta.future.set_exception(data)
1727
+ else:
1728
+ raise ValueError(f"unknown mode {mode}")
1729
+ except Exception as e:
1730
+ logger.exception(f"got unexpected error in future handler: {e}")
1731
+
1732
+ def _get_future(self, op_id: int, stream: Optional[torch.Stream]) -> Future[object]:
1733
+ with self._futures_lock:
1734
+ fut = Future()
1735
+ self._futures[op_id] = _FutureMetadata(future=fut, stream=stream)
1736
+ assert self._pipe is not None
1737
+ self._pipe.send(("future", op_id))
1738
+
1739
+ # TODO: return correct tensor instead of None
1740
+ return fut
1741
+
1742
+ def _wait(self, op_id: int, timeout: Optional[timedelta] = None) -> bool:
1743
+ assert self._pipe is not None
1744
+ self._pipe.send(("wait", op_id, timeout))
1745
+
1746
+ assert self._pipe is not None
1747
+ op_id, event = cast(
1748
+ Tuple[int, Optional[Union[torch.cuda.Event, torch.xpu.Event]]],
1749
+ self._pipe.recv(timeout or self._timeout),
1750
+ )
1751
+ assert op_id == op_id
1752
+ if event is not None:
1753
+ event.wait()
1754
+
1755
+ return True
1756
+
1757
+ def _del(self, op_id: int) -> None:
1758
+ assert self._pipe is not None
1759
+ try:
1760
+ self._pipe.send(("del", op_id))
1761
+ except OSError:
1762
+ # if pipe is closed we can safely do nothing
1763
+ pass
1764
+
1765
+ def _run_func(self, func: str, *args: object, **kwargs: object) -> Work:
1766
+ pipe = self._pipe
1767
+ assert pipe is not None
1768
+
1769
+ is_accelerator = _is_any_cuda(args) or _is_any_xpu(args)
1770
+
1771
+ stream_device = (
1772
+ torch.accelerator.current_stream().device if is_accelerator else None
1773
+ )
1774
+ stream_id = (
1775
+ torch.accelerator.current_stream().stream_id if is_accelerator else None
1776
+ )
1777
+ event = record_event() if is_accelerator else None
1778
+
1779
+ op_id = self._next_op_id
1780
+ self._next_op_id += 1
1781
+
1782
+ pipe.send(
1783
+ (
1784
+ "func",
1785
+ op_id,
1786
+ func,
1787
+ _PickleSafeOptions.safe_args(args),
1788
+ kwargs,
1789
+ stream_device,
1790
+ stream_id,
1791
+ event,
1792
+ ),
1793
+ )
1794
+
1795
+ return _BabyWork(
1796
+ pg=self,
1797
+ op_id=op_id,
1798
+ stream=torch.accelerator.current_stream() if is_accelerator else None,
1799
+ )
1800
+
1801
+ def allgather(
1802
+ self,
1803
+ output_tensors: List[List[torch.Tensor]],
1804
+ input_tensor: List[torch.Tensor],
1805
+ opts: AllgatherOptions,
1806
+ ) -> Work:
1807
+ _assert_list(output_tensors)
1808
+ _assert_list(input_tensor)
1809
+ _maybe_share_tensors(output_tensors)
1810
+ _maybe_share_tensors(input_tensor)
1811
+ return self._run_func("allgather", output_tensors, input_tensor, opts)
1812
+
1813
+ def allgather_into_tensor_coalesced(
1814
+ self,
1815
+ output_tensors: List[torch.Tensor],
1816
+ input_tensors: List[torch.Tensor],
1817
+ opts: AllgatherOptions,
1818
+ ) -> Work:
1819
+ _assert_list(output_tensors)
1820
+ _assert_list(input_tensors)
1821
+ _maybe_share_tensors(output_tensors)
1822
+ _maybe_share_tensors(input_tensors)
1823
+ return self._run_func(
1824
+ "allgather_into_tensor_coalesced", output_tensors, input_tensors, opts
1825
+ )
1826
+
1827
+ def allreduce(
1828
+ self,
1829
+ tensors: List[torch.Tensor],
1830
+ opts: Union[dist.AllreduceOptions, dist.ReduceOp],
1831
+ ) -> Work:
1832
+ _assert_list(tensors)
1833
+ _maybe_share_tensors(tensors)
1834
+ return self._run_func("allreduce", tensors, opts)
1835
+
1836
+ def allreduce_coalesced(
1837
+ self,
1838
+ tensors: List[torch.Tensor],
1839
+ opts: Union[dist.AllreduceCoalescedOptions, dist.ReduceOp],
1840
+ ) -> Work:
1841
+ _assert_list(tensors)
1842
+ _maybe_share_tensors(tensors)
1843
+ return self._run_func("allreduce_coalesced", tensors, opts)
1844
+
1845
+ def alltoall_base(
1846
+ self,
1847
+ output_buffer: torch.Tensor,
1848
+ input_buffer: torch.Tensor,
1849
+ output_split_sizes: List[int],
1850
+ input_split_sizes: List[int],
1851
+ opts: AllToAllOptions,
1852
+ ) -> Work:
1853
+ _maybe_share_tensors(output_buffer)
1854
+ _maybe_share_tensors(input_buffer)
1855
+ return self._run_func(
1856
+ "alltoall_base",
1857
+ output_buffer,
1858
+ input_buffer,
1859
+ output_split_sizes,
1860
+ input_split_sizes,
1861
+ opts,
1862
+ )
1863
+
1864
+ def barrier(self, opts: Optional[BarrierOptions] = None) -> Work:
1865
+ return self._run_func("barrier", opts)
1866
+
1867
+ def broadcast(
1868
+ self,
1869
+ tensor_list: List[torch.Tensor],
1870
+ opts: BroadcastOptions,
1871
+ ) -> Work:
1872
+ _assert_list(tensor_list)
1873
+ _maybe_share_tensors(tensor_list)
1874
+ return self._run_func("broadcast", tensor_list, opts)
1875
+
1876
+ def recv(self, tensors: List[torch.Tensor], src_rank: int, tag: int) -> Work:
1877
+ _assert_list(tensors)
1878
+ _maybe_share_tensors(tensors)
1879
+ return self._run_func("recv", tensors, src_rank, tag)
1880
+
1881
+ def reduce_scatter(
1882
+ self,
1883
+ output_tensors: List[torch.Tensor],
1884
+ input_tensors: List[List[torch.Tensor]],
1885
+ opts: ReduceScatterOptions,
1886
+ ) -> Work:
1887
+ _assert_list(output_tensors)
1888
+ _assert_list(input_tensors)
1889
+ _maybe_share_tensors(output_tensors)
1890
+ _maybe_share_tensors(input_tensors)
1891
+ return self._run_func("reduce_scatter", output_tensors, input_tensors, opts)
1892
+
1893
+ def reduce_scatter_tensor_coalesced(
1894
+ self,
1895
+ output_tensors: List[torch.Tensor],
1896
+ input_tensors: List[torch.Tensor],
1897
+ opts: ReduceScatterOptions,
1898
+ ) -> Work:
1899
+ _assert_list(output_tensors)
1900
+ _assert_list(input_tensors)
1901
+ _maybe_share_tensors(output_tensors)
1902
+ _maybe_share_tensors(input_tensors)
1903
+ return self._run_func(
1904
+ "reduce_scatter_tensor_coalesced", output_tensors, input_tensors, opts
1905
+ )
1906
+
1907
+ def send(self, tensors: List[torch.Tensor], dst_rank: int, tag: int) -> Work:
1908
+ _assert_list(tensors)
1909
+ _maybe_share_tensors(tensors)
1910
+ return self._run_func("send", tensors, dst_rank, tag)
1911
+
1912
+ def size(self) -> int:
1913
+ return self._world_size
1914
+
1915
+ def num_active_work(self) -> int:
1916
+ assert self._pipe is not None
1917
+ self._pipe.send(("num_active_work",))
1918
+
1919
+ assert self._pipe is not None
1920
+ return cast(int, self._pipe.recv(self._timeout))
1921
+
1922
+ def set_timeout(self, timeout: timedelta) -> None:
1923
+ self._timeout = timeout.total_seconds()
1924
+
1925
+
1926
+ @dataclass
1927
+ class _PickleSafeOptions:
1928
+ func: Callable[[], object]
1929
+ fields: Dict[str, object]
1930
+
1931
+ @classmethod
1932
+ def safe_args(cls, args: T) -> T:
1933
+ if isinstance(args, tuple):
1934
+ return tuple(cls.safe_args(arg) for arg in args)
1935
+ elif isinstance(args, list):
1936
+ return [cls.safe_args(arg) for arg in args]
1937
+ elif isinstance(
1938
+ args,
1939
+ (
1940
+ AllgatherOptions,
1941
+ AllreduceOptions,
1942
+ AllreduceCoalescedOptions,
1943
+ AllToAllOptions,
1944
+ BarrierOptions,
1945
+ BroadcastOptions,
1946
+ ReduceScatterOptions,
1947
+ ),
1948
+ ):
1949
+ return cls.from_torch(args)
1950
+ else:
1951
+ return args
1952
+
1953
+ @classmethod
1954
+ def unsafe_args(cls, args: T) -> T:
1955
+ if isinstance(args, tuple):
1956
+ return tuple(cls.unsafe_args(arg) for arg in args)
1957
+ elif isinstance(args, list):
1958
+ return [cls.unsafe_args(arg) for arg in args]
1959
+ elif isinstance(args, cls):
1960
+ return args.to_torch()
1961
+ else:
1962
+ return args
1963
+
1964
+ @classmethod
1965
+ def from_torch(cls, opts: object) -> "_PickleSafeOptions":
1966
+ return cls(
1967
+ func=opts.__class__,
1968
+ fields={k: getattr(opts, k) for k in dir(opts) if not k.startswith("_")},
1969
+ )
1970
+
1971
+ def to_torch(self) -> object:
1972
+ opts = self.func()
1973
+ for k, v in self.fields.items():
1974
+ setattr(opts, k, v)
1975
+ return opts
1976
+
1977
+
1978
+ class ProcessGroupBabyGloo(ProcessGroupBaby):
1979
+ """
1980
+ This is a ProcessGroup that runs Gloo in a subprocess.
1981
+
1982
+ For most use cases you should prefer ProcessGroupGloo or
1983
+ ProcessGroupBabyNCCL.
1984
+ """
1985
+
1986
+ @classmethod
1987
+ def _create_pg(cls, store: Store, rank: int, world_size: int) -> BaseProcessGroup:
1988
+ pg = BaseProcessGroup(store, rank, world_size)
1989
+ pg._set_default_backend(ProcessGroup.BackendType.GLOO)
1990
+ # pyre-fixme[16]: no attribute ProcessGroupGloo
1991
+ backend_class = BaseProcessGroupGloo(store, rank, world_size)
1992
+ pg._register_backend(
1993
+ torch.device("cpu"), ProcessGroup.BackendType.GLOO, backend_class
1994
+ )
1995
+ return pg
1996
+
1997
+ def getBackendName(self) -> str:
1998
+ return "torchft-baby-gloo"
1999
+
2000
+ # pyre-fixme[15]: inconsistent override
2001
+ def reduce_scatter(
2002
+ self,
2003
+ output_tensors: List[torch.Tensor],
2004
+ input_tensors: List[List[torch.Tensor]],
2005
+ opts: ReduceScatterOptions,
2006
+ ) -> None:
2007
+ """
2008
+ This function is a placeholder for the reduce_scatter operation in the
2009
+ ProcessGroupGloo class. However, this operation is not supported by the
2010
+ Gloo backend, and thus, calling this function will raise a
2011
+ RuntimeError.
2012
+
2013
+ Raises:
2014
+ RuntimeError: Always raised since reduce_scatter is not
2015
+ supported by ProcessGroupGloo.
2016
+ """
2017
+ raise RuntimeError("ProcessGroupBabyGloo does not support reduce_scatter.")
2018
+
2019
+ # pyre-fixme[15]: inconsistent override
2020
+ def reduce_scatter_tensor_coalesced(
2021
+ self,
2022
+ output_tensors: List[torch.Tensor],
2023
+ input_tensors: List[torch.Tensor],
2024
+ opts: ReduceScatterOptions,
2025
+ ) -> None:
2026
+ """
2027
+ This function is a placeholder for the reduce_scatter_tensor_coalesced
2028
+ operation in the ProcessGroupBabyGloo class.
2029
+ However, this operation is not supported by the
2030
+ Gloo backend, and thus, calling this function will raise a
2031
+ RuntimeError.
2032
+
2033
+ Raises:
2034
+ RuntimeError: Always raised since reduce_scatter is not
2035
+ supported by ProcessGroupBabyGloo.
2036
+ """
2037
+ raise RuntimeError(
2038
+ "ProcessGroupBabyGloo does not support reduce_scatter_tensor_coalesced."
2039
+ )
2040
+
2041
+
2042
+ class ProcessGroupBabyNCCL(ProcessGroupBaby):
2043
+ """
2044
+ This is a ProcessGroup that runs NCCL in a subprocess.
2045
+
2046
+ For the NCCL backend, extra memory will be used by the subprocesses CUDA
2047
+ context compared to running NCCL in the main process. This is typically
2048
+ around ~1GB.
2049
+
2050
+ The returned Work objects only synchronize on the cuda stream and not on the
2051
+ CPU side. This works by passing CUDA Events between the processes. To do a
2052
+ CPU synchronize, call torch.cuda.synchronize() after wait().
2053
+
2054
+ WARNING: If the child process is killed while an operation is running, CUDA
2055
+ tensors may leak in the current PyTorch implementation. TODO fix
2056
+
2057
+ WARNING: As this uses a separate CUDA context for the subprocess, performance
2058
+ may be slower than using NCCL directly. Separate CUDA contexts can not run
2059
+ at the same time so network and compute kernels will not overlap execution
2060
+ and instead do time sharing which may reduce GPU utilization.
2061
+ """
2062
+
2063
+ @classmethod
2064
+ def _create_pg(cls, store: Store, rank: int, world_size: int) -> BaseProcessGroup:
2065
+ from torch.distributed import ProcessGroupNCCL as BaseProcessGroupNCCL
2066
+
2067
+ pg = BaseProcessGroup(store, rank, world_size)
2068
+ pg._set_default_backend(ProcessGroup.BackendType.NCCL)
2069
+ # pyre-fixme[16]: no attribute ProcessGroupNCCL
2070
+ backend_class = BaseProcessGroupNCCL(store, rank, world_size)
2071
+ backend_class._set_sequence_number_for_group()
2072
+ pg._register_backend(
2073
+ torch.device("cuda"), ProcessGroup.BackendType.NCCL, backend_class
2074
+ )
2075
+ return pg
2076
+
2077
+ def getBackendName(self) -> str:
2078
+ return "torchft-baby-nccl"
2079
+
2080
+
2081
+ class ProcessGroupBabyXCCL(ProcessGroupBaby):
2082
+ """
2083
+ This is a ProcessGroup that runs XCCL in a subprocess for Intel XPU devices.
2084
+
2085
+ For the XCCL backend, extra memory will be used by the subprocesses XPU
2086
+ context compared to running XCCL in the main process. This is typically
2087
+ dependent on the XPU memory architecture.
2088
+
2089
+ The returned Work objects only synchronize on the XPU stream and not on the
2090
+ CPU side. This works by passing XPU Events between the processes. To do a
2091
+ CPU synchronize, call torch.xpu.synchronize() after wait().
2092
+
2093
+ WARNING: If the child process is killed while an operation is running, XPU
2094
+ tensors may leak in the current PyTorch implementation. TODO fix
2095
+
2096
+ WARNING: As this uses a separate XPU context for the subprocess, performance
2097
+ may be slower than using XCCL directly. Separate XPU contexts can not run
2098
+ at the same time so network and compute kernels will not overlap execution
2099
+ and instead do time sharing which may reduce XPU utilization.
2100
+ """
2101
+
2102
+ @classmethod
2103
+ def _create_pg(cls, store: Store, rank: int, world_size: int) -> BaseProcessGroup:
2104
+ # Check if XPU and XCCL are available
2105
+ from torch.distributed import ProcessGroupXCCL as BaseProcessGroupXCCL
2106
+
2107
+ pg = BaseProcessGroup(store, rank, world_size)
2108
+ pg._set_default_backend(ProcessGroup.BackendType.XCCL)
2109
+ # pyre-fixme[16]: no attribute ProcessGroupNCCL
2110
+ backend_class = BaseProcessGroupXCCL(store, rank, world_size)
2111
+ backend_class._set_sequence_number_for_group()
2112
+ pg._register_backend(
2113
+ torch.device("xpu"), ProcessGroup.BackendType.XCCL, backend_class
2114
+ )
2115
+ return pg
2116
+
2117
+ def getBackendName(self) -> str:
2118
+ return "torchft-baby-xccl"