torchft-nightly 2026.1.3__cp310-cp310-manylinux_2_24_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchft/__init__.py +34 -0
- torchft/_test/diloco_trainer.py +287 -0
- torchft/_test/managed_work_test.py +320 -0
- torchft/_test_utils.py +111 -0
- torchft/_torchft.cpython-310-x86_64-linux-gnu.so +0 -0
- torchft/_torchft.pyi +116 -0
- torchft/checkpointing/__init__.py +20 -0
- torchft/checkpointing/_rwlock.py +136 -0
- torchft/checkpointing/_serialization.py +39 -0
- torchft/checkpointing/http_transport.py +299 -0
- torchft/checkpointing/http_transport_bench.py +61 -0
- torchft/checkpointing/http_transport_test.py +146 -0
- torchft/checkpointing/pg_transport.py +306 -0
- torchft/checkpointing/pg_transport_bench.py +99 -0
- torchft/checkpointing/pg_transport_test.py +101 -0
- torchft/checkpointing/rwlock_test.py +58 -0
- torchft/checkpointing/transport.py +68 -0
- torchft/checkpointing/transport_test.py +161 -0
- torchft/collectives.py +415 -0
- torchft/collectives_test.py +212 -0
- torchft/coordination.py +39 -0
- torchft/coordination_test.py +29 -0
- torchft/data.py +77 -0
- torchft/data_test.py +39 -0
- torchft/ddp.py +105 -0
- torchft/ddp_test.py +68 -0
- torchft/diloco_regression_test.py +644 -0
- torchft/examples/slurm/README.md +34 -0
- torchft/examples/slurm/punisher.py +95 -0
- torchft/examples/slurm/runner.py +221 -0
- torchft/fsdp_test.py +102 -0
- torchft/futures.py +353 -0
- torchft/futures_test.py +140 -0
- torchft/http.py +13 -0
- torchft/lighthouse_test.py +163 -0
- torchft/local_sgd.py +796 -0
- torchft/local_sgd_integ_test.py +600 -0
- torchft/local_sgd_test.py +324 -0
- torchft/manager.py +1358 -0
- torchft/manager_integ_test.py +653 -0
- torchft/manager_test.py +911 -0
- torchft/multiprocessing.py +38 -0
- torchft/multiprocessing_dummy_context.py +135 -0
- torchft/multiprocessing_test.py +58 -0
- torchft/optim.py +63 -0
- torchft/optim_test.py +50 -0
- torchft/otel.py +134 -0
- torchft/parameter_server.py +195 -0
- torchft/parameter_server_test.py +47 -0
- torchft/process_group.py +2118 -0
- torchft/process_group_test.py +1028 -0
- torchft/quantization.py +686 -0
- torchft/quantization_test.py +131 -0
- torchft/torchx.py +89 -0
- torchft/utils.py +67 -0
- torchft/work.py +26 -0
- torchft_nightly-2026.1.3.dist-info/METADATA +308 -0
- torchft_nightly-2026.1.3.dist-info/RECORD +61 -0
- torchft_nightly-2026.1.3.dist-info/WHEEL +4 -0
- torchft_nightly-2026.1.3.dist-info/entry_points.txt +2 -0
- torchft_nightly-2026.1.3.dist-info/licenses/LICENSE +34 -0
torchft/manager.py
ADDED
|
@@ -0,0 +1,1358 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
"""
|
|
8
|
+
Manager
|
|
9
|
+
=========
|
|
10
|
+
|
|
11
|
+
This module implements the Manager that manages the full fault tolerant training
|
|
12
|
+
loop.
|
|
13
|
+
|
|
14
|
+
The Manager is responsible for managing the
|
|
15
|
+
full training loop, communicating with the Lighthouse server to figure out
|
|
16
|
+
quorum, reconfiguring the ProcessGroups and restoring checkpoint state when
|
|
17
|
+
recovering.
|
|
18
|
+
|
|
19
|
+
This uses wrapper classes to wrap the standard PyTorch Optimizer and Module
|
|
20
|
+
classes to provide fault tolerance. These wrappers indented to add fault
|
|
21
|
+
tolerance with minimal changes to the users modeling code and training loop.
|
|
22
|
+
|
|
23
|
+
This is designed to work with the standard PyTorch DistributedDataParallel module
|
|
24
|
+
and Hybrid FSDP.
|
|
25
|
+
|
|
26
|
+
"""
|
|
27
|
+
|
|
28
|
+
import concurrent.futures
|
|
29
|
+
import logging
|
|
30
|
+
import os
|
|
31
|
+
import socket
|
|
32
|
+
import traceback
|
|
33
|
+
import uuid
|
|
34
|
+
import weakref
|
|
35
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
36
|
+
from contextlib import nullcontext
|
|
37
|
+
from datetime import timedelta
|
|
38
|
+
from enum import Enum
|
|
39
|
+
from typing import (
|
|
40
|
+
Any,
|
|
41
|
+
Callable,
|
|
42
|
+
cast,
|
|
43
|
+
Dict,
|
|
44
|
+
List,
|
|
45
|
+
Optional,
|
|
46
|
+
TYPE_CHECKING,
|
|
47
|
+
TypeAlias,
|
|
48
|
+
TypeVar,
|
|
49
|
+
Union,
|
|
50
|
+
)
|
|
51
|
+
|
|
52
|
+
import torch
|
|
53
|
+
import torch.distributed as dist
|
|
54
|
+
from torch.distributed import ReduceOp, TCPStore
|
|
55
|
+
from torch.distributed.distributed_c10d import AllreduceOptions, ReduceOp, Work
|
|
56
|
+
|
|
57
|
+
from torchft._torchft import ManagerClient, ManagerServer
|
|
58
|
+
from torchft.checkpointing import CheckpointTransport, HTTPTransport
|
|
59
|
+
from torchft.checkpointing._rwlock import RWLock
|
|
60
|
+
from torchft.futures import future_timeout
|
|
61
|
+
from torchft.utils import get_stream_context, synchronize
|
|
62
|
+
from torchft.work import _DummyWork
|
|
63
|
+
|
|
64
|
+
if TYPE_CHECKING:
|
|
65
|
+
from torchft.process_group import ProcessGroup
|
|
66
|
+
|
|
67
|
+
IS_TRITON_AVAILABLE = True
|
|
68
|
+
try:
|
|
69
|
+
# pyre-ignore[21]: Could not find a module corresponding to import `triton`
|
|
70
|
+
import triton
|
|
71
|
+
|
|
72
|
+
from torchft.collectives import allreduce_quantized
|
|
73
|
+
except ImportError:
|
|
74
|
+
IS_TRITON_AVAILABLE = False
|
|
75
|
+
|
|
76
|
+
MANAGER_ADDR_KEY: str = "manager_addr"
|
|
77
|
+
MANAGER_PORT_ENV: str = "TORCHFT_MANAGER_PORT"
|
|
78
|
+
REPLICA_ID_KEY: str = "replica_id"
|
|
79
|
+
|
|
80
|
+
# Environment variables for various timeouts. These can also be passed
|
|
81
|
+
# in through the manager but the environment variables take precedence.
|
|
82
|
+
TIMEOUT_SEC_ENV: str = "TORCHFT_TIMEOUT_SEC"
|
|
83
|
+
QUORUM_TIMEOUT_SEC_ENV: str = "TORCHFT_QUORUM_TIMEOUT_SEC"
|
|
84
|
+
CONNECT_TIMEOUT_SEC_ENV: str = "TORCHFT_CONNECT_TIMEOUT_SEC"
|
|
85
|
+
|
|
86
|
+
# Environment variable for the number of retries to use for the quorum.
|
|
87
|
+
# We need to retry quorum in case lighthouse fails. Otherwise, if we
|
|
88
|
+
# crash if call to quorum fails, all replicas will crash.
|
|
89
|
+
QUORUM_RETRIES_ENV: str = "TORCHFT_QUORUM_RETRIES"
|
|
90
|
+
|
|
91
|
+
TORCH_FR_DUMP_TEMP_FILE_ENV: str = "TORCH_FR_DUMP_TEMP_FILE"
|
|
92
|
+
|
|
93
|
+
T = TypeVar("T")
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
def get_timeout(
|
|
97
|
+
timeout_sec_env: str | None, default_timeout_sec: timedelta
|
|
98
|
+
) -> timedelta:
|
|
99
|
+
"""
|
|
100
|
+
Get the timeout from the environment variable or the default value.
|
|
101
|
+
|
|
102
|
+
Args:
|
|
103
|
+
timeout_sec_env: The environment variable for the timeout
|
|
104
|
+
default_timeout_sec: The default timeout
|
|
105
|
+
Returns:
|
|
106
|
+
The timeout to use. Environment variable takes precedence.
|
|
107
|
+
"""
|
|
108
|
+
if timeout_sec_env is not None:
|
|
109
|
+
return timedelta(seconds=int(timeout_sec_env))
|
|
110
|
+
|
|
111
|
+
return default_timeout_sec
|
|
112
|
+
|
|
113
|
+
|
|
114
|
+
def extract_trailing_digits(s: str) -> int:
|
|
115
|
+
"""
|
|
116
|
+
Extracts the trailing digits from the end of the string s.
|
|
117
|
+
Returns an empty string if no trailing digits are found.
|
|
118
|
+
"""
|
|
119
|
+
i = len(s) - 1
|
|
120
|
+
while i >= 0 and s[i].isdigit():
|
|
121
|
+
i -= 1
|
|
122
|
+
return int(s[i + 1 :]) if i < len(s) - 1 else 0
|
|
123
|
+
|
|
124
|
+
|
|
125
|
+
class WorldSizeMode(Enum):
|
|
126
|
+
"""
|
|
127
|
+
This controls the numerics for the job when doing allreduces across replicas
|
|
128
|
+
when the world size is larger than ``min_replica_size``. The world size will
|
|
129
|
+
never be smaller than ``min_replica_size``.
|
|
130
|
+
|
|
131
|
+
DYNAMIC:
|
|
132
|
+
The world size will dynamical increase to use all available
|
|
133
|
+
replicas and normalize the gradient by the world size.
|
|
134
|
+
FIXED_WITH_SPARES:
|
|
135
|
+
The number of active replicas is ``min_replica_size`` and any spares
|
|
136
|
+
will contribute zero gradients.
|
|
137
|
+
"""
|
|
138
|
+
|
|
139
|
+
DYNAMIC = 0
|
|
140
|
+
FIXED_WITH_SPARES = 1
|
|
141
|
+
|
|
142
|
+
|
|
143
|
+
class ExceptionWithTraceback(Exception):
|
|
144
|
+
def __init__(self, e: Exception) -> None:
|
|
145
|
+
self.original_exception = e
|
|
146
|
+
self.stack_trace: str = traceback.format_exc()
|
|
147
|
+
super().__init__(f"{e}\n{self.stack_trace}")
|
|
148
|
+
|
|
149
|
+
|
|
150
|
+
class Manager:
|
|
151
|
+
"""
|
|
152
|
+
Manager manages the full fault tolerant training loop.
|
|
153
|
+
|
|
154
|
+
This requires the that the TCPStore specified by the store_addr and
|
|
155
|
+
store_port or MASTER_ADDR and MASTER_PORT environment variables to be
|
|
156
|
+
started prior to creating this manager. If using a modern version of
|
|
157
|
+
torchelastic this will already be the case. Otherwise, it should be started
|
|
158
|
+
via torch.distributed.init_process_group prior to creating this manager.
|
|
159
|
+
|
|
160
|
+
NOTE: when saving periodic checkpoints you must save and restore the
|
|
161
|
+
Manager's state_dict as well to avoid synchronization issues.
|
|
162
|
+
"""
|
|
163
|
+
|
|
164
|
+
def __init__(
|
|
165
|
+
self,
|
|
166
|
+
pg: "ProcessGroup",
|
|
167
|
+
load_state_dict: Optional[Callable[[T], None]],
|
|
168
|
+
state_dict: Optional[Callable[[], T]],
|
|
169
|
+
min_replica_size: int,
|
|
170
|
+
use_async_quorum: bool = True,
|
|
171
|
+
timeout: timedelta = timedelta(seconds=60),
|
|
172
|
+
quorum_timeout: timedelta = timedelta(seconds=60),
|
|
173
|
+
connect_timeout: timedelta = timedelta(seconds=60),
|
|
174
|
+
rank: Optional[int] = None,
|
|
175
|
+
world_size: Optional[int] = None,
|
|
176
|
+
world_size_mode: WorldSizeMode = WorldSizeMode.DYNAMIC,
|
|
177
|
+
store_addr: Optional[str] = None,
|
|
178
|
+
store_port: Optional[int] = None,
|
|
179
|
+
lighthouse_addr: Optional[str] = None,
|
|
180
|
+
replica_id: Optional[str] = None,
|
|
181
|
+
port: Optional[int] = None,
|
|
182
|
+
hostname: str = socket.gethostname(),
|
|
183
|
+
heartbeat_interval: timedelta = timedelta(milliseconds=100),
|
|
184
|
+
checkpoint_transport: Optional[CheckpointTransport[Dict[str, T]]] = None,
|
|
185
|
+
init_sync: bool = True,
|
|
186
|
+
max_retries: Optional[int] = None,
|
|
187
|
+
quorum_retries: int = 0,
|
|
188
|
+
) -> None:
|
|
189
|
+
"""
|
|
190
|
+
Args:
|
|
191
|
+
load_state_dict: function to load the state dict when recovering
|
|
192
|
+
state_dict: function to save the state dict with recovering
|
|
193
|
+
min_replica_size: minimum number of replicas on each step
|
|
194
|
+
port: if rank==0, the port to run the manager server on.
|
|
195
|
+
Port assignment priority:
|
|
196
|
+
1. this argument
|
|
197
|
+
2. TORCHFT_MANAGER_PORT env var
|
|
198
|
+
3. arbitrary port assigned via 0
|
|
199
|
+
use_async_quorum: whether to run the quorum asynchronously during the forward pass
|
|
200
|
+
timeout: the default timeout for all operations
|
|
201
|
+
Included:
|
|
202
|
+
* collectives such as allreduce
|
|
203
|
+
* should_commit rpc
|
|
204
|
+
* checkpoint_address rpc
|
|
205
|
+
* checkpoint HTTP operations
|
|
206
|
+
* wrap_future
|
|
207
|
+
quorum_timeout: the default timeout to wait for the quorum to complete.
|
|
208
|
+
This generally should be longer than the training step time /
|
|
209
|
+
the interval between quorum checks to avoid any split brain
|
|
210
|
+
issues.
|
|
211
|
+
|
|
212
|
+
For LocalSGD/DiLoCo this may need to be set to ~1h or longer
|
|
213
|
+
depending on how frequently the syncs occur.
|
|
214
|
+
connect_timeout: the timeout used for establishing rpc connections
|
|
215
|
+
to ManagerServer and Lighthouse
|
|
216
|
+
rank: the replica group local rank, referred to as group_rank in manager.py for clarity
|
|
217
|
+
world_size: the replica group local world size, referred to as group_world_size in manager.py for clarity
|
|
218
|
+
store_addr: TCPStore address for this replica group
|
|
219
|
+
store_port: TCPStore port for this replica group
|
|
220
|
+
lighthouse_addr: if rank==0, the address of the lighthouse server
|
|
221
|
+
replica_id: if rank==0, the replica_id for this group
|
|
222
|
+
hostname: if rank==0, the hostname to advertise to the lighthouse server
|
|
223
|
+
checkpoint_transport: the checkpoint transport to use for
|
|
224
|
+
transfering checkpoints to recovering replicas, defaults to HTTPTransport
|
|
225
|
+
init_sync: whether to synchronize the model weights on step 0. If
|
|
226
|
+
all of the model weights are initialized identically via
|
|
227
|
+
``torch.set_seed`` you should set this to False.
|
|
228
|
+
max_retries: the maximum number of consecutive should_commit failures to allow
|
|
229
|
+
before raising an exception. If None, will retry indefinitely.
|
|
230
|
+
quorum_retries: the number of times to retry the quorum before crashing
|
|
231
|
+
"""
|
|
232
|
+
self.quorum_logger: logging.Logger = logging.getLogger("torchft_quorums")
|
|
233
|
+
self.commits_logger: logging.Logger = logging.getLogger("torchft_commits")
|
|
234
|
+
self.errors_logger: logging.Logger = logging.getLogger("torchft_errors")
|
|
235
|
+
|
|
236
|
+
self._load_state_dict_fns: Dict[str, Callable[[object], None]] = {}
|
|
237
|
+
self._user_state_dicts: Dict[str, Callable[[], object]] = {}
|
|
238
|
+
|
|
239
|
+
self._original_fr_dump_temp_file: Optional[str] = os.environ.get(
|
|
240
|
+
TORCH_FR_DUMP_TEMP_FILE_ENV
|
|
241
|
+
)
|
|
242
|
+
self._replica_id = replica_id
|
|
243
|
+
|
|
244
|
+
# Protects state dict
|
|
245
|
+
self._state_dict_lock = RWLock(timeout=timeout.total_seconds())
|
|
246
|
+
|
|
247
|
+
if load_state_dict and state_dict:
|
|
248
|
+
self.register_state_dict_fn("default", load_state_dict, state_dict)
|
|
249
|
+
|
|
250
|
+
self._pending_state_dict: Optional[Dict[str, object]] = None
|
|
251
|
+
self._use_async_quorum = use_async_quorum
|
|
252
|
+
|
|
253
|
+
self._timeout: timedelta = get_timeout(
|
|
254
|
+
os.environ.get(TIMEOUT_SEC_ENV, None), timeout
|
|
255
|
+
)
|
|
256
|
+
self._quorum_timeout: timedelta = get_timeout(
|
|
257
|
+
os.environ.get(QUORUM_TIMEOUT_SEC_ENV, None), quorum_timeout
|
|
258
|
+
)
|
|
259
|
+
self._connect_timeout: timedelta = get_timeout(
|
|
260
|
+
os.environ.get(CONNECT_TIMEOUT_SEC_ENV, None), connect_timeout
|
|
261
|
+
)
|
|
262
|
+
|
|
263
|
+
self._replica_world_size_mode = world_size_mode
|
|
264
|
+
self._init_sync = init_sync
|
|
265
|
+
self._max_retries = max_retries
|
|
266
|
+
self._commit_failures = 0
|
|
267
|
+
|
|
268
|
+
self._quorum_retries: int = int(
|
|
269
|
+
os.environ.get(QUORUM_RETRIES_ENV, str(quorum_retries))
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
store_addr = store_addr or os.environ["MASTER_ADDR"]
|
|
273
|
+
store_port = store_port or int(os.environ["MASTER_PORT"])
|
|
274
|
+
self._group_rank: int = rank if rank is not None else int(os.environ["RANK"])
|
|
275
|
+
group_rank = self._group_rank
|
|
276
|
+
self._group_world_size: int = world_size or int(os.environ["WORLD_SIZE"])
|
|
277
|
+
self._min_replica_size = min_replica_size
|
|
278
|
+
|
|
279
|
+
if checkpoint_transport is None:
|
|
280
|
+
checkpoint_transport = HTTPTransport[Dict[str, T]](
|
|
281
|
+
timeout=timeout,
|
|
282
|
+
num_chunks=0,
|
|
283
|
+
)
|
|
284
|
+
|
|
285
|
+
self._checkpoint_transport: CheckpointTransport[Dict[str, T]] = (
|
|
286
|
+
checkpoint_transport
|
|
287
|
+
)
|
|
288
|
+
self._executor = ThreadPoolExecutor(
|
|
289
|
+
max_workers=1, thread_name_prefix="async_quorum"
|
|
290
|
+
)
|
|
291
|
+
self._quorum_future: Optional[concurrent.futures.Future] = None
|
|
292
|
+
|
|
293
|
+
self._store = TCPStore(
|
|
294
|
+
host_name=store_addr,
|
|
295
|
+
port=store_port,
|
|
296
|
+
is_master=False,
|
|
297
|
+
wait_for_workers=False,
|
|
298
|
+
)
|
|
299
|
+
self._pg = pg
|
|
300
|
+
self._manager: Optional[ManagerServer] = None
|
|
301
|
+
|
|
302
|
+
self._recovery_stream: Optional["torch.Stream"] = (
|
|
303
|
+
torch.Stream() if torch.accelerator.is_available() else None
|
|
304
|
+
)
|
|
305
|
+
|
|
306
|
+
# Used to synchronize recovery operation
|
|
307
|
+
self._recovery_event: Optional[torch.Event] = None
|
|
308
|
+
|
|
309
|
+
if self._group_rank == 0:
|
|
310
|
+
if port is None:
|
|
311
|
+
port = int(os.environ.get(MANAGER_PORT_ENV, 0))
|
|
312
|
+
|
|
313
|
+
bind = f"[::]:{port}"
|
|
314
|
+
lighthouse_addr = lighthouse_addr or os.environ["TORCHFT_LIGHTHOUSE"]
|
|
315
|
+
|
|
316
|
+
# We need a unique identifier in the case that a worker restarts quickly and
|
|
317
|
+
# replaces the previous worker with the same ID.
|
|
318
|
+
new_uuid = str(uuid.uuid4())
|
|
319
|
+
if replica_id is None or replica_id == "":
|
|
320
|
+
replica_id = new_uuid
|
|
321
|
+
else:
|
|
322
|
+
replica_id = f"{replica_id}:{new_uuid}"
|
|
323
|
+
self._manager = ManagerServer(
|
|
324
|
+
replica_id=replica_id,
|
|
325
|
+
lighthouse_addr=lighthouse_addr,
|
|
326
|
+
hostname=hostname,
|
|
327
|
+
bind=bind,
|
|
328
|
+
store_addr=f"{store_addr}:{store_port}",
|
|
329
|
+
world_size=self._group_world_size,
|
|
330
|
+
heartbeat_interval=heartbeat_interval,
|
|
331
|
+
connect_timeout=connect_timeout,
|
|
332
|
+
quorum_retries=self._quorum_retries,
|
|
333
|
+
)
|
|
334
|
+
|
|
335
|
+
self._store.set(MANAGER_ADDR_KEY, self._manager.address())
|
|
336
|
+
self._store.set(REPLICA_ID_KEY, replica_id)
|
|
337
|
+
|
|
338
|
+
addr = self._store.get(MANAGER_ADDR_KEY).decode("utf-8")
|
|
339
|
+
self._client = ManagerClient(addr, connect_timeout=connect_timeout)
|
|
340
|
+
|
|
341
|
+
replica_id = self._store.get(REPLICA_ID_KEY).decode("utf-8")
|
|
342
|
+
self._logger = _ManagerLogger(
|
|
343
|
+
manager=self, replica_id=replica_id or "", group_rank=group_rank
|
|
344
|
+
)
|
|
345
|
+
|
|
346
|
+
self._step = 0
|
|
347
|
+
self._quorum_id = -1
|
|
348
|
+
self._errored: Optional[ExceptionWithTraceback] = None
|
|
349
|
+
self._healing = False
|
|
350
|
+
self._batches_committed = 0
|
|
351
|
+
|
|
352
|
+
# first step is 1
|
|
353
|
+
self._participating_replica_rank: Optional[int] = None
|
|
354
|
+
self._participating_replica_world_size: int = 0
|
|
355
|
+
self._is_state_dict_read_allowed = True
|
|
356
|
+
|
|
357
|
+
self._global_rank: int = (
|
|
358
|
+
self._group_rank
|
|
359
|
+
if self._replica_id is None
|
|
360
|
+
else (
|
|
361
|
+
extract_trailing_digits(self._replica_id) * self._group_world_size
|
|
362
|
+
+ self._group_rank
|
|
363
|
+
)
|
|
364
|
+
)
|
|
365
|
+
|
|
366
|
+
self._update_fr_path()
|
|
367
|
+
|
|
368
|
+
def allow_state_dict_read(self) -> None:
|
|
369
|
+
if self._is_state_dict_read_allowed:
|
|
370
|
+
return
|
|
371
|
+
|
|
372
|
+
self._is_state_dict_read_allowed = True
|
|
373
|
+
self._state_dict_lock.w_release()
|
|
374
|
+
|
|
375
|
+
def disallow_state_dict_read(self) -> None:
|
|
376
|
+
if not self._is_state_dict_read_allowed:
|
|
377
|
+
return
|
|
378
|
+
|
|
379
|
+
self._is_state_dict_read_allowed = False
|
|
380
|
+
self._state_dict_lock.w_acquire()
|
|
381
|
+
|
|
382
|
+
def register_state_dict_fn(
|
|
383
|
+
self,
|
|
384
|
+
key: str,
|
|
385
|
+
load_state_dict: Callable[[T], None],
|
|
386
|
+
state_dict: Callable[[], T],
|
|
387
|
+
) -> None:
|
|
388
|
+
# Can't register duplicate keys
|
|
389
|
+
assert key not in self._load_state_dict_fns
|
|
390
|
+
assert key not in self._user_state_dicts
|
|
391
|
+
|
|
392
|
+
self._load_state_dict_fns[key] = cast(Callable[[object], None], load_state_dict)
|
|
393
|
+
self._user_state_dicts[key] = state_dict
|
|
394
|
+
|
|
395
|
+
def set_state_dict_fns(
|
|
396
|
+
self, load_state_dict: Callable[[T], None], state_dict: Callable[[], T]
|
|
397
|
+
) -> None:
|
|
398
|
+
self._logger.warn(
|
|
399
|
+
"`set_state_dict_fns` is deprecated, please use `register_state_dict_fn` instead"
|
|
400
|
+
)
|
|
401
|
+
self.register_state_dict_fn("set_state_dict_fns", load_state_dict, state_dict)
|
|
402
|
+
|
|
403
|
+
def shutdown(self, wait: bool = True) -> None:
|
|
404
|
+
"""
|
|
405
|
+
Shutdown the manager and checkpoint server.
|
|
406
|
+
"""
|
|
407
|
+
self._checkpoint_transport.shutdown(wait=wait)
|
|
408
|
+
if self._manager is not None:
|
|
409
|
+
self._manager.shutdown()
|
|
410
|
+
self._executor.shutdown(wait=wait)
|
|
411
|
+
|
|
412
|
+
@torch.profiler.record_function("torchft::manager::allreduce")
|
|
413
|
+
def allreduce(
|
|
414
|
+
self,
|
|
415
|
+
tensor: torch.Tensor,
|
|
416
|
+
should_quantize: bool = False,
|
|
417
|
+
reduce_op: ReduceOp = ReduceOp.AVG,
|
|
418
|
+
) -> Work:
|
|
419
|
+
"""
|
|
420
|
+
Fault tolerant allreduce the tensor and return a Future that will be completed when
|
|
421
|
+
the tensor is ready.
|
|
422
|
+
|
|
423
|
+
This will automatically scale the tensor by 1 / world_size.
|
|
424
|
+
|
|
425
|
+
If an error occurs during the allreduce:
|
|
426
|
+
|
|
427
|
+
* The Future will be completed with no error and instead tracked asynchronously.
|
|
428
|
+
* After the first error, all subsequent calls will be noops and immediately return.
|
|
429
|
+
* The tensor must be zeroed before being used as it may be corrupted.
|
|
430
|
+
|
|
431
|
+
Args:
|
|
432
|
+
tensor: the tensor to allreduce
|
|
433
|
+
should_quantize: weather the tensor should be quantized before communication
|
|
434
|
+
Returns:
|
|
435
|
+
a Future that will be completed with the allreduced tensor
|
|
436
|
+
"""
|
|
437
|
+
if self.errored():
|
|
438
|
+
return _DummyWork(tensor)
|
|
439
|
+
|
|
440
|
+
self.wait_quorum()
|
|
441
|
+
num_participants: int = self.num_participants()
|
|
442
|
+
|
|
443
|
+
if not self.is_participating():
|
|
444
|
+
tensor.zero_()
|
|
445
|
+
|
|
446
|
+
# special logic for average
|
|
447
|
+
pg_reduce_op = reduce_op
|
|
448
|
+
if reduce_op == ReduceOp.AVG:
|
|
449
|
+
if not torch.is_floating_point(tensor):
|
|
450
|
+
raise ValueError(
|
|
451
|
+
"average reduce op is only supported for floating point tensors"
|
|
452
|
+
)
|
|
453
|
+
pg_reduce_op = ReduceOp.SUM
|
|
454
|
+
|
|
455
|
+
# TODO: increase timeout when waiting when healing
|
|
456
|
+
try:
|
|
457
|
+
# Run the allreduce async and save the work object so we can wait on
|
|
458
|
+
# it later.
|
|
459
|
+
if should_quantize and IS_TRITON_AVAILABLE:
|
|
460
|
+
work = allreduce_quantized(
|
|
461
|
+
[tensor],
|
|
462
|
+
pg_reduce_op,
|
|
463
|
+
self._pg,
|
|
464
|
+
# pyre-fixme[6]: Expected `Optional[streams.Stream]` but got `_C.Stream`
|
|
465
|
+
torch.accelerator.current_stream(),
|
|
466
|
+
)
|
|
467
|
+
else:
|
|
468
|
+
opts = AllreduceOptions()
|
|
469
|
+
opts.reduceOp = pg_reduce_op
|
|
470
|
+
work = self._pg.allreduce([tensor], opts)
|
|
471
|
+
|
|
472
|
+
# schedule grad normalization as a continuation
|
|
473
|
+
# on the Future
|
|
474
|
+
@torch.profiler.record_function("torchft::manager::allreduce::callback")
|
|
475
|
+
def callback(
|
|
476
|
+
fut: torch.futures.Future[torch.Tensor],
|
|
477
|
+
) -> torch.Tensor:
|
|
478
|
+
nonlocal tensor
|
|
479
|
+
if reduce_op == ReduceOp.AVG:
|
|
480
|
+
tensor /= num_participants
|
|
481
|
+
return tensor
|
|
482
|
+
|
|
483
|
+
managed_work = _ManagedWork(self, work, tensor)
|
|
484
|
+
fut = managed_work.get_future()
|
|
485
|
+
fut = cast(torch.futures.Future[torch.Tensor], fut)
|
|
486
|
+
fut = fut.then(callback)
|
|
487
|
+
return managed_work
|
|
488
|
+
|
|
489
|
+
except Exception as e:
|
|
490
|
+
self._logger.exception(
|
|
491
|
+
f"got exception in all reduce -- skipping remaining: {e}"
|
|
492
|
+
)
|
|
493
|
+
self.report_error(e)
|
|
494
|
+
|
|
495
|
+
return _DummyWork(tensor)
|
|
496
|
+
|
|
497
|
+
def report_error(self, e: Exception) -> None:
|
|
498
|
+
"""
|
|
499
|
+
Report an error to the manager.
|
|
500
|
+
|
|
501
|
+
This will cause the manager to skip the current step and will be
|
|
502
|
+
reconfigured on the next step.
|
|
503
|
+
|
|
504
|
+
This should be called when an error occurs that leads to a corrupted
|
|
505
|
+
gradient that needs to be discarded.
|
|
506
|
+
"""
|
|
507
|
+
self._errored = ExceptionWithTraceback(e)
|
|
508
|
+
|
|
509
|
+
def errored(self) -> Optional[ExceptionWithTraceback]:
|
|
510
|
+
"""
|
|
511
|
+
Get whether an error has occurred.
|
|
512
|
+
|
|
513
|
+
Returns:
|
|
514
|
+
The error or None if no error has occurred.
|
|
515
|
+
"""
|
|
516
|
+
return self._errored
|
|
517
|
+
|
|
518
|
+
def wrap_future(
|
|
519
|
+
self,
|
|
520
|
+
fut: torch.futures.Future[T],
|
|
521
|
+
default: T,
|
|
522
|
+
timeout: Optional[timedelta] = None,
|
|
523
|
+
) -> torch.futures.Future[T]:
|
|
524
|
+
"""
|
|
525
|
+
Wrap a Future and swallow any errors that occur and report them to the manager.
|
|
526
|
+
|
|
527
|
+
If an error occurs, the Future will be completed with the default value.
|
|
528
|
+
|
|
529
|
+
Args:
|
|
530
|
+
fut: the Future to wrap
|
|
531
|
+
default: the default value to complete the Future with if an error occurs
|
|
532
|
+
timeout: the timeout for the Future, if None, the manager's timeout will be used
|
|
533
|
+
"""
|
|
534
|
+
|
|
535
|
+
fut = future_timeout(fut, timeout or self._timeout)
|
|
536
|
+
|
|
537
|
+
stream: Optional[torch.Stream] = (
|
|
538
|
+
torch.accelerator.current_stream()
|
|
539
|
+
if torch.accelerator.is_available()
|
|
540
|
+
else None
|
|
541
|
+
)
|
|
542
|
+
|
|
543
|
+
# schedule error handling as a continuation on the Future
|
|
544
|
+
def callback(
|
|
545
|
+
fut: torch.futures.Future[T],
|
|
546
|
+
) -> T:
|
|
547
|
+
nonlocal default, stream
|
|
548
|
+
|
|
549
|
+
with get_stream_context(stream):
|
|
550
|
+
try:
|
|
551
|
+
return fut.value()
|
|
552
|
+
except Exception as e:
|
|
553
|
+
self._logger.exception(
|
|
554
|
+
f"got exception in future -- skipping remaining: {e}"
|
|
555
|
+
)
|
|
556
|
+
self.report_error(e)
|
|
557
|
+
return default
|
|
558
|
+
|
|
559
|
+
fut = fut.then(callback)
|
|
560
|
+
return fut
|
|
561
|
+
|
|
562
|
+
def start_quorum(
|
|
563
|
+
self,
|
|
564
|
+
allow_heal: bool = True,
|
|
565
|
+
shrink_only: bool = False,
|
|
566
|
+
timeout: Optional[timedelta] = None,
|
|
567
|
+
) -> None:
|
|
568
|
+
"""
|
|
569
|
+
.. note::
|
|
570
|
+
We recommend using the :py:class:`torchft.optim.OptimizerWrapper` instead of calling this directly.
|
|
571
|
+
|
|
572
|
+
Computes a new quorum (potentially asynchronously) and readies the
|
|
573
|
+
manager for a new step.
|
|
574
|
+
|
|
575
|
+
It's best practice to call this before the forwards pass of each step for
|
|
576
|
+
performance as computing quorum may take some time.
|
|
577
|
+
|
|
578
|
+
Args:
|
|
579
|
+
allow_heal: (experimental) whether to allow healing at the beginning of the step
|
|
580
|
+
If allow_heal is set, the manager will attempt to heal either
|
|
581
|
+
synchronously before returning or asynchronously prior to any network
|
|
582
|
+
calls. All replicas must pass the same value to allow_heal.
|
|
583
|
+
timeout: the timeout for quorum to be ready, if None, the manager's timeout will be used
|
|
584
|
+
recovery operations will use the manager timeout
|
|
585
|
+
"""
|
|
586
|
+
|
|
587
|
+
# wait for previous quorum to complete
|
|
588
|
+
if self._quorum_future is not None:
|
|
589
|
+
self._quorum_future.result()
|
|
590
|
+
|
|
591
|
+
self._errored = None
|
|
592
|
+
self._healing = False
|
|
593
|
+
|
|
594
|
+
# TODO: we should really be wrapping this whole section in a try-except
|
|
595
|
+
# block to allow gracefully recovering from issues in PG setup and quorum.
|
|
596
|
+
|
|
597
|
+
self._quorum_future = self._executor.submit(
|
|
598
|
+
self._async_quorum,
|
|
599
|
+
allow_heal=allow_heal,
|
|
600
|
+
shrink_only=shrink_only,
|
|
601
|
+
quorum_timeout=timeout or self._quorum_timeout,
|
|
602
|
+
curr_device=(
|
|
603
|
+
torch.accelerator.current_device_index()
|
|
604
|
+
if torch.accelerator.is_available()
|
|
605
|
+
else -1
|
|
606
|
+
),
|
|
607
|
+
)
|
|
608
|
+
if not self._use_async_quorum:
|
|
609
|
+
self.wait_quorum()
|
|
610
|
+
|
|
611
|
+
if self._healing:
|
|
612
|
+
# eagerly apply pending state_dict so we can run the forwards pass
|
|
613
|
+
self._apply_pending_state_dict()
|
|
614
|
+
|
|
615
|
+
# we are forcing healing at the beginning so we're in a good state
|
|
616
|
+
# and don't need to zero_grad
|
|
617
|
+
self._healing = False
|
|
618
|
+
|
|
619
|
+
@torch.profiler.record_function("torchft::manager::wait_quorum")
|
|
620
|
+
def wait_quorum(self) -> None:
|
|
621
|
+
"""
|
|
622
|
+
Wait for the quorum to complete.
|
|
623
|
+
|
|
624
|
+
ProcessGroup will be in a healthy state after this returns.
|
|
625
|
+
"""
|
|
626
|
+
assert (
|
|
627
|
+
self._quorum_future is not None
|
|
628
|
+
), "must call start_quorum before wait_quorum"
|
|
629
|
+
self._quorum_future.result()
|
|
630
|
+
|
|
631
|
+
@torch.profiler.record_function("torchft::manager::_async_quorum")
|
|
632
|
+
def _async_quorum(
|
|
633
|
+
self,
|
|
634
|
+
allow_heal: bool,
|
|
635
|
+
shrink_only: bool,
|
|
636
|
+
quorum_timeout: timedelta,
|
|
637
|
+
curr_device: int,
|
|
638
|
+
) -> None:
|
|
639
|
+
torch.multiprocessing._set_thread_name("torchft_quorum")
|
|
640
|
+
|
|
641
|
+
if curr_device >= 0 and torch.accelerator.is_available():
|
|
642
|
+
torch.accelerator.set_device_index(curr_device)
|
|
643
|
+
|
|
644
|
+
quorum = None
|
|
645
|
+
with torch.profiler.record_function("torchft::manager::_client::_quorum"):
|
|
646
|
+
quorum = self._client._quorum(
|
|
647
|
+
group_rank=self._group_rank,
|
|
648
|
+
step=self._step,
|
|
649
|
+
checkpoint_metadata=self._checkpoint_transport.metadata(),
|
|
650
|
+
shrink_only=shrink_only,
|
|
651
|
+
timeout=quorum_timeout,
|
|
652
|
+
init_sync=self._init_sync,
|
|
653
|
+
commit_failures=self._commit_failures,
|
|
654
|
+
)
|
|
655
|
+
|
|
656
|
+
quorum_id = quorum.quorum_id
|
|
657
|
+
replica_rank = quorum.replica_rank
|
|
658
|
+
replica_world_size = quorum.replica_world_size
|
|
659
|
+
recover_src_manager_address = quorum.recover_src_manager_address
|
|
660
|
+
store_address = quorum.store_address
|
|
661
|
+
max_step = quorum.max_step
|
|
662
|
+
max_replica_rank = quorum.max_replica_rank
|
|
663
|
+
max_replica_world_size = quorum.max_world_size
|
|
664
|
+
heal = quorum.heal
|
|
665
|
+
replica_ids = quorum.replica_ids
|
|
666
|
+
|
|
667
|
+
ranks_in_quorum = [
|
|
668
|
+
extract_trailing_digits(replica_id.split(":")[0]) * self._group_world_size
|
|
669
|
+
+ self._group_rank
|
|
670
|
+
for replica_id in replica_ids
|
|
671
|
+
]
|
|
672
|
+
|
|
673
|
+
# When using async quorum we need to take the recovered workers.
|
|
674
|
+
# When not using async quorum we need to take the max world size as all
|
|
675
|
+
# workers will be healthy.
|
|
676
|
+
self._participating_replica_rank, self._participating_replica_world_size = (
|
|
677
|
+
(max_replica_rank, max_replica_world_size)
|
|
678
|
+
if self._use_async_quorum or not allow_heal
|
|
679
|
+
else (replica_rank, replica_world_size)
|
|
680
|
+
)
|
|
681
|
+
|
|
682
|
+
# For fixed with spares we need to ensure that we don't have more
|
|
683
|
+
# participating replicas than the min replica size.
|
|
684
|
+
if self._replica_world_size_mode == WorldSizeMode.FIXED_WITH_SPARES:
|
|
685
|
+
self._participating_replica_world_size = min(
|
|
686
|
+
self._participating_replica_world_size, self._min_replica_size
|
|
687
|
+
)
|
|
688
|
+
if (
|
|
689
|
+
self._participating_replica_rank is not None
|
|
690
|
+
and self._participating_replica_rank >= self._min_replica_size
|
|
691
|
+
):
|
|
692
|
+
self._participating_replica_rank = None
|
|
693
|
+
|
|
694
|
+
if quorum_id != self._quorum_id:
|
|
695
|
+
self.quorum_logger.info(
|
|
696
|
+
"",
|
|
697
|
+
extra={
|
|
698
|
+
"job_id": os.environ.get("JOB_ID", "unknown"),
|
|
699
|
+
"replica_id": self._replica_id,
|
|
700
|
+
"rank": self._group_rank,
|
|
701
|
+
"quorum_id": quorum_id,
|
|
702
|
+
"step": max_step,
|
|
703
|
+
},
|
|
704
|
+
)
|
|
705
|
+
store_prefixed_addr = (
|
|
706
|
+
f"{store_address}/torchft/{quorum_id}/{self._group_rank}"
|
|
707
|
+
)
|
|
708
|
+
|
|
709
|
+
self._logger.info(f"reconfiguring for {quorum_id=} {store_prefixed_addr=}")
|
|
710
|
+
# We use the replica rank and world as we want all replicas in the PG.
|
|
711
|
+
try:
|
|
712
|
+
self._quorum_id = quorum_id
|
|
713
|
+
with torch.profiler.record_function("torchft::manager::_pg::configure"):
|
|
714
|
+
# Reset GPU state for Flight Recorder
|
|
715
|
+
if torch.accelerator.is_available():
|
|
716
|
+
torch.accelerator.synchronize()
|
|
717
|
+
|
|
718
|
+
self._pg.configure(
|
|
719
|
+
store_prefixed_addr,
|
|
720
|
+
self._replica_id if self._replica_id is not None else "0",
|
|
721
|
+
replica_rank,
|
|
722
|
+
replica_world_size,
|
|
723
|
+
quorum_id,
|
|
724
|
+
self._group_rank,
|
|
725
|
+
self._group_world_size,
|
|
726
|
+
ranks_in_quorum,
|
|
727
|
+
)
|
|
728
|
+
|
|
729
|
+
# We need to reset the trace after reconfiguring the PG because that
|
|
730
|
+
# calls abort which may trigger a dump
|
|
731
|
+
self._logger.info(
|
|
732
|
+
f"resetting fr recording for quorum id {self._quorum_id}"
|
|
733
|
+
)
|
|
734
|
+
self._update_fr_path()
|
|
735
|
+
torch._C._distributed_c10d._reset_fr_recording_nccl() # pyre-ignore
|
|
736
|
+
except Exception as e:
|
|
737
|
+
self._logger.exception(f"got exception in pg configure: {e}")
|
|
738
|
+
self.report_error(e)
|
|
739
|
+
return
|
|
740
|
+
|
|
741
|
+
if allow_heal:
|
|
742
|
+
# run recovery on the recovery stream if available
|
|
743
|
+
recovery_stream = self._recovery_stream
|
|
744
|
+
with get_stream_context(recovery_stream):
|
|
745
|
+
try:
|
|
746
|
+
if quorum.recover_dst_replica_ranks:
|
|
747
|
+
self._logger.info(
|
|
748
|
+
f"peers need recovery from us {quorum.recover_dst_replica_ranks}"
|
|
749
|
+
)
|
|
750
|
+
with torch.profiler.record_function(
|
|
751
|
+
"torchft::manager::_checkpoint_transport::send_checkpoint"
|
|
752
|
+
):
|
|
753
|
+
self._checkpoint_transport.send_checkpoint(
|
|
754
|
+
dst_ranks=quorum.recover_dst_replica_ranks,
|
|
755
|
+
step=max_step,
|
|
756
|
+
state_dict=self._manager_state_dict(),
|
|
757
|
+
timeout=self._timeout,
|
|
758
|
+
)
|
|
759
|
+
|
|
760
|
+
# See manager.rs for healing conditions
|
|
761
|
+
if heal:
|
|
762
|
+
self._healing = True
|
|
763
|
+
self._logger.info(
|
|
764
|
+
f"healing required, fetching checkpoint metadata from {recover_src_manager_address=} {max_step=}"
|
|
765
|
+
)
|
|
766
|
+
primary_client = ManagerClient(
|
|
767
|
+
recover_src_manager_address,
|
|
768
|
+
connect_timeout=self._connect_timeout,
|
|
769
|
+
)
|
|
770
|
+
checkpoint_metadata = primary_client._checkpoint_metadata(
|
|
771
|
+
self._group_rank, timeout=self._timeout
|
|
772
|
+
)
|
|
773
|
+
recover_src_replica_rank = quorum.recover_src_replica_rank
|
|
774
|
+
assert (
|
|
775
|
+
recover_src_replica_rank is not None
|
|
776
|
+
), "must have a recover rank when healing"
|
|
777
|
+
|
|
778
|
+
self._logger.info(
|
|
779
|
+
f"fetching checkpoint from {recover_src_replica_rank=} with {checkpoint_metadata=}"
|
|
780
|
+
)
|
|
781
|
+
|
|
782
|
+
# we apply the user state dict only when safe from the main thread
|
|
783
|
+
# save it for now
|
|
784
|
+
with torch.profiler.record_function(
|
|
785
|
+
"torchft::manager::_checkpoint_transport::recv_checkpoint"
|
|
786
|
+
):
|
|
787
|
+
self._pending_state_dict = self._checkpoint_transport.recv_checkpoint(
|
|
788
|
+
src_rank=recover_src_replica_rank,
|
|
789
|
+
metadata=checkpoint_metadata, # Depending on group rank
|
|
790
|
+
step=max_step,
|
|
791
|
+
timeout=self._timeout,
|
|
792
|
+
)
|
|
793
|
+
|
|
794
|
+
# pyre-fixme[6]: got object
|
|
795
|
+
self.load_state_dict(self._pending_state_dict["torchft"])
|
|
796
|
+
|
|
797
|
+
# This isn't strictly needed as loading the state_dict above should
|
|
798
|
+
# restore the correct step but it makes writing tests simpler.
|
|
799
|
+
self._step = max_step
|
|
800
|
+
except Exception as e:
|
|
801
|
+
self._logger.exception(f"got exception in recovery: {e}")
|
|
802
|
+
self.report_error(e)
|
|
803
|
+
|
|
804
|
+
self._recovery_event = (
|
|
805
|
+
torch.accelerator.current_stream().record_event()
|
|
806
|
+
if recovery_stream is not None
|
|
807
|
+
else None
|
|
808
|
+
)
|
|
809
|
+
|
|
810
|
+
def _update_fr_path(self) -> None:
|
|
811
|
+
"""
|
|
812
|
+
Update the path that flight recorder will dump the traces to.
|
|
813
|
+
The format is
|
|
814
|
+
<TORCH_FR_DUMP_TEMP_FILE_ENV>_quorum_<quorum_id>/<global_rank>
|
|
815
|
+
"""
|
|
816
|
+
if self._original_fr_dump_temp_file is not None:
|
|
817
|
+
folder = f"{self._original_fr_dump_temp_file}_quorum_{self._quorum_id}"
|
|
818
|
+
os.makedirs(folder, exist_ok=True)
|
|
819
|
+
os.environ[TORCH_FR_DUMP_TEMP_FILE_ENV] = f"{folder}/{self._global_rank}"
|
|
820
|
+
|
|
821
|
+
def _apply_pending_state_dict(self) -> None:
|
|
822
|
+
assert self._healing, "must be in healing state"
|
|
823
|
+
|
|
824
|
+
# synchronize on future
|
|
825
|
+
assert self._quorum_future is not None, "must call step before should_commit"
|
|
826
|
+
self._quorum_future.result()
|
|
827
|
+
|
|
828
|
+
pending_state_dict = self._pending_state_dict
|
|
829
|
+
|
|
830
|
+
if pending_state_dict is None:
|
|
831
|
+
assert self.errored(), "checkpoint was not staged and no error occured"
|
|
832
|
+
else:
|
|
833
|
+
self._logger.info("applying pending state dict")
|
|
834
|
+
|
|
835
|
+
assert (
|
|
836
|
+
len(self._load_state_dict_fns) > 0
|
|
837
|
+
), "user load_state_dict is not initialized."
|
|
838
|
+
|
|
839
|
+
pending_user_state_dict = cast(
|
|
840
|
+
Dict[str, object], pending_state_dict["user"]
|
|
841
|
+
)
|
|
842
|
+
|
|
843
|
+
for key in self._load_state_dict_fns.keys():
|
|
844
|
+
load_state_dict_fn = self._load_state_dict_fns[key]
|
|
845
|
+
load_state_dict_fn(pending_user_state_dict[key])
|
|
846
|
+
|
|
847
|
+
self._pending_state_dict = None
|
|
848
|
+
self._logger.info("Loaded state dict.")
|
|
849
|
+
|
|
850
|
+
@torch.profiler.record_function("torchft::manager::should_commit")
|
|
851
|
+
def should_commit(self, timeout: Optional[timedelta] = None) -> bool:
|
|
852
|
+
"""
|
|
853
|
+
.. note::
|
|
854
|
+
We recommend using the :py:class:`torchft.optim.OptimizerWrapper` instead of calling this directly.
|
|
855
|
+
|
|
856
|
+
Must be called after the backwards pass completes but before stepping the optimizer.
|
|
857
|
+
|
|
858
|
+
The optimizer must only be stepped if this returns True.
|
|
859
|
+
|
|
860
|
+
This must be called on all workers within a replica group. This uses a
|
|
861
|
+
collective to ensure all workers within a replica return the same value.
|
|
862
|
+
If an error occurs on any worker, all workers will return False.
|
|
863
|
+
Different replica groups may return different values.
|
|
864
|
+
|
|
865
|
+
This should only be called once per step.
|
|
866
|
+
|
|
867
|
+
If max_retries is set and should_commit fails that many times consecutively,
|
|
868
|
+
this method will raise a RuntimeError to prevent indefinite failure loops.
|
|
869
|
+
|
|
870
|
+
Returns:
|
|
871
|
+
True if the optimizer should be stepped, False otherwise
|
|
872
|
+
Raises:
|
|
873
|
+
RuntimeError: if should_commit fails max_retries times in a row and max_retries is set
|
|
874
|
+
"""
|
|
875
|
+
# make sure recovery is complete before committing
|
|
876
|
+
with torch.profiler.record_function(
|
|
877
|
+
"torchft::manager::should_commmit::recovery_stream::synchronize"
|
|
878
|
+
):
|
|
879
|
+
if self._recovery_event is not None:
|
|
880
|
+
self._recovery_event.synchronize()
|
|
881
|
+
self._recovery_event = None
|
|
882
|
+
|
|
883
|
+
with torch.profiler.record_function(
|
|
884
|
+
"torchft::manager::should_commit::current_stream::synchronize"
|
|
885
|
+
):
|
|
886
|
+
if torch.accelerator.is_available():
|
|
887
|
+
synchronize()
|
|
888
|
+
|
|
889
|
+
if err := self._pg.errored():
|
|
890
|
+
self.report_error(err)
|
|
891
|
+
|
|
892
|
+
# apply state_dict if healing
|
|
893
|
+
if self._healing:
|
|
894
|
+
self._apply_pending_state_dict()
|
|
895
|
+
|
|
896
|
+
enough_replicas = self.num_participants() >= self._min_replica_size
|
|
897
|
+
local_should_commit = enough_replicas and self._errored is None
|
|
898
|
+
should_commit = self._client.should_commit(
|
|
899
|
+
self._group_rank,
|
|
900
|
+
self._step,
|
|
901
|
+
local_should_commit,
|
|
902
|
+
timeout=timeout or self._timeout,
|
|
903
|
+
)
|
|
904
|
+
self._logger.info(
|
|
905
|
+
f"should_commit={should_commit} enough_replicas={enough_replicas}, errored={self._errored}"
|
|
906
|
+
)
|
|
907
|
+
|
|
908
|
+
self.commits_logger.info(
|
|
909
|
+
"",
|
|
910
|
+
extra={
|
|
911
|
+
"job_id": os.environ.get("JOB_ID", "unknown"),
|
|
912
|
+
"replica_id": self._replica_id,
|
|
913
|
+
"rank": self._group_rank,
|
|
914
|
+
"quorum_id": self._quorum_id,
|
|
915
|
+
"step": self._step,
|
|
916
|
+
"commit_result": should_commit,
|
|
917
|
+
},
|
|
918
|
+
)
|
|
919
|
+
|
|
920
|
+
self._checkpoint_transport.disallow_checkpoint()
|
|
921
|
+
|
|
922
|
+
# decide whether we're in a healthy state to increase the step count
|
|
923
|
+
if should_commit:
|
|
924
|
+
self._step += 1
|
|
925
|
+
self._batches_committed += self.num_participants()
|
|
926
|
+
self._commit_failures = 0 # Reset failure counter on success
|
|
927
|
+
else:
|
|
928
|
+
self._commit_failures += 1
|
|
929
|
+
# Check if we've hit max retries
|
|
930
|
+
if (
|
|
931
|
+
self._max_retries is not None
|
|
932
|
+
and self._commit_failures > self._max_retries
|
|
933
|
+
):
|
|
934
|
+
msg = f"should_commit failed {self._commit_failures} times consecutively, exceeding max_retries={self._max_retries}"
|
|
935
|
+
self._logger.exception(msg)
|
|
936
|
+
raise RuntimeError(msg)
|
|
937
|
+
|
|
938
|
+
return should_commit
|
|
939
|
+
|
|
940
|
+
def load_state_dict(self, state_dict: Dict[str, int]) -> None:
|
|
941
|
+
"""
|
|
942
|
+
Load the state dict from a previous checkpoint.
|
|
943
|
+
|
|
944
|
+
This will restore the step count and internal metadata.
|
|
945
|
+
|
|
946
|
+
Args:
|
|
947
|
+
state_dict: the state dict to load
|
|
948
|
+
"""
|
|
949
|
+
self._step = state_dict["step"]
|
|
950
|
+
self._batches_committed = state_dict["batches_committed"]
|
|
951
|
+
|
|
952
|
+
def _manager_state_dict(self) -> Dict[str, object]:
|
|
953
|
+
with self._state_dict_lock.r_lock():
|
|
954
|
+
assert (
|
|
955
|
+
len(self._user_state_dicts) > 0
|
|
956
|
+
), "user state_dict is not initialized."
|
|
957
|
+
return {
|
|
958
|
+
"user": {key: value() for key, value in self._user_state_dicts.items()},
|
|
959
|
+
"torchft": self.state_dict(),
|
|
960
|
+
}
|
|
961
|
+
|
|
962
|
+
def state_dict(self) -> Dict[str, int]:
|
|
963
|
+
"""
|
|
964
|
+
Get the state dict for this manager.
|
|
965
|
+
|
|
966
|
+
This can be used to checkpoint the state of the manager to restore
|
|
967
|
+
from a previous checkpoint.
|
|
968
|
+
|
|
969
|
+
Returns:
|
|
970
|
+
the state dict for this manager
|
|
971
|
+
"""
|
|
972
|
+
return {"step": self._step, "batches_committed": self._batches_committed}
|
|
973
|
+
|
|
974
|
+
def current_step(self) -> int:
|
|
975
|
+
"""
|
|
976
|
+
Get the current step count.
|
|
977
|
+
|
|
978
|
+
This number is incremented on .step()
|
|
979
|
+
|
|
980
|
+
Returns:
|
|
981
|
+
the current step count
|
|
982
|
+
"""
|
|
983
|
+
return self._step
|
|
984
|
+
|
|
985
|
+
def batches_committed(self) -> int:
|
|
986
|
+
"""
|
|
987
|
+
Get the total number of batches committed across all steps and replicas.
|
|
988
|
+
5 replicas participating in 2 steps is 10 batches but may be more than
|
|
989
|
+
10 examples depending on batch size.
|
|
990
|
+
|
|
991
|
+
This number is incremented on .step()
|
|
992
|
+
|
|
993
|
+
Returns:
|
|
994
|
+
the total number of batches committed
|
|
995
|
+
"""
|
|
996
|
+
return self._batches_committed
|
|
997
|
+
|
|
998
|
+
def participating_rank(self) -> Optional[int]:
|
|
999
|
+
"""
|
|
1000
|
+
Get the replica group rank of the current quorum. This will be the same on all
|
|
1001
|
+
ranks within the replica group.
|
|
1002
|
+
|
|
1003
|
+
If this replica group is not participating in the current quorum, this will be None.
|
|
1004
|
+
|
|
1005
|
+
This will block on the async quorum if it is not yet ready.
|
|
1006
|
+
|
|
1007
|
+
Returns:
|
|
1008
|
+
the rank of the current quorum
|
|
1009
|
+
"""
|
|
1010
|
+
if self._quorum_future is None:
|
|
1011
|
+
return None
|
|
1012
|
+
|
|
1013
|
+
self.wait_quorum()
|
|
1014
|
+
|
|
1015
|
+
return self._participating_replica_rank
|
|
1016
|
+
|
|
1017
|
+
def num_participants(self) -> int:
|
|
1018
|
+
"""
|
|
1019
|
+
Get the number of participants in the current quorum.
|
|
1020
|
+
|
|
1021
|
+
This is the number of replicas participating in the current step.
|
|
1022
|
+
|
|
1023
|
+
This will block on the async quorum if it is not yet ready.
|
|
1024
|
+
|
|
1025
|
+
Returns:
|
|
1026
|
+
the number of participants in the current quorum
|
|
1027
|
+
"""
|
|
1028
|
+
if self._quorum_future is None:
|
|
1029
|
+
return 0
|
|
1030
|
+
|
|
1031
|
+
self.wait_quorum()
|
|
1032
|
+
|
|
1033
|
+
assert self._participating_replica_world_size >= 0, "internal error"
|
|
1034
|
+
return self._participating_replica_world_size
|
|
1035
|
+
|
|
1036
|
+
def is_participating(self) -> bool:
|
|
1037
|
+
"""
|
|
1038
|
+
Get whether this replica is participating in the current quorum.
|
|
1039
|
+
|
|
1040
|
+
Returns:
|
|
1041
|
+
whether this replica is participating in the current quorum
|
|
1042
|
+
"""
|
|
1043
|
+
if self._participating_replica_rank is None:
|
|
1044
|
+
return False
|
|
1045
|
+
if self._healing:
|
|
1046
|
+
assert self._use_async_quorum
|
|
1047
|
+
return False
|
|
1048
|
+
return True
|
|
1049
|
+
|
|
1050
|
+
|
|
1051
|
+
class _ManagerLogger:
|
|
1052
|
+
def __init__(self, manager: Manager, replica_id: str, group_rank: int) -> None:
|
|
1053
|
+
self._logger: logging.Logger = logging.getLogger(__name__)
|
|
1054
|
+
self._replica_id = replica_id
|
|
1055
|
+
self._group_rank = group_rank
|
|
1056
|
+
self._manager = manager
|
|
1057
|
+
|
|
1058
|
+
def prefix(self) -> str:
|
|
1059
|
+
return f"[{self._replica_id}/{self._group_rank} - step {self._manager.current_step()}]"
|
|
1060
|
+
|
|
1061
|
+
def info(self, msg: str) -> None:
|
|
1062
|
+
self._logger.info(f"{self.prefix()} {msg}")
|
|
1063
|
+
|
|
1064
|
+
def warn(self, msg: str) -> None:
|
|
1065
|
+
self._logger.warn(f"{self.prefix()} {msg}")
|
|
1066
|
+
|
|
1067
|
+
def exception(self, msg: str) -> None:
|
|
1068
|
+
self._logger.exception(f"{self.prefix()} {msg}")
|
|
1069
|
+
|
|
1070
|
+
|
|
1071
|
+
T = TypeVar("T")
|
|
1072
|
+
S = TypeVar("S")
|
|
1073
|
+
|
|
1074
|
+
|
|
1075
|
+
class _SimpleFuture(torch.futures.Future[T]):
|
|
1076
|
+
"""
|
|
1077
|
+
A simplified implementation of torch.futures.Future that wraps a value.
|
|
1078
|
+
|
|
1079
|
+
This class provides a minimal Future implementation that holds a pre-determined value.
|
|
1080
|
+
It's primarily used as a wrapper for values in the callback chain of `_ManagedFuture`.
|
|
1081
|
+
Most methods raise `RuntimeError` as they're not intended to be called.
|
|
1082
|
+
|
|
1083
|
+
This class is designed to be used only in specific contexts where we don't
|
|
1084
|
+
want to call `value()` on the underlying `Future` as that would cause the CPU to block.
|
|
1085
|
+
"""
|
|
1086
|
+
|
|
1087
|
+
def __init__(self, value: object) -> None:
|
|
1088
|
+
super().__init__()
|
|
1089
|
+
self._value = value
|
|
1090
|
+
|
|
1091
|
+
def value(self) -> object:
|
|
1092
|
+
return self._value
|
|
1093
|
+
|
|
1094
|
+
def then(
|
|
1095
|
+
self, callback: Callable[[torch.futures.Future[T]], S]
|
|
1096
|
+
) -> torch.futures.Future[S]:
|
|
1097
|
+
raise NotImplementedError(
|
|
1098
|
+
"This future is only supposed to be used in callback chain to extract the value"
|
|
1099
|
+
)
|
|
1100
|
+
|
|
1101
|
+
def wait(self) -> object:
|
|
1102
|
+
raise NotImplementedError(
|
|
1103
|
+
"This future is only supposed to be used in callback chain to extract the value"
|
|
1104
|
+
)
|
|
1105
|
+
|
|
1106
|
+
def done(self) -> bool:
|
|
1107
|
+
raise NotImplementedError(
|
|
1108
|
+
"This future is only supposed to be used in callback chain to extract the value"
|
|
1109
|
+
)
|
|
1110
|
+
|
|
1111
|
+
def add_done_callback(
|
|
1112
|
+
self, callback: Callable[[torch.futures.Future[T]], None]
|
|
1113
|
+
) -> None:
|
|
1114
|
+
raise NotImplementedError(
|
|
1115
|
+
"This future is only supposed to be used in callback chain to extract the value"
|
|
1116
|
+
)
|
|
1117
|
+
|
|
1118
|
+
def set_result(self, result: object) -> None:
|
|
1119
|
+
raise NotImplementedError(
|
|
1120
|
+
"This future is only supposed to be used in callback chain to extract the value"
|
|
1121
|
+
)
|
|
1122
|
+
|
|
1123
|
+
def set_exception(self, result: object) -> None:
|
|
1124
|
+
raise NotImplementedError(
|
|
1125
|
+
"This future is only supposed to be used in callback chain to extract the value"
|
|
1126
|
+
)
|
|
1127
|
+
|
|
1128
|
+
|
|
1129
|
+
class _ManagedFuture(torch.futures.Future[T]):
|
|
1130
|
+
"""
|
|
1131
|
+
A specialized Future implementation that works alongside `_ManagedWork`.
|
|
1132
|
+
|
|
1133
|
+
This class extends torch.futures.Future to provide future chaining that is
|
|
1134
|
+
lazy - `then()` method simply stores the callback, which is only executed when
|
|
1135
|
+
`wait()` is called on `_ManagedFuture` or `_ManagedWork`
|
|
1136
|
+
|
|
1137
|
+
Callback chains are implemented as a linked list of `_ManagedFuture` objects through the
|
|
1138
|
+
`_next` attribute. When appending a callback to the chain, it also updates the tail of the
|
|
1139
|
+
linked list stored in `_ManagedWork`.
|
|
1140
|
+
|
|
1141
|
+
Delegates actual future operations to an internal torch.futures.Future.
|
|
1142
|
+
|
|
1143
|
+
Raises RuntimeError for methods that should not be called.
|
|
1144
|
+
"""
|
|
1145
|
+
|
|
1146
|
+
def __init__(self, managed_work: weakref.ReferenceType["_ManagedWork"]) -> None:
|
|
1147
|
+
super().__init__()
|
|
1148
|
+
# Store a weak reference to _ManagedWork to avoid reference cycles
|
|
1149
|
+
self._managed_work = managed_work
|
|
1150
|
+
|
|
1151
|
+
# The underlying torch.futures.Future that this class delegates to
|
|
1152
|
+
self._fut: Optional[torch.futures.Future[T]] = None
|
|
1153
|
+
|
|
1154
|
+
# The next future in the callback chain
|
|
1155
|
+
self._next: Optional[_ManagedFuture[object]] = None
|
|
1156
|
+
|
|
1157
|
+
# The callback to be executed when the future is completed - this callback
|
|
1158
|
+
# returns the next future in the chain
|
|
1159
|
+
self._callback: Optional[Callable[[torch.futures.Future[T]], object]] = None
|
|
1160
|
+
|
|
1161
|
+
def then(
|
|
1162
|
+
self,
|
|
1163
|
+
callback: Callable[[torch.futures.Future[T]], S],
|
|
1164
|
+
) -> torch.futures.Future[S]:
|
|
1165
|
+
"""
|
|
1166
|
+
Sets the callback to be executed when the future is completed.
|
|
1167
|
+
|
|
1168
|
+
Since the callback returns a future, this method also creates a new future
|
|
1169
|
+
in the chain and also updates the tail of the chain in `_ManagedWork`.
|
|
1170
|
+
"""
|
|
1171
|
+
managed_work = self._managed_work()
|
|
1172
|
+
assert managed_work is not None, "got garbage collected"
|
|
1173
|
+
|
|
1174
|
+
self._callback = callback
|
|
1175
|
+
self._next = _ManagedFuture[object](self._managed_work)
|
|
1176
|
+
managed_work._managed_fut_tail = self._next
|
|
1177
|
+
return cast(torch.futures.Future[S], self._next)
|
|
1178
|
+
|
|
1179
|
+
def wait(self) -> object:
|
|
1180
|
+
assert self._fut
|
|
1181
|
+
return self._fut.wait()
|
|
1182
|
+
|
|
1183
|
+
def value(self) -> object:
|
|
1184
|
+
raise NotImplementedError(
|
|
1185
|
+
"This future is supposed to be used to create callback chain"
|
|
1186
|
+
)
|
|
1187
|
+
|
|
1188
|
+
def done(self) -> bool:
|
|
1189
|
+
raise NotImplementedError(
|
|
1190
|
+
"This future is supposed to be used to create callback chain"
|
|
1191
|
+
)
|
|
1192
|
+
|
|
1193
|
+
def add_done_callback(
|
|
1194
|
+
self, callback: Callable[[torch.futures.Future[T]], None]
|
|
1195
|
+
) -> None:
|
|
1196
|
+
raise NotImplementedError(
|
|
1197
|
+
"This future is supposed to be used to create callback chain"
|
|
1198
|
+
)
|
|
1199
|
+
|
|
1200
|
+
def set_result(self, result: object) -> None:
|
|
1201
|
+
raise NotImplementedError(
|
|
1202
|
+
"This future is supposed to be used to create callback chain"
|
|
1203
|
+
)
|
|
1204
|
+
|
|
1205
|
+
def set_exception(self, result: object) -> None:
|
|
1206
|
+
raise NotImplementedError(
|
|
1207
|
+
"This future is supposed to be used to create callback chain"
|
|
1208
|
+
)
|
|
1209
|
+
|
|
1210
|
+
|
|
1211
|
+
class _ManagedWork(dist._Work):
|
|
1212
|
+
"""
|
|
1213
|
+
A specialized `Work` implementation that works alongside `_ManagedFuture` to create
|
|
1214
|
+
callback chains lazily. The callback chain is created when `wait()`, `block_current_stream()`
|
|
1215
|
+
or `synchronize()` are called.
|
|
1216
|
+
"""
|
|
1217
|
+
|
|
1218
|
+
def __init__(
|
|
1219
|
+
self,
|
|
1220
|
+
manager: Manager,
|
|
1221
|
+
work: dist._Work,
|
|
1222
|
+
value: object,
|
|
1223
|
+
) -> None:
|
|
1224
|
+
super().__init__()
|
|
1225
|
+
# Underlying `Work` retruned from process group operations
|
|
1226
|
+
self._work = work
|
|
1227
|
+
|
|
1228
|
+
# Used to report errors to the manager through `wrap_future()`
|
|
1229
|
+
self._manager = manager
|
|
1230
|
+
|
|
1231
|
+
# The value returned by the final future in the callback chain
|
|
1232
|
+
self._value = value
|
|
1233
|
+
|
|
1234
|
+
# The head of the callback chain
|
|
1235
|
+
self._managed_fut_head = _ManagedFuture[object](weakref.ref(self))
|
|
1236
|
+
|
|
1237
|
+
# The tail of the callback chain
|
|
1238
|
+
self._managed_fut_tail: _ManagedFuture[object] = self._managed_fut_head
|
|
1239
|
+
|
|
1240
|
+
# The stream used to created the `Work` - we ensure all operations in the future
|
|
1241
|
+
# callback chain are executed on this stream
|
|
1242
|
+
self._stream: Optional[torch.Stream] = (
|
|
1243
|
+
torch.accelerator.current_stream()
|
|
1244
|
+
if torch.accelerator.is_available()
|
|
1245
|
+
else None
|
|
1246
|
+
)
|
|
1247
|
+
|
|
1248
|
+
# To ensure the future callback chain is only created once
|
|
1249
|
+
self._is_set_future_callback_called = False
|
|
1250
|
+
|
|
1251
|
+
def _set_future_callback(
|
|
1252
|
+
self,
|
|
1253
|
+
) -> None:
|
|
1254
|
+
"""
|
|
1255
|
+
Sets up the stored future callback chain.
|
|
1256
|
+
|
|
1257
|
+
This method creates a chain of callbacks for the futures in the managed work,
|
|
1258
|
+
ensuring that each callback is executed in the proper order and with the
|
|
1259
|
+
appropriate stream context. It also wraps the futures with error handling
|
|
1260
|
+
through the manager's `wrap_future` method.
|
|
1261
|
+
|
|
1262
|
+
The method is called internally when waiting or synchronizing on the work.
|
|
1263
|
+
"""
|
|
1264
|
+
if self._is_set_future_callback_called:
|
|
1265
|
+
return
|
|
1266
|
+
|
|
1267
|
+
managed_fut: _ManagedFuture[object] = self._managed_fut_head
|
|
1268
|
+
managed_fut._fut = self._work.get_future()
|
|
1269
|
+
value = self._value
|
|
1270
|
+
|
|
1271
|
+
is_future_wrapped = False
|
|
1272
|
+
while managed_fut._next:
|
|
1273
|
+
|
|
1274
|
+
def callback(
|
|
1275
|
+
fut: torch.futures.Future[object],
|
|
1276
|
+
) -> object:
|
|
1277
|
+
nonlocal managed_fut, value
|
|
1278
|
+
# change the stream to avoid making the callback stream
|
|
1279
|
+
# dependent on process group stream running the allreduce
|
|
1280
|
+
with get_stream_context(self._stream):
|
|
1281
|
+
# Setup stream dependency
|
|
1282
|
+
fut.wait()
|
|
1283
|
+
assert managed_fut._callback
|
|
1284
|
+
value = managed_fut._callback(
|
|
1285
|
+
_SimpleFuture(value),
|
|
1286
|
+
)
|
|
1287
|
+
return value
|
|
1288
|
+
|
|
1289
|
+
assert managed_fut._fut
|
|
1290
|
+
fut = managed_fut._fut.then(callback)
|
|
1291
|
+
assert managed_fut._next
|
|
1292
|
+
managed_fut = managed_fut._next
|
|
1293
|
+
managed_fut._fut = fut
|
|
1294
|
+
|
|
1295
|
+
if is_future_wrapped:
|
|
1296
|
+
continue
|
|
1297
|
+
|
|
1298
|
+
managed_fut._fut = self._manager.wrap_future(managed_fut._fut, value)
|
|
1299
|
+
is_future_wrapped = True
|
|
1300
|
+
|
|
1301
|
+
self._value = value
|
|
1302
|
+
self._is_set_future_callback_called = True
|
|
1303
|
+
|
|
1304
|
+
def _assert_same_stream(self) -> None:
|
|
1305
|
+
"""
|
|
1306
|
+
Asserts that the current CUDA stream is the same as the one used to create this work.
|
|
1307
|
+
|
|
1308
|
+
This makes sure users of the API are aware about stream dependencies.
|
|
1309
|
+
"""
|
|
1310
|
+
if self._stream is not None:
|
|
1311
|
+
assert self._stream == torch.accelerator.current_stream()
|
|
1312
|
+
|
|
1313
|
+
def wait(self, timeout: Optional[timedelta] = None) -> bool:
|
|
1314
|
+
self._assert_same_stream()
|
|
1315
|
+
|
|
1316
|
+
try:
|
|
1317
|
+
with get_stream_context(self._stream):
|
|
1318
|
+
self._work.wait()
|
|
1319
|
+
self._set_future_callback()
|
|
1320
|
+
|
|
1321
|
+
with get_stream_context(self._stream):
|
|
1322
|
+
self._managed_fut_tail.wait()
|
|
1323
|
+
|
|
1324
|
+
return True
|
|
1325
|
+
except Exception as e:
|
|
1326
|
+
self._manager._logger.exception(f"got exception waiting for work {e}")
|
|
1327
|
+
self._manager.report_error(e)
|
|
1328
|
+
return False
|
|
1329
|
+
|
|
1330
|
+
def block_current_stream(self, timeout: Optional[timedelta] = None) -> None:
|
|
1331
|
+
self._assert_same_stream()
|
|
1332
|
+
|
|
1333
|
+
with get_stream_context(self._stream):
|
|
1334
|
+
self._work.block_current_stream()
|
|
1335
|
+
|
|
1336
|
+
self._set_future_callback()
|
|
1337
|
+
|
|
1338
|
+
def synchronize(self) -> None:
|
|
1339
|
+
self._assert_same_stream()
|
|
1340
|
+
|
|
1341
|
+
if torch.cuda.is_available():
|
|
1342
|
+
self.block_current_stream()
|
|
1343
|
+
elif torch.xpu.is_available():
|
|
1344
|
+
self._set_future_callback()
|
|
1345
|
+
else:
|
|
1346
|
+
# No stream dependencies need to be set
|
|
1347
|
+
self._set_future_callback()
|
|
1348
|
+
|
|
1349
|
+
def get_future(
|
|
1350
|
+
self,
|
|
1351
|
+
) -> torch.futures.Future[object]:
|
|
1352
|
+
"""
|
|
1353
|
+
Returns:
|
|
1354
|
+
The tail of the managed future chain, which represents the final
|
|
1355
|
+
result of all the chained operations. This future will be completed when
|
|
1356
|
+
all the work and its callbacks have been executed.
|
|
1357
|
+
"""
|
|
1358
|
+
return self._managed_fut_tail
|