torchft-nightly 2026.1.3__cp310-cp310-manylinux_2_24_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. torchft/__init__.py +34 -0
  2. torchft/_test/diloco_trainer.py +287 -0
  3. torchft/_test/managed_work_test.py +320 -0
  4. torchft/_test_utils.py +111 -0
  5. torchft/_torchft.cpython-310-x86_64-linux-gnu.so +0 -0
  6. torchft/_torchft.pyi +116 -0
  7. torchft/checkpointing/__init__.py +20 -0
  8. torchft/checkpointing/_rwlock.py +136 -0
  9. torchft/checkpointing/_serialization.py +39 -0
  10. torchft/checkpointing/http_transport.py +299 -0
  11. torchft/checkpointing/http_transport_bench.py +61 -0
  12. torchft/checkpointing/http_transport_test.py +146 -0
  13. torchft/checkpointing/pg_transport.py +306 -0
  14. torchft/checkpointing/pg_transport_bench.py +99 -0
  15. torchft/checkpointing/pg_transport_test.py +101 -0
  16. torchft/checkpointing/rwlock_test.py +58 -0
  17. torchft/checkpointing/transport.py +68 -0
  18. torchft/checkpointing/transport_test.py +161 -0
  19. torchft/collectives.py +415 -0
  20. torchft/collectives_test.py +212 -0
  21. torchft/coordination.py +39 -0
  22. torchft/coordination_test.py +29 -0
  23. torchft/data.py +77 -0
  24. torchft/data_test.py +39 -0
  25. torchft/ddp.py +105 -0
  26. torchft/ddp_test.py +68 -0
  27. torchft/diloco_regression_test.py +644 -0
  28. torchft/examples/slurm/README.md +34 -0
  29. torchft/examples/slurm/punisher.py +95 -0
  30. torchft/examples/slurm/runner.py +221 -0
  31. torchft/fsdp_test.py +102 -0
  32. torchft/futures.py +353 -0
  33. torchft/futures_test.py +140 -0
  34. torchft/http.py +13 -0
  35. torchft/lighthouse_test.py +163 -0
  36. torchft/local_sgd.py +796 -0
  37. torchft/local_sgd_integ_test.py +600 -0
  38. torchft/local_sgd_test.py +324 -0
  39. torchft/manager.py +1358 -0
  40. torchft/manager_integ_test.py +653 -0
  41. torchft/manager_test.py +911 -0
  42. torchft/multiprocessing.py +38 -0
  43. torchft/multiprocessing_dummy_context.py +135 -0
  44. torchft/multiprocessing_test.py +58 -0
  45. torchft/optim.py +63 -0
  46. torchft/optim_test.py +50 -0
  47. torchft/otel.py +134 -0
  48. torchft/parameter_server.py +195 -0
  49. torchft/parameter_server_test.py +47 -0
  50. torchft/process_group.py +2118 -0
  51. torchft/process_group_test.py +1028 -0
  52. torchft/quantization.py +686 -0
  53. torchft/quantization_test.py +131 -0
  54. torchft/torchx.py +89 -0
  55. torchft/utils.py +67 -0
  56. torchft/work.py +26 -0
  57. torchft_nightly-2026.1.3.dist-info/METADATA +308 -0
  58. torchft_nightly-2026.1.3.dist-info/RECORD +61 -0
  59. torchft_nightly-2026.1.3.dist-info/WHEEL +4 -0
  60. torchft_nightly-2026.1.3.dist-info/entry_points.txt +2 -0
  61. torchft_nightly-2026.1.3.dist-info/licenses/LICENSE +34 -0
torchft/manager.py ADDED
@@ -0,0 +1,1358 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ """
8
+ Manager
9
+ =========
10
+
11
+ This module implements the Manager that manages the full fault tolerant training
12
+ loop.
13
+
14
+ The Manager is responsible for managing the
15
+ full training loop, communicating with the Lighthouse server to figure out
16
+ quorum, reconfiguring the ProcessGroups and restoring checkpoint state when
17
+ recovering.
18
+
19
+ This uses wrapper classes to wrap the standard PyTorch Optimizer and Module
20
+ classes to provide fault tolerance. These wrappers indented to add fault
21
+ tolerance with minimal changes to the users modeling code and training loop.
22
+
23
+ This is designed to work with the standard PyTorch DistributedDataParallel module
24
+ and Hybrid FSDP.
25
+
26
+ """
27
+
28
+ import concurrent.futures
29
+ import logging
30
+ import os
31
+ import socket
32
+ import traceback
33
+ import uuid
34
+ import weakref
35
+ from concurrent.futures import ThreadPoolExecutor
36
+ from contextlib import nullcontext
37
+ from datetime import timedelta
38
+ from enum import Enum
39
+ from typing import (
40
+ Any,
41
+ Callable,
42
+ cast,
43
+ Dict,
44
+ List,
45
+ Optional,
46
+ TYPE_CHECKING,
47
+ TypeAlias,
48
+ TypeVar,
49
+ Union,
50
+ )
51
+
52
+ import torch
53
+ import torch.distributed as dist
54
+ from torch.distributed import ReduceOp, TCPStore
55
+ from torch.distributed.distributed_c10d import AllreduceOptions, ReduceOp, Work
56
+
57
+ from torchft._torchft import ManagerClient, ManagerServer
58
+ from torchft.checkpointing import CheckpointTransport, HTTPTransport
59
+ from torchft.checkpointing._rwlock import RWLock
60
+ from torchft.futures import future_timeout
61
+ from torchft.utils import get_stream_context, synchronize
62
+ from torchft.work import _DummyWork
63
+
64
+ if TYPE_CHECKING:
65
+ from torchft.process_group import ProcessGroup
66
+
67
+ IS_TRITON_AVAILABLE = True
68
+ try:
69
+ # pyre-ignore[21]: Could not find a module corresponding to import `triton`
70
+ import triton
71
+
72
+ from torchft.collectives import allreduce_quantized
73
+ except ImportError:
74
+ IS_TRITON_AVAILABLE = False
75
+
76
+ MANAGER_ADDR_KEY: str = "manager_addr"
77
+ MANAGER_PORT_ENV: str = "TORCHFT_MANAGER_PORT"
78
+ REPLICA_ID_KEY: str = "replica_id"
79
+
80
+ # Environment variables for various timeouts. These can also be passed
81
+ # in through the manager but the environment variables take precedence.
82
+ TIMEOUT_SEC_ENV: str = "TORCHFT_TIMEOUT_SEC"
83
+ QUORUM_TIMEOUT_SEC_ENV: str = "TORCHFT_QUORUM_TIMEOUT_SEC"
84
+ CONNECT_TIMEOUT_SEC_ENV: str = "TORCHFT_CONNECT_TIMEOUT_SEC"
85
+
86
+ # Environment variable for the number of retries to use for the quorum.
87
+ # We need to retry quorum in case lighthouse fails. Otherwise, if we
88
+ # crash if call to quorum fails, all replicas will crash.
89
+ QUORUM_RETRIES_ENV: str = "TORCHFT_QUORUM_RETRIES"
90
+
91
+ TORCH_FR_DUMP_TEMP_FILE_ENV: str = "TORCH_FR_DUMP_TEMP_FILE"
92
+
93
+ T = TypeVar("T")
94
+
95
+
96
+ def get_timeout(
97
+ timeout_sec_env: str | None, default_timeout_sec: timedelta
98
+ ) -> timedelta:
99
+ """
100
+ Get the timeout from the environment variable or the default value.
101
+
102
+ Args:
103
+ timeout_sec_env: The environment variable for the timeout
104
+ default_timeout_sec: The default timeout
105
+ Returns:
106
+ The timeout to use. Environment variable takes precedence.
107
+ """
108
+ if timeout_sec_env is not None:
109
+ return timedelta(seconds=int(timeout_sec_env))
110
+
111
+ return default_timeout_sec
112
+
113
+
114
+ def extract_trailing_digits(s: str) -> int:
115
+ """
116
+ Extracts the trailing digits from the end of the string s.
117
+ Returns an empty string if no trailing digits are found.
118
+ """
119
+ i = len(s) - 1
120
+ while i >= 0 and s[i].isdigit():
121
+ i -= 1
122
+ return int(s[i + 1 :]) if i < len(s) - 1 else 0
123
+
124
+
125
+ class WorldSizeMode(Enum):
126
+ """
127
+ This controls the numerics for the job when doing allreduces across replicas
128
+ when the world size is larger than ``min_replica_size``. The world size will
129
+ never be smaller than ``min_replica_size``.
130
+
131
+ DYNAMIC:
132
+ The world size will dynamical increase to use all available
133
+ replicas and normalize the gradient by the world size.
134
+ FIXED_WITH_SPARES:
135
+ The number of active replicas is ``min_replica_size`` and any spares
136
+ will contribute zero gradients.
137
+ """
138
+
139
+ DYNAMIC = 0
140
+ FIXED_WITH_SPARES = 1
141
+
142
+
143
+ class ExceptionWithTraceback(Exception):
144
+ def __init__(self, e: Exception) -> None:
145
+ self.original_exception = e
146
+ self.stack_trace: str = traceback.format_exc()
147
+ super().__init__(f"{e}\n{self.stack_trace}")
148
+
149
+
150
+ class Manager:
151
+ """
152
+ Manager manages the full fault tolerant training loop.
153
+
154
+ This requires the that the TCPStore specified by the store_addr and
155
+ store_port or MASTER_ADDR and MASTER_PORT environment variables to be
156
+ started prior to creating this manager. If using a modern version of
157
+ torchelastic this will already be the case. Otherwise, it should be started
158
+ via torch.distributed.init_process_group prior to creating this manager.
159
+
160
+ NOTE: when saving periodic checkpoints you must save and restore the
161
+ Manager's state_dict as well to avoid synchronization issues.
162
+ """
163
+
164
+ def __init__(
165
+ self,
166
+ pg: "ProcessGroup",
167
+ load_state_dict: Optional[Callable[[T], None]],
168
+ state_dict: Optional[Callable[[], T]],
169
+ min_replica_size: int,
170
+ use_async_quorum: bool = True,
171
+ timeout: timedelta = timedelta(seconds=60),
172
+ quorum_timeout: timedelta = timedelta(seconds=60),
173
+ connect_timeout: timedelta = timedelta(seconds=60),
174
+ rank: Optional[int] = None,
175
+ world_size: Optional[int] = None,
176
+ world_size_mode: WorldSizeMode = WorldSizeMode.DYNAMIC,
177
+ store_addr: Optional[str] = None,
178
+ store_port: Optional[int] = None,
179
+ lighthouse_addr: Optional[str] = None,
180
+ replica_id: Optional[str] = None,
181
+ port: Optional[int] = None,
182
+ hostname: str = socket.gethostname(),
183
+ heartbeat_interval: timedelta = timedelta(milliseconds=100),
184
+ checkpoint_transport: Optional[CheckpointTransport[Dict[str, T]]] = None,
185
+ init_sync: bool = True,
186
+ max_retries: Optional[int] = None,
187
+ quorum_retries: int = 0,
188
+ ) -> None:
189
+ """
190
+ Args:
191
+ load_state_dict: function to load the state dict when recovering
192
+ state_dict: function to save the state dict with recovering
193
+ min_replica_size: minimum number of replicas on each step
194
+ port: if rank==0, the port to run the manager server on.
195
+ Port assignment priority:
196
+ 1. this argument
197
+ 2. TORCHFT_MANAGER_PORT env var
198
+ 3. arbitrary port assigned via 0
199
+ use_async_quorum: whether to run the quorum asynchronously during the forward pass
200
+ timeout: the default timeout for all operations
201
+ Included:
202
+ * collectives such as allreduce
203
+ * should_commit rpc
204
+ * checkpoint_address rpc
205
+ * checkpoint HTTP operations
206
+ * wrap_future
207
+ quorum_timeout: the default timeout to wait for the quorum to complete.
208
+ This generally should be longer than the training step time /
209
+ the interval between quorum checks to avoid any split brain
210
+ issues.
211
+
212
+ For LocalSGD/DiLoCo this may need to be set to ~1h or longer
213
+ depending on how frequently the syncs occur.
214
+ connect_timeout: the timeout used for establishing rpc connections
215
+ to ManagerServer and Lighthouse
216
+ rank: the replica group local rank, referred to as group_rank in manager.py for clarity
217
+ world_size: the replica group local world size, referred to as group_world_size in manager.py for clarity
218
+ store_addr: TCPStore address for this replica group
219
+ store_port: TCPStore port for this replica group
220
+ lighthouse_addr: if rank==0, the address of the lighthouse server
221
+ replica_id: if rank==0, the replica_id for this group
222
+ hostname: if rank==0, the hostname to advertise to the lighthouse server
223
+ checkpoint_transport: the checkpoint transport to use for
224
+ transfering checkpoints to recovering replicas, defaults to HTTPTransport
225
+ init_sync: whether to synchronize the model weights on step 0. If
226
+ all of the model weights are initialized identically via
227
+ ``torch.set_seed`` you should set this to False.
228
+ max_retries: the maximum number of consecutive should_commit failures to allow
229
+ before raising an exception. If None, will retry indefinitely.
230
+ quorum_retries: the number of times to retry the quorum before crashing
231
+ """
232
+ self.quorum_logger: logging.Logger = logging.getLogger("torchft_quorums")
233
+ self.commits_logger: logging.Logger = logging.getLogger("torchft_commits")
234
+ self.errors_logger: logging.Logger = logging.getLogger("torchft_errors")
235
+
236
+ self._load_state_dict_fns: Dict[str, Callable[[object], None]] = {}
237
+ self._user_state_dicts: Dict[str, Callable[[], object]] = {}
238
+
239
+ self._original_fr_dump_temp_file: Optional[str] = os.environ.get(
240
+ TORCH_FR_DUMP_TEMP_FILE_ENV
241
+ )
242
+ self._replica_id = replica_id
243
+
244
+ # Protects state dict
245
+ self._state_dict_lock = RWLock(timeout=timeout.total_seconds())
246
+
247
+ if load_state_dict and state_dict:
248
+ self.register_state_dict_fn("default", load_state_dict, state_dict)
249
+
250
+ self._pending_state_dict: Optional[Dict[str, object]] = None
251
+ self._use_async_quorum = use_async_quorum
252
+
253
+ self._timeout: timedelta = get_timeout(
254
+ os.environ.get(TIMEOUT_SEC_ENV, None), timeout
255
+ )
256
+ self._quorum_timeout: timedelta = get_timeout(
257
+ os.environ.get(QUORUM_TIMEOUT_SEC_ENV, None), quorum_timeout
258
+ )
259
+ self._connect_timeout: timedelta = get_timeout(
260
+ os.environ.get(CONNECT_TIMEOUT_SEC_ENV, None), connect_timeout
261
+ )
262
+
263
+ self._replica_world_size_mode = world_size_mode
264
+ self._init_sync = init_sync
265
+ self._max_retries = max_retries
266
+ self._commit_failures = 0
267
+
268
+ self._quorum_retries: int = int(
269
+ os.environ.get(QUORUM_RETRIES_ENV, str(quorum_retries))
270
+ )
271
+
272
+ store_addr = store_addr or os.environ["MASTER_ADDR"]
273
+ store_port = store_port or int(os.environ["MASTER_PORT"])
274
+ self._group_rank: int = rank if rank is not None else int(os.environ["RANK"])
275
+ group_rank = self._group_rank
276
+ self._group_world_size: int = world_size or int(os.environ["WORLD_SIZE"])
277
+ self._min_replica_size = min_replica_size
278
+
279
+ if checkpoint_transport is None:
280
+ checkpoint_transport = HTTPTransport[Dict[str, T]](
281
+ timeout=timeout,
282
+ num_chunks=0,
283
+ )
284
+
285
+ self._checkpoint_transport: CheckpointTransport[Dict[str, T]] = (
286
+ checkpoint_transport
287
+ )
288
+ self._executor = ThreadPoolExecutor(
289
+ max_workers=1, thread_name_prefix="async_quorum"
290
+ )
291
+ self._quorum_future: Optional[concurrent.futures.Future] = None
292
+
293
+ self._store = TCPStore(
294
+ host_name=store_addr,
295
+ port=store_port,
296
+ is_master=False,
297
+ wait_for_workers=False,
298
+ )
299
+ self._pg = pg
300
+ self._manager: Optional[ManagerServer] = None
301
+
302
+ self._recovery_stream: Optional["torch.Stream"] = (
303
+ torch.Stream() if torch.accelerator.is_available() else None
304
+ )
305
+
306
+ # Used to synchronize recovery operation
307
+ self._recovery_event: Optional[torch.Event] = None
308
+
309
+ if self._group_rank == 0:
310
+ if port is None:
311
+ port = int(os.environ.get(MANAGER_PORT_ENV, 0))
312
+
313
+ bind = f"[::]:{port}"
314
+ lighthouse_addr = lighthouse_addr or os.environ["TORCHFT_LIGHTHOUSE"]
315
+
316
+ # We need a unique identifier in the case that a worker restarts quickly and
317
+ # replaces the previous worker with the same ID.
318
+ new_uuid = str(uuid.uuid4())
319
+ if replica_id is None or replica_id == "":
320
+ replica_id = new_uuid
321
+ else:
322
+ replica_id = f"{replica_id}:{new_uuid}"
323
+ self._manager = ManagerServer(
324
+ replica_id=replica_id,
325
+ lighthouse_addr=lighthouse_addr,
326
+ hostname=hostname,
327
+ bind=bind,
328
+ store_addr=f"{store_addr}:{store_port}",
329
+ world_size=self._group_world_size,
330
+ heartbeat_interval=heartbeat_interval,
331
+ connect_timeout=connect_timeout,
332
+ quorum_retries=self._quorum_retries,
333
+ )
334
+
335
+ self._store.set(MANAGER_ADDR_KEY, self._manager.address())
336
+ self._store.set(REPLICA_ID_KEY, replica_id)
337
+
338
+ addr = self._store.get(MANAGER_ADDR_KEY).decode("utf-8")
339
+ self._client = ManagerClient(addr, connect_timeout=connect_timeout)
340
+
341
+ replica_id = self._store.get(REPLICA_ID_KEY).decode("utf-8")
342
+ self._logger = _ManagerLogger(
343
+ manager=self, replica_id=replica_id or "", group_rank=group_rank
344
+ )
345
+
346
+ self._step = 0
347
+ self._quorum_id = -1
348
+ self._errored: Optional[ExceptionWithTraceback] = None
349
+ self._healing = False
350
+ self._batches_committed = 0
351
+
352
+ # first step is 1
353
+ self._participating_replica_rank: Optional[int] = None
354
+ self._participating_replica_world_size: int = 0
355
+ self._is_state_dict_read_allowed = True
356
+
357
+ self._global_rank: int = (
358
+ self._group_rank
359
+ if self._replica_id is None
360
+ else (
361
+ extract_trailing_digits(self._replica_id) * self._group_world_size
362
+ + self._group_rank
363
+ )
364
+ )
365
+
366
+ self._update_fr_path()
367
+
368
+ def allow_state_dict_read(self) -> None:
369
+ if self._is_state_dict_read_allowed:
370
+ return
371
+
372
+ self._is_state_dict_read_allowed = True
373
+ self._state_dict_lock.w_release()
374
+
375
+ def disallow_state_dict_read(self) -> None:
376
+ if not self._is_state_dict_read_allowed:
377
+ return
378
+
379
+ self._is_state_dict_read_allowed = False
380
+ self._state_dict_lock.w_acquire()
381
+
382
+ def register_state_dict_fn(
383
+ self,
384
+ key: str,
385
+ load_state_dict: Callable[[T], None],
386
+ state_dict: Callable[[], T],
387
+ ) -> None:
388
+ # Can't register duplicate keys
389
+ assert key not in self._load_state_dict_fns
390
+ assert key not in self._user_state_dicts
391
+
392
+ self._load_state_dict_fns[key] = cast(Callable[[object], None], load_state_dict)
393
+ self._user_state_dicts[key] = state_dict
394
+
395
+ def set_state_dict_fns(
396
+ self, load_state_dict: Callable[[T], None], state_dict: Callable[[], T]
397
+ ) -> None:
398
+ self._logger.warn(
399
+ "`set_state_dict_fns` is deprecated, please use `register_state_dict_fn` instead"
400
+ )
401
+ self.register_state_dict_fn("set_state_dict_fns", load_state_dict, state_dict)
402
+
403
+ def shutdown(self, wait: bool = True) -> None:
404
+ """
405
+ Shutdown the manager and checkpoint server.
406
+ """
407
+ self._checkpoint_transport.shutdown(wait=wait)
408
+ if self._manager is not None:
409
+ self._manager.shutdown()
410
+ self._executor.shutdown(wait=wait)
411
+
412
+ @torch.profiler.record_function("torchft::manager::allreduce")
413
+ def allreduce(
414
+ self,
415
+ tensor: torch.Tensor,
416
+ should_quantize: bool = False,
417
+ reduce_op: ReduceOp = ReduceOp.AVG,
418
+ ) -> Work:
419
+ """
420
+ Fault tolerant allreduce the tensor and return a Future that will be completed when
421
+ the tensor is ready.
422
+
423
+ This will automatically scale the tensor by 1 / world_size.
424
+
425
+ If an error occurs during the allreduce:
426
+
427
+ * The Future will be completed with no error and instead tracked asynchronously.
428
+ * After the first error, all subsequent calls will be noops and immediately return.
429
+ * The tensor must be zeroed before being used as it may be corrupted.
430
+
431
+ Args:
432
+ tensor: the tensor to allreduce
433
+ should_quantize: weather the tensor should be quantized before communication
434
+ Returns:
435
+ a Future that will be completed with the allreduced tensor
436
+ """
437
+ if self.errored():
438
+ return _DummyWork(tensor)
439
+
440
+ self.wait_quorum()
441
+ num_participants: int = self.num_participants()
442
+
443
+ if not self.is_participating():
444
+ tensor.zero_()
445
+
446
+ # special logic for average
447
+ pg_reduce_op = reduce_op
448
+ if reduce_op == ReduceOp.AVG:
449
+ if not torch.is_floating_point(tensor):
450
+ raise ValueError(
451
+ "average reduce op is only supported for floating point tensors"
452
+ )
453
+ pg_reduce_op = ReduceOp.SUM
454
+
455
+ # TODO: increase timeout when waiting when healing
456
+ try:
457
+ # Run the allreduce async and save the work object so we can wait on
458
+ # it later.
459
+ if should_quantize and IS_TRITON_AVAILABLE:
460
+ work = allreduce_quantized(
461
+ [tensor],
462
+ pg_reduce_op,
463
+ self._pg,
464
+ # pyre-fixme[6]: Expected `Optional[streams.Stream]` but got `_C.Stream`
465
+ torch.accelerator.current_stream(),
466
+ )
467
+ else:
468
+ opts = AllreduceOptions()
469
+ opts.reduceOp = pg_reduce_op
470
+ work = self._pg.allreduce([tensor], opts)
471
+
472
+ # schedule grad normalization as a continuation
473
+ # on the Future
474
+ @torch.profiler.record_function("torchft::manager::allreduce::callback")
475
+ def callback(
476
+ fut: torch.futures.Future[torch.Tensor],
477
+ ) -> torch.Tensor:
478
+ nonlocal tensor
479
+ if reduce_op == ReduceOp.AVG:
480
+ tensor /= num_participants
481
+ return tensor
482
+
483
+ managed_work = _ManagedWork(self, work, tensor)
484
+ fut = managed_work.get_future()
485
+ fut = cast(torch.futures.Future[torch.Tensor], fut)
486
+ fut = fut.then(callback)
487
+ return managed_work
488
+
489
+ except Exception as e:
490
+ self._logger.exception(
491
+ f"got exception in all reduce -- skipping remaining: {e}"
492
+ )
493
+ self.report_error(e)
494
+
495
+ return _DummyWork(tensor)
496
+
497
+ def report_error(self, e: Exception) -> None:
498
+ """
499
+ Report an error to the manager.
500
+
501
+ This will cause the manager to skip the current step and will be
502
+ reconfigured on the next step.
503
+
504
+ This should be called when an error occurs that leads to a corrupted
505
+ gradient that needs to be discarded.
506
+ """
507
+ self._errored = ExceptionWithTraceback(e)
508
+
509
+ def errored(self) -> Optional[ExceptionWithTraceback]:
510
+ """
511
+ Get whether an error has occurred.
512
+
513
+ Returns:
514
+ The error or None if no error has occurred.
515
+ """
516
+ return self._errored
517
+
518
+ def wrap_future(
519
+ self,
520
+ fut: torch.futures.Future[T],
521
+ default: T,
522
+ timeout: Optional[timedelta] = None,
523
+ ) -> torch.futures.Future[T]:
524
+ """
525
+ Wrap a Future and swallow any errors that occur and report them to the manager.
526
+
527
+ If an error occurs, the Future will be completed with the default value.
528
+
529
+ Args:
530
+ fut: the Future to wrap
531
+ default: the default value to complete the Future with if an error occurs
532
+ timeout: the timeout for the Future, if None, the manager's timeout will be used
533
+ """
534
+
535
+ fut = future_timeout(fut, timeout or self._timeout)
536
+
537
+ stream: Optional[torch.Stream] = (
538
+ torch.accelerator.current_stream()
539
+ if torch.accelerator.is_available()
540
+ else None
541
+ )
542
+
543
+ # schedule error handling as a continuation on the Future
544
+ def callback(
545
+ fut: torch.futures.Future[T],
546
+ ) -> T:
547
+ nonlocal default, stream
548
+
549
+ with get_stream_context(stream):
550
+ try:
551
+ return fut.value()
552
+ except Exception as e:
553
+ self._logger.exception(
554
+ f"got exception in future -- skipping remaining: {e}"
555
+ )
556
+ self.report_error(e)
557
+ return default
558
+
559
+ fut = fut.then(callback)
560
+ return fut
561
+
562
+ def start_quorum(
563
+ self,
564
+ allow_heal: bool = True,
565
+ shrink_only: bool = False,
566
+ timeout: Optional[timedelta] = None,
567
+ ) -> None:
568
+ """
569
+ .. note::
570
+ We recommend using the :py:class:`torchft.optim.OptimizerWrapper` instead of calling this directly.
571
+
572
+ Computes a new quorum (potentially asynchronously) and readies the
573
+ manager for a new step.
574
+
575
+ It's best practice to call this before the forwards pass of each step for
576
+ performance as computing quorum may take some time.
577
+
578
+ Args:
579
+ allow_heal: (experimental) whether to allow healing at the beginning of the step
580
+ If allow_heal is set, the manager will attempt to heal either
581
+ synchronously before returning or asynchronously prior to any network
582
+ calls. All replicas must pass the same value to allow_heal.
583
+ timeout: the timeout for quorum to be ready, if None, the manager's timeout will be used
584
+ recovery operations will use the manager timeout
585
+ """
586
+
587
+ # wait for previous quorum to complete
588
+ if self._quorum_future is not None:
589
+ self._quorum_future.result()
590
+
591
+ self._errored = None
592
+ self._healing = False
593
+
594
+ # TODO: we should really be wrapping this whole section in a try-except
595
+ # block to allow gracefully recovering from issues in PG setup and quorum.
596
+
597
+ self._quorum_future = self._executor.submit(
598
+ self._async_quorum,
599
+ allow_heal=allow_heal,
600
+ shrink_only=shrink_only,
601
+ quorum_timeout=timeout or self._quorum_timeout,
602
+ curr_device=(
603
+ torch.accelerator.current_device_index()
604
+ if torch.accelerator.is_available()
605
+ else -1
606
+ ),
607
+ )
608
+ if not self._use_async_quorum:
609
+ self.wait_quorum()
610
+
611
+ if self._healing:
612
+ # eagerly apply pending state_dict so we can run the forwards pass
613
+ self._apply_pending_state_dict()
614
+
615
+ # we are forcing healing at the beginning so we're in a good state
616
+ # and don't need to zero_grad
617
+ self._healing = False
618
+
619
+ @torch.profiler.record_function("torchft::manager::wait_quorum")
620
+ def wait_quorum(self) -> None:
621
+ """
622
+ Wait for the quorum to complete.
623
+
624
+ ProcessGroup will be in a healthy state after this returns.
625
+ """
626
+ assert (
627
+ self._quorum_future is not None
628
+ ), "must call start_quorum before wait_quorum"
629
+ self._quorum_future.result()
630
+
631
+ @torch.profiler.record_function("torchft::manager::_async_quorum")
632
+ def _async_quorum(
633
+ self,
634
+ allow_heal: bool,
635
+ shrink_only: bool,
636
+ quorum_timeout: timedelta,
637
+ curr_device: int,
638
+ ) -> None:
639
+ torch.multiprocessing._set_thread_name("torchft_quorum")
640
+
641
+ if curr_device >= 0 and torch.accelerator.is_available():
642
+ torch.accelerator.set_device_index(curr_device)
643
+
644
+ quorum = None
645
+ with torch.profiler.record_function("torchft::manager::_client::_quorum"):
646
+ quorum = self._client._quorum(
647
+ group_rank=self._group_rank,
648
+ step=self._step,
649
+ checkpoint_metadata=self._checkpoint_transport.metadata(),
650
+ shrink_only=shrink_only,
651
+ timeout=quorum_timeout,
652
+ init_sync=self._init_sync,
653
+ commit_failures=self._commit_failures,
654
+ )
655
+
656
+ quorum_id = quorum.quorum_id
657
+ replica_rank = quorum.replica_rank
658
+ replica_world_size = quorum.replica_world_size
659
+ recover_src_manager_address = quorum.recover_src_manager_address
660
+ store_address = quorum.store_address
661
+ max_step = quorum.max_step
662
+ max_replica_rank = quorum.max_replica_rank
663
+ max_replica_world_size = quorum.max_world_size
664
+ heal = quorum.heal
665
+ replica_ids = quorum.replica_ids
666
+
667
+ ranks_in_quorum = [
668
+ extract_trailing_digits(replica_id.split(":")[0]) * self._group_world_size
669
+ + self._group_rank
670
+ for replica_id in replica_ids
671
+ ]
672
+
673
+ # When using async quorum we need to take the recovered workers.
674
+ # When not using async quorum we need to take the max world size as all
675
+ # workers will be healthy.
676
+ self._participating_replica_rank, self._participating_replica_world_size = (
677
+ (max_replica_rank, max_replica_world_size)
678
+ if self._use_async_quorum or not allow_heal
679
+ else (replica_rank, replica_world_size)
680
+ )
681
+
682
+ # For fixed with spares we need to ensure that we don't have more
683
+ # participating replicas than the min replica size.
684
+ if self._replica_world_size_mode == WorldSizeMode.FIXED_WITH_SPARES:
685
+ self._participating_replica_world_size = min(
686
+ self._participating_replica_world_size, self._min_replica_size
687
+ )
688
+ if (
689
+ self._participating_replica_rank is not None
690
+ and self._participating_replica_rank >= self._min_replica_size
691
+ ):
692
+ self._participating_replica_rank = None
693
+
694
+ if quorum_id != self._quorum_id:
695
+ self.quorum_logger.info(
696
+ "",
697
+ extra={
698
+ "job_id": os.environ.get("JOB_ID", "unknown"),
699
+ "replica_id": self._replica_id,
700
+ "rank": self._group_rank,
701
+ "quorum_id": quorum_id,
702
+ "step": max_step,
703
+ },
704
+ )
705
+ store_prefixed_addr = (
706
+ f"{store_address}/torchft/{quorum_id}/{self._group_rank}"
707
+ )
708
+
709
+ self._logger.info(f"reconfiguring for {quorum_id=} {store_prefixed_addr=}")
710
+ # We use the replica rank and world as we want all replicas in the PG.
711
+ try:
712
+ self._quorum_id = quorum_id
713
+ with torch.profiler.record_function("torchft::manager::_pg::configure"):
714
+ # Reset GPU state for Flight Recorder
715
+ if torch.accelerator.is_available():
716
+ torch.accelerator.synchronize()
717
+
718
+ self._pg.configure(
719
+ store_prefixed_addr,
720
+ self._replica_id if self._replica_id is not None else "0",
721
+ replica_rank,
722
+ replica_world_size,
723
+ quorum_id,
724
+ self._group_rank,
725
+ self._group_world_size,
726
+ ranks_in_quorum,
727
+ )
728
+
729
+ # We need to reset the trace after reconfiguring the PG because that
730
+ # calls abort which may trigger a dump
731
+ self._logger.info(
732
+ f"resetting fr recording for quorum id {self._quorum_id}"
733
+ )
734
+ self._update_fr_path()
735
+ torch._C._distributed_c10d._reset_fr_recording_nccl() # pyre-ignore
736
+ except Exception as e:
737
+ self._logger.exception(f"got exception in pg configure: {e}")
738
+ self.report_error(e)
739
+ return
740
+
741
+ if allow_heal:
742
+ # run recovery on the recovery stream if available
743
+ recovery_stream = self._recovery_stream
744
+ with get_stream_context(recovery_stream):
745
+ try:
746
+ if quorum.recover_dst_replica_ranks:
747
+ self._logger.info(
748
+ f"peers need recovery from us {quorum.recover_dst_replica_ranks}"
749
+ )
750
+ with torch.profiler.record_function(
751
+ "torchft::manager::_checkpoint_transport::send_checkpoint"
752
+ ):
753
+ self._checkpoint_transport.send_checkpoint(
754
+ dst_ranks=quorum.recover_dst_replica_ranks,
755
+ step=max_step,
756
+ state_dict=self._manager_state_dict(),
757
+ timeout=self._timeout,
758
+ )
759
+
760
+ # See manager.rs for healing conditions
761
+ if heal:
762
+ self._healing = True
763
+ self._logger.info(
764
+ f"healing required, fetching checkpoint metadata from {recover_src_manager_address=} {max_step=}"
765
+ )
766
+ primary_client = ManagerClient(
767
+ recover_src_manager_address,
768
+ connect_timeout=self._connect_timeout,
769
+ )
770
+ checkpoint_metadata = primary_client._checkpoint_metadata(
771
+ self._group_rank, timeout=self._timeout
772
+ )
773
+ recover_src_replica_rank = quorum.recover_src_replica_rank
774
+ assert (
775
+ recover_src_replica_rank is not None
776
+ ), "must have a recover rank when healing"
777
+
778
+ self._logger.info(
779
+ f"fetching checkpoint from {recover_src_replica_rank=} with {checkpoint_metadata=}"
780
+ )
781
+
782
+ # we apply the user state dict only when safe from the main thread
783
+ # save it for now
784
+ with torch.profiler.record_function(
785
+ "torchft::manager::_checkpoint_transport::recv_checkpoint"
786
+ ):
787
+ self._pending_state_dict = self._checkpoint_transport.recv_checkpoint(
788
+ src_rank=recover_src_replica_rank,
789
+ metadata=checkpoint_metadata, # Depending on group rank
790
+ step=max_step,
791
+ timeout=self._timeout,
792
+ )
793
+
794
+ # pyre-fixme[6]: got object
795
+ self.load_state_dict(self._pending_state_dict["torchft"])
796
+
797
+ # This isn't strictly needed as loading the state_dict above should
798
+ # restore the correct step but it makes writing tests simpler.
799
+ self._step = max_step
800
+ except Exception as e:
801
+ self._logger.exception(f"got exception in recovery: {e}")
802
+ self.report_error(e)
803
+
804
+ self._recovery_event = (
805
+ torch.accelerator.current_stream().record_event()
806
+ if recovery_stream is not None
807
+ else None
808
+ )
809
+
810
+ def _update_fr_path(self) -> None:
811
+ """
812
+ Update the path that flight recorder will dump the traces to.
813
+ The format is
814
+ <TORCH_FR_DUMP_TEMP_FILE_ENV>_quorum_<quorum_id>/<global_rank>
815
+ """
816
+ if self._original_fr_dump_temp_file is not None:
817
+ folder = f"{self._original_fr_dump_temp_file}_quorum_{self._quorum_id}"
818
+ os.makedirs(folder, exist_ok=True)
819
+ os.environ[TORCH_FR_DUMP_TEMP_FILE_ENV] = f"{folder}/{self._global_rank}"
820
+
821
+ def _apply_pending_state_dict(self) -> None:
822
+ assert self._healing, "must be in healing state"
823
+
824
+ # synchronize on future
825
+ assert self._quorum_future is not None, "must call step before should_commit"
826
+ self._quorum_future.result()
827
+
828
+ pending_state_dict = self._pending_state_dict
829
+
830
+ if pending_state_dict is None:
831
+ assert self.errored(), "checkpoint was not staged and no error occured"
832
+ else:
833
+ self._logger.info("applying pending state dict")
834
+
835
+ assert (
836
+ len(self._load_state_dict_fns) > 0
837
+ ), "user load_state_dict is not initialized."
838
+
839
+ pending_user_state_dict = cast(
840
+ Dict[str, object], pending_state_dict["user"]
841
+ )
842
+
843
+ for key in self._load_state_dict_fns.keys():
844
+ load_state_dict_fn = self._load_state_dict_fns[key]
845
+ load_state_dict_fn(pending_user_state_dict[key])
846
+
847
+ self._pending_state_dict = None
848
+ self._logger.info("Loaded state dict.")
849
+
850
+ @torch.profiler.record_function("torchft::manager::should_commit")
851
+ def should_commit(self, timeout: Optional[timedelta] = None) -> bool:
852
+ """
853
+ .. note::
854
+ We recommend using the :py:class:`torchft.optim.OptimizerWrapper` instead of calling this directly.
855
+
856
+ Must be called after the backwards pass completes but before stepping the optimizer.
857
+
858
+ The optimizer must only be stepped if this returns True.
859
+
860
+ This must be called on all workers within a replica group. This uses a
861
+ collective to ensure all workers within a replica return the same value.
862
+ If an error occurs on any worker, all workers will return False.
863
+ Different replica groups may return different values.
864
+
865
+ This should only be called once per step.
866
+
867
+ If max_retries is set and should_commit fails that many times consecutively,
868
+ this method will raise a RuntimeError to prevent indefinite failure loops.
869
+
870
+ Returns:
871
+ True if the optimizer should be stepped, False otherwise
872
+ Raises:
873
+ RuntimeError: if should_commit fails max_retries times in a row and max_retries is set
874
+ """
875
+ # make sure recovery is complete before committing
876
+ with torch.profiler.record_function(
877
+ "torchft::manager::should_commmit::recovery_stream::synchronize"
878
+ ):
879
+ if self._recovery_event is not None:
880
+ self._recovery_event.synchronize()
881
+ self._recovery_event = None
882
+
883
+ with torch.profiler.record_function(
884
+ "torchft::manager::should_commit::current_stream::synchronize"
885
+ ):
886
+ if torch.accelerator.is_available():
887
+ synchronize()
888
+
889
+ if err := self._pg.errored():
890
+ self.report_error(err)
891
+
892
+ # apply state_dict if healing
893
+ if self._healing:
894
+ self._apply_pending_state_dict()
895
+
896
+ enough_replicas = self.num_participants() >= self._min_replica_size
897
+ local_should_commit = enough_replicas and self._errored is None
898
+ should_commit = self._client.should_commit(
899
+ self._group_rank,
900
+ self._step,
901
+ local_should_commit,
902
+ timeout=timeout or self._timeout,
903
+ )
904
+ self._logger.info(
905
+ f"should_commit={should_commit} enough_replicas={enough_replicas}, errored={self._errored}"
906
+ )
907
+
908
+ self.commits_logger.info(
909
+ "",
910
+ extra={
911
+ "job_id": os.environ.get("JOB_ID", "unknown"),
912
+ "replica_id": self._replica_id,
913
+ "rank": self._group_rank,
914
+ "quorum_id": self._quorum_id,
915
+ "step": self._step,
916
+ "commit_result": should_commit,
917
+ },
918
+ )
919
+
920
+ self._checkpoint_transport.disallow_checkpoint()
921
+
922
+ # decide whether we're in a healthy state to increase the step count
923
+ if should_commit:
924
+ self._step += 1
925
+ self._batches_committed += self.num_participants()
926
+ self._commit_failures = 0 # Reset failure counter on success
927
+ else:
928
+ self._commit_failures += 1
929
+ # Check if we've hit max retries
930
+ if (
931
+ self._max_retries is not None
932
+ and self._commit_failures > self._max_retries
933
+ ):
934
+ msg = f"should_commit failed {self._commit_failures} times consecutively, exceeding max_retries={self._max_retries}"
935
+ self._logger.exception(msg)
936
+ raise RuntimeError(msg)
937
+
938
+ return should_commit
939
+
940
+ def load_state_dict(self, state_dict: Dict[str, int]) -> None:
941
+ """
942
+ Load the state dict from a previous checkpoint.
943
+
944
+ This will restore the step count and internal metadata.
945
+
946
+ Args:
947
+ state_dict: the state dict to load
948
+ """
949
+ self._step = state_dict["step"]
950
+ self._batches_committed = state_dict["batches_committed"]
951
+
952
+ def _manager_state_dict(self) -> Dict[str, object]:
953
+ with self._state_dict_lock.r_lock():
954
+ assert (
955
+ len(self._user_state_dicts) > 0
956
+ ), "user state_dict is not initialized."
957
+ return {
958
+ "user": {key: value() for key, value in self._user_state_dicts.items()},
959
+ "torchft": self.state_dict(),
960
+ }
961
+
962
+ def state_dict(self) -> Dict[str, int]:
963
+ """
964
+ Get the state dict for this manager.
965
+
966
+ This can be used to checkpoint the state of the manager to restore
967
+ from a previous checkpoint.
968
+
969
+ Returns:
970
+ the state dict for this manager
971
+ """
972
+ return {"step": self._step, "batches_committed": self._batches_committed}
973
+
974
+ def current_step(self) -> int:
975
+ """
976
+ Get the current step count.
977
+
978
+ This number is incremented on .step()
979
+
980
+ Returns:
981
+ the current step count
982
+ """
983
+ return self._step
984
+
985
+ def batches_committed(self) -> int:
986
+ """
987
+ Get the total number of batches committed across all steps and replicas.
988
+ 5 replicas participating in 2 steps is 10 batches but may be more than
989
+ 10 examples depending on batch size.
990
+
991
+ This number is incremented on .step()
992
+
993
+ Returns:
994
+ the total number of batches committed
995
+ """
996
+ return self._batches_committed
997
+
998
+ def participating_rank(self) -> Optional[int]:
999
+ """
1000
+ Get the replica group rank of the current quorum. This will be the same on all
1001
+ ranks within the replica group.
1002
+
1003
+ If this replica group is not participating in the current quorum, this will be None.
1004
+
1005
+ This will block on the async quorum if it is not yet ready.
1006
+
1007
+ Returns:
1008
+ the rank of the current quorum
1009
+ """
1010
+ if self._quorum_future is None:
1011
+ return None
1012
+
1013
+ self.wait_quorum()
1014
+
1015
+ return self._participating_replica_rank
1016
+
1017
+ def num_participants(self) -> int:
1018
+ """
1019
+ Get the number of participants in the current quorum.
1020
+
1021
+ This is the number of replicas participating in the current step.
1022
+
1023
+ This will block on the async quorum if it is not yet ready.
1024
+
1025
+ Returns:
1026
+ the number of participants in the current quorum
1027
+ """
1028
+ if self._quorum_future is None:
1029
+ return 0
1030
+
1031
+ self.wait_quorum()
1032
+
1033
+ assert self._participating_replica_world_size >= 0, "internal error"
1034
+ return self._participating_replica_world_size
1035
+
1036
+ def is_participating(self) -> bool:
1037
+ """
1038
+ Get whether this replica is participating in the current quorum.
1039
+
1040
+ Returns:
1041
+ whether this replica is participating in the current quorum
1042
+ """
1043
+ if self._participating_replica_rank is None:
1044
+ return False
1045
+ if self._healing:
1046
+ assert self._use_async_quorum
1047
+ return False
1048
+ return True
1049
+
1050
+
1051
+ class _ManagerLogger:
1052
+ def __init__(self, manager: Manager, replica_id: str, group_rank: int) -> None:
1053
+ self._logger: logging.Logger = logging.getLogger(__name__)
1054
+ self._replica_id = replica_id
1055
+ self._group_rank = group_rank
1056
+ self._manager = manager
1057
+
1058
+ def prefix(self) -> str:
1059
+ return f"[{self._replica_id}/{self._group_rank} - step {self._manager.current_step()}]"
1060
+
1061
+ def info(self, msg: str) -> None:
1062
+ self._logger.info(f"{self.prefix()} {msg}")
1063
+
1064
+ def warn(self, msg: str) -> None:
1065
+ self._logger.warn(f"{self.prefix()} {msg}")
1066
+
1067
+ def exception(self, msg: str) -> None:
1068
+ self._logger.exception(f"{self.prefix()} {msg}")
1069
+
1070
+
1071
+ T = TypeVar("T")
1072
+ S = TypeVar("S")
1073
+
1074
+
1075
+ class _SimpleFuture(torch.futures.Future[T]):
1076
+ """
1077
+ A simplified implementation of torch.futures.Future that wraps a value.
1078
+
1079
+ This class provides a minimal Future implementation that holds a pre-determined value.
1080
+ It's primarily used as a wrapper for values in the callback chain of `_ManagedFuture`.
1081
+ Most methods raise `RuntimeError` as they're not intended to be called.
1082
+
1083
+ This class is designed to be used only in specific contexts where we don't
1084
+ want to call `value()` on the underlying `Future` as that would cause the CPU to block.
1085
+ """
1086
+
1087
+ def __init__(self, value: object) -> None:
1088
+ super().__init__()
1089
+ self._value = value
1090
+
1091
+ def value(self) -> object:
1092
+ return self._value
1093
+
1094
+ def then(
1095
+ self, callback: Callable[[torch.futures.Future[T]], S]
1096
+ ) -> torch.futures.Future[S]:
1097
+ raise NotImplementedError(
1098
+ "This future is only supposed to be used in callback chain to extract the value"
1099
+ )
1100
+
1101
+ def wait(self) -> object:
1102
+ raise NotImplementedError(
1103
+ "This future is only supposed to be used in callback chain to extract the value"
1104
+ )
1105
+
1106
+ def done(self) -> bool:
1107
+ raise NotImplementedError(
1108
+ "This future is only supposed to be used in callback chain to extract the value"
1109
+ )
1110
+
1111
+ def add_done_callback(
1112
+ self, callback: Callable[[torch.futures.Future[T]], None]
1113
+ ) -> None:
1114
+ raise NotImplementedError(
1115
+ "This future is only supposed to be used in callback chain to extract the value"
1116
+ )
1117
+
1118
+ def set_result(self, result: object) -> None:
1119
+ raise NotImplementedError(
1120
+ "This future is only supposed to be used in callback chain to extract the value"
1121
+ )
1122
+
1123
+ def set_exception(self, result: object) -> None:
1124
+ raise NotImplementedError(
1125
+ "This future is only supposed to be used in callback chain to extract the value"
1126
+ )
1127
+
1128
+
1129
+ class _ManagedFuture(torch.futures.Future[T]):
1130
+ """
1131
+ A specialized Future implementation that works alongside `_ManagedWork`.
1132
+
1133
+ This class extends torch.futures.Future to provide future chaining that is
1134
+ lazy - `then()` method simply stores the callback, which is only executed when
1135
+ `wait()` is called on `_ManagedFuture` or `_ManagedWork`
1136
+
1137
+ Callback chains are implemented as a linked list of `_ManagedFuture` objects through the
1138
+ `_next` attribute. When appending a callback to the chain, it also updates the tail of the
1139
+ linked list stored in `_ManagedWork`.
1140
+
1141
+ Delegates actual future operations to an internal torch.futures.Future.
1142
+
1143
+ Raises RuntimeError for methods that should not be called.
1144
+ """
1145
+
1146
+ def __init__(self, managed_work: weakref.ReferenceType["_ManagedWork"]) -> None:
1147
+ super().__init__()
1148
+ # Store a weak reference to _ManagedWork to avoid reference cycles
1149
+ self._managed_work = managed_work
1150
+
1151
+ # The underlying torch.futures.Future that this class delegates to
1152
+ self._fut: Optional[torch.futures.Future[T]] = None
1153
+
1154
+ # The next future in the callback chain
1155
+ self._next: Optional[_ManagedFuture[object]] = None
1156
+
1157
+ # The callback to be executed when the future is completed - this callback
1158
+ # returns the next future in the chain
1159
+ self._callback: Optional[Callable[[torch.futures.Future[T]], object]] = None
1160
+
1161
+ def then(
1162
+ self,
1163
+ callback: Callable[[torch.futures.Future[T]], S],
1164
+ ) -> torch.futures.Future[S]:
1165
+ """
1166
+ Sets the callback to be executed when the future is completed.
1167
+
1168
+ Since the callback returns a future, this method also creates a new future
1169
+ in the chain and also updates the tail of the chain in `_ManagedWork`.
1170
+ """
1171
+ managed_work = self._managed_work()
1172
+ assert managed_work is not None, "got garbage collected"
1173
+
1174
+ self._callback = callback
1175
+ self._next = _ManagedFuture[object](self._managed_work)
1176
+ managed_work._managed_fut_tail = self._next
1177
+ return cast(torch.futures.Future[S], self._next)
1178
+
1179
+ def wait(self) -> object:
1180
+ assert self._fut
1181
+ return self._fut.wait()
1182
+
1183
+ def value(self) -> object:
1184
+ raise NotImplementedError(
1185
+ "This future is supposed to be used to create callback chain"
1186
+ )
1187
+
1188
+ def done(self) -> bool:
1189
+ raise NotImplementedError(
1190
+ "This future is supposed to be used to create callback chain"
1191
+ )
1192
+
1193
+ def add_done_callback(
1194
+ self, callback: Callable[[torch.futures.Future[T]], None]
1195
+ ) -> None:
1196
+ raise NotImplementedError(
1197
+ "This future is supposed to be used to create callback chain"
1198
+ )
1199
+
1200
+ def set_result(self, result: object) -> None:
1201
+ raise NotImplementedError(
1202
+ "This future is supposed to be used to create callback chain"
1203
+ )
1204
+
1205
+ def set_exception(self, result: object) -> None:
1206
+ raise NotImplementedError(
1207
+ "This future is supposed to be used to create callback chain"
1208
+ )
1209
+
1210
+
1211
+ class _ManagedWork(dist._Work):
1212
+ """
1213
+ A specialized `Work` implementation that works alongside `_ManagedFuture` to create
1214
+ callback chains lazily. The callback chain is created when `wait()`, `block_current_stream()`
1215
+ or `synchronize()` are called.
1216
+ """
1217
+
1218
+ def __init__(
1219
+ self,
1220
+ manager: Manager,
1221
+ work: dist._Work,
1222
+ value: object,
1223
+ ) -> None:
1224
+ super().__init__()
1225
+ # Underlying `Work` retruned from process group operations
1226
+ self._work = work
1227
+
1228
+ # Used to report errors to the manager through `wrap_future()`
1229
+ self._manager = manager
1230
+
1231
+ # The value returned by the final future in the callback chain
1232
+ self._value = value
1233
+
1234
+ # The head of the callback chain
1235
+ self._managed_fut_head = _ManagedFuture[object](weakref.ref(self))
1236
+
1237
+ # The tail of the callback chain
1238
+ self._managed_fut_tail: _ManagedFuture[object] = self._managed_fut_head
1239
+
1240
+ # The stream used to created the `Work` - we ensure all operations in the future
1241
+ # callback chain are executed on this stream
1242
+ self._stream: Optional[torch.Stream] = (
1243
+ torch.accelerator.current_stream()
1244
+ if torch.accelerator.is_available()
1245
+ else None
1246
+ )
1247
+
1248
+ # To ensure the future callback chain is only created once
1249
+ self._is_set_future_callback_called = False
1250
+
1251
+ def _set_future_callback(
1252
+ self,
1253
+ ) -> None:
1254
+ """
1255
+ Sets up the stored future callback chain.
1256
+
1257
+ This method creates a chain of callbacks for the futures in the managed work,
1258
+ ensuring that each callback is executed in the proper order and with the
1259
+ appropriate stream context. It also wraps the futures with error handling
1260
+ through the manager's `wrap_future` method.
1261
+
1262
+ The method is called internally when waiting or synchronizing on the work.
1263
+ """
1264
+ if self._is_set_future_callback_called:
1265
+ return
1266
+
1267
+ managed_fut: _ManagedFuture[object] = self._managed_fut_head
1268
+ managed_fut._fut = self._work.get_future()
1269
+ value = self._value
1270
+
1271
+ is_future_wrapped = False
1272
+ while managed_fut._next:
1273
+
1274
+ def callback(
1275
+ fut: torch.futures.Future[object],
1276
+ ) -> object:
1277
+ nonlocal managed_fut, value
1278
+ # change the stream to avoid making the callback stream
1279
+ # dependent on process group stream running the allreduce
1280
+ with get_stream_context(self._stream):
1281
+ # Setup stream dependency
1282
+ fut.wait()
1283
+ assert managed_fut._callback
1284
+ value = managed_fut._callback(
1285
+ _SimpleFuture(value),
1286
+ )
1287
+ return value
1288
+
1289
+ assert managed_fut._fut
1290
+ fut = managed_fut._fut.then(callback)
1291
+ assert managed_fut._next
1292
+ managed_fut = managed_fut._next
1293
+ managed_fut._fut = fut
1294
+
1295
+ if is_future_wrapped:
1296
+ continue
1297
+
1298
+ managed_fut._fut = self._manager.wrap_future(managed_fut._fut, value)
1299
+ is_future_wrapped = True
1300
+
1301
+ self._value = value
1302
+ self._is_set_future_callback_called = True
1303
+
1304
+ def _assert_same_stream(self) -> None:
1305
+ """
1306
+ Asserts that the current CUDA stream is the same as the one used to create this work.
1307
+
1308
+ This makes sure users of the API are aware about stream dependencies.
1309
+ """
1310
+ if self._stream is not None:
1311
+ assert self._stream == torch.accelerator.current_stream()
1312
+
1313
+ def wait(self, timeout: Optional[timedelta] = None) -> bool:
1314
+ self._assert_same_stream()
1315
+
1316
+ try:
1317
+ with get_stream_context(self._stream):
1318
+ self._work.wait()
1319
+ self._set_future_callback()
1320
+
1321
+ with get_stream_context(self._stream):
1322
+ self._managed_fut_tail.wait()
1323
+
1324
+ return True
1325
+ except Exception as e:
1326
+ self._manager._logger.exception(f"got exception waiting for work {e}")
1327
+ self._manager.report_error(e)
1328
+ return False
1329
+
1330
+ def block_current_stream(self, timeout: Optional[timedelta] = None) -> None:
1331
+ self._assert_same_stream()
1332
+
1333
+ with get_stream_context(self._stream):
1334
+ self._work.block_current_stream()
1335
+
1336
+ self._set_future_callback()
1337
+
1338
+ def synchronize(self) -> None:
1339
+ self._assert_same_stream()
1340
+
1341
+ if torch.cuda.is_available():
1342
+ self.block_current_stream()
1343
+ elif torch.xpu.is_available():
1344
+ self._set_future_callback()
1345
+ else:
1346
+ # No stream dependencies need to be set
1347
+ self._set_future_callback()
1348
+
1349
+ def get_future(
1350
+ self,
1351
+ ) -> torch.futures.Future[object]:
1352
+ """
1353
+ Returns:
1354
+ The tail of the managed future chain, which represents the final
1355
+ result of all the chained operations. This future will be completed when
1356
+ all the work and its callbacks have been executed.
1357
+ """
1358
+ return self._managed_fut_tail