torchft-nightly 2026.1.3__cp310-cp310-manylinux_2_24_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchft/__init__.py +34 -0
- torchft/_test/diloco_trainer.py +287 -0
- torchft/_test/managed_work_test.py +320 -0
- torchft/_test_utils.py +111 -0
- torchft/_torchft.cpython-310-x86_64-linux-gnu.so +0 -0
- torchft/_torchft.pyi +116 -0
- torchft/checkpointing/__init__.py +20 -0
- torchft/checkpointing/_rwlock.py +136 -0
- torchft/checkpointing/_serialization.py +39 -0
- torchft/checkpointing/http_transport.py +299 -0
- torchft/checkpointing/http_transport_bench.py +61 -0
- torchft/checkpointing/http_transport_test.py +146 -0
- torchft/checkpointing/pg_transport.py +306 -0
- torchft/checkpointing/pg_transport_bench.py +99 -0
- torchft/checkpointing/pg_transport_test.py +101 -0
- torchft/checkpointing/rwlock_test.py +58 -0
- torchft/checkpointing/transport.py +68 -0
- torchft/checkpointing/transport_test.py +161 -0
- torchft/collectives.py +415 -0
- torchft/collectives_test.py +212 -0
- torchft/coordination.py +39 -0
- torchft/coordination_test.py +29 -0
- torchft/data.py +77 -0
- torchft/data_test.py +39 -0
- torchft/ddp.py +105 -0
- torchft/ddp_test.py +68 -0
- torchft/diloco_regression_test.py +644 -0
- torchft/examples/slurm/README.md +34 -0
- torchft/examples/slurm/punisher.py +95 -0
- torchft/examples/slurm/runner.py +221 -0
- torchft/fsdp_test.py +102 -0
- torchft/futures.py +353 -0
- torchft/futures_test.py +140 -0
- torchft/http.py +13 -0
- torchft/lighthouse_test.py +163 -0
- torchft/local_sgd.py +796 -0
- torchft/local_sgd_integ_test.py +600 -0
- torchft/local_sgd_test.py +324 -0
- torchft/manager.py +1358 -0
- torchft/manager_integ_test.py +653 -0
- torchft/manager_test.py +911 -0
- torchft/multiprocessing.py +38 -0
- torchft/multiprocessing_dummy_context.py +135 -0
- torchft/multiprocessing_test.py +58 -0
- torchft/optim.py +63 -0
- torchft/optim_test.py +50 -0
- torchft/otel.py +134 -0
- torchft/parameter_server.py +195 -0
- torchft/parameter_server_test.py +47 -0
- torchft/process_group.py +2118 -0
- torchft/process_group_test.py +1028 -0
- torchft/quantization.py +686 -0
- torchft/quantization_test.py +131 -0
- torchft/torchx.py +89 -0
- torchft/utils.py +67 -0
- torchft/work.py +26 -0
- torchft_nightly-2026.1.3.dist-info/METADATA +308 -0
- torchft_nightly-2026.1.3.dist-info/RECORD +61 -0
- torchft_nightly-2026.1.3.dist-info/WHEEL +4 -0
- torchft_nightly-2026.1.3.dist-info/entry_points.txt +2 -0
- torchft_nightly-2026.1.3.dist-info/licenses/LICENSE +34 -0
|
@@ -0,0 +1,653 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
import copy
|
|
8
|
+
import logging
|
|
9
|
+
import threading
|
|
10
|
+
import time
|
|
11
|
+
import traceback
|
|
12
|
+
from collections import defaultdict
|
|
13
|
+
from concurrent.futures import as_completed, ThreadPoolExecutor
|
|
14
|
+
from contextlib import contextmanager, ExitStack
|
|
15
|
+
from dataclasses import dataclass, field
|
|
16
|
+
from datetime import timedelta
|
|
17
|
+
from enum import auto, Enum
|
|
18
|
+
from typing import (
|
|
19
|
+
Any,
|
|
20
|
+
cast,
|
|
21
|
+
Dict,
|
|
22
|
+
Generator,
|
|
23
|
+
List,
|
|
24
|
+
Optional,
|
|
25
|
+
Protocol,
|
|
26
|
+
Set,
|
|
27
|
+
Tuple,
|
|
28
|
+
TypeVar,
|
|
29
|
+
)
|
|
30
|
+
from unittest import TestCase
|
|
31
|
+
|
|
32
|
+
import torch
|
|
33
|
+
import torch.distributed as dist
|
|
34
|
+
from parameterized import parameterized
|
|
35
|
+
from torch import nn, optim
|
|
36
|
+
from torch._dynamo.utils import timed
|
|
37
|
+
|
|
38
|
+
from torchft._torchft import LighthouseServer
|
|
39
|
+
from torchft.ddp import DistributedDataParallel
|
|
40
|
+
from torchft.local_sgd import DiLoCo, LocalSGD
|
|
41
|
+
from torchft.manager import Manager
|
|
42
|
+
from torchft.optim import OptimizerWrapper
|
|
43
|
+
from torchft.process_group import (
|
|
44
|
+
FakeProcessGroupWrapper,
|
|
45
|
+
ProcessGroupBabyNCCL,
|
|
46
|
+
ProcessGroupGloo,
|
|
47
|
+
)
|
|
48
|
+
|
|
49
|
+
logging.basicConfig(level=logging.INFO)
|
|
50
|
+
logger: logging.Logger = logging.getLogger(__name__)
|
|
51
|
+
|
|
52
|
+
INIT_LOCK: threading.Lock = threading.Lock()
|
|
53
|
+
|
|
54
|
+
|
|
55
|
+
class MyModel(nn.Module):
|
|
56
|
+
def __init__(self, in_dim: int = 3, out_dim: int = 4) -> None:
|
|
57
|
+
super().__init__()
|
|
58
|
+
self.in_dim = in_dim
|
|
59
|
+
self.out_dim = out_dim
|
|
60
|
+
self.layers = nn.ModuleList(
|
|
61
|
+
[
|
|
62
|
+
nn.Linear(in_dim, 8),
|
|
63
|
+
nn.ReLU(),
|
|
64
|
+
nn.Linear(8, out_dim),
|
|
65
|
+
nn.ReLU(),
|
|
66
|
+
]
|
|
67
|
+
)
|
|
68
|
+
|
|
69
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
70
|
+
for layer in self.layers:
|
|
71
|
+
x = layer(x)
|
|
72
|
+
return x
|
|
73
|
+
|
|
74
|
+
def get_rand_inputs(
|
|
75
|
+
self, batch_size: int, device: torch.device = torch.device("cpu")
|
|
76
|
+
) -> torch.Tensor:
|
|
77
|
+
return torch.rand(batch_size, self.in_dim, device=device)
|
|
78
|
+
|
|
79
|
+
def get_rand_labels(
|
|
80
|
+
self, batch_size: int, device: torch.device = torch.device("cpu")
|
|
81
|
+
) -> torch.Tensor:
|
|
82
|
+
return torch.randint(3, (batch_size,), device=device)
|
|
83
|
+
|
|
84
|
+
|
|
85
|
+
class InjectedFailure(Exception):
|
|
86
|
+
pass
|
|
87
|
+
|
|
88
|
+
|
|
89
|
+
class EventInjectorEvent(Enum):
|
|
90
|
+
# Crashes a rank
|
|
91
|
+
Failure = auto()
|
|
92
|
+
# Used to wait for a rank to reach a certain step before continuing.
|
|
93
|
+
# Users need to make sure the size of the barrier is appropriately set.
|
|
94
|
+
Barrier = auto()
|
|
95
|
+
# Fails the allreduce call made by a rank
|
|
96
|
+
AllreduceFailure = auto()
|
|
97
|
+
|
|
98
|
+
|
|
99
|
+
class EventInjectorInfo:
|
|
100
|
+
def __init__(self, event: EventInjectorEvent, data: object) -> None:
|
|
101
|
+
self.event = event
|
|
102
|
+
self.data = data
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class EventInjector:
|
|
106
|
+
def __init__(self) -> None:
|
|
107
|
+
self._pg: Optional[FakeProcessGroupWrapper] = None
|
|
108
|
+
self._lock = threading.Lock()
|
|
109
|
+
self._events: Dict[Tuple[int, int], EventInjectorInfo] = {}
|
|
110
|
+
self.count: dict[EventInjectorEvent, int] = defaultdict(int)
|
|
111
|
+
|
|
112
|
+
def set_pg(self, pg: FakeProcessGroupWrapper) -> None:
|
|
113
|
+
with self._lock:
|
|
114
|
+
self._pg = pg
|
|
115
|
+
|
|
116
|
+
def fail_at(self, rank: int, step: int) -> "EventInjector":
|
|
117
|
+
with self._lock:
|
|
118
|
+
assert (rank, step) not in self._events
|
|
119
|
+
self._events[(rank, step)] = EventInjectorInfo(
|
|
120
|
+
EventInjectorEvent.Failure, None
|
|
121
|
+
)
|
|
122
|
+
return self
|
|
123
|
+
|
|
124
|
+
def fail_allreduce_at(self, rank: int, step: int) -> "EventInjector":
|
|
125
|
+
with self._lock:
|
|
126
|
+
assert (rank, step) not in self._events
|
|
127
|
+
self._events[(rank, step)] = EventInjectorInfo(
|
|
128
|
+
EventInjectorEvent.AllreduceFailure, None
|
|
129
|
+
)
|
|
130
|
+
return self
|
|
131
|
+
|
|
132
|
+
def barrier_at(
|
|
133
|
+
self, rank: int, step: int, barrier: threading.Barrier
|
|
134
|
+
) -> "EventInjector":
|
|
135
|
+
with self._lock:
|
|
136
|
+
assert (rank, step) not in self._events
|
|
137
|
+
self._events[(rank, step)] = EventInjectorInfo(
|
|
138
|
+
EventInjectorEvent.Barrier, barrier
|
|
139
|
+
)
|
|
140
|
+
return self
|
|
141
|
+
|
|
142
|
+
def check(self, rank: int, step: int) -> None:
|
|
143
|
+
with self._lock:
|
|
144
|
+
key = (rank, step)
|
|
145
|
+
if key in self._events:
|
|
146
|
+
event_info = self._events.pop(key)
|
|
147
|
+
|
|
148
|
+
self.count[event_info.event] += 1
|
|
149
|
+
|
|
150
|
+
if event_info.event == EventInjectorEvent.Failure:
|
|
151
|
+
print(f"injecting failure {rank=} {step=}")
|
|
152
|
+
raise InjectedFailure(f"injected failure {rank=} {step=}")
|
|
153
|
+
|
|
154
|
+
if event_info.event == EventInjectorEvent.AllreduceFailure:
|
|
155
|
+
print(f"injecting allreduce failure {rank=} {step=}")
|
|
156
|
+
assert self._pg is not None
|
|
157
|
+
self._pg.report_future_error(
|
|
158
|
+
RuntimeError("injected allreduce error")
|
|
159
|
+
)
|
|
160
|
+
return
|
|
161
|
+
|
|
162
|
+
if event_info.event == EventInjectorEvent.Barrier:
|
|
163
|
+
print(f"waiting for barrier {rank=} {step=}")
|
|
164
|
+
cast(threading.Barrier, event_info.data).wait()
|
|
165
|
+
return
|
|
166
|
+
|
|
167
|
+
raise RuntimeError(f"unknown event {event_info.event}")
|
|
168
|
+
|
|
169
|
+
|
|
170
|
+
# R for an arbitrary return type
|
|
171
|
+
R = TypeVar("R", covariant=True)
|
|
172
|
+
|
|
173
|
+
|
|
174
|
+
class TrainLoop(Protocol[R]):
|
|
175
|
+
def __call__(
|
|
176
|
+
self,
|
|
177
|
+
rank: int,
|
|
178
|
+
store_port: int,
|
|
179
|
+
device: torch.device,
|
|
180
|
+
runner: "Runner",
|
|
181
|
+
train_loop_args: dict[str, Any] = field(default_factory=dict),
|
|
182
|
+
) -> R: ...
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
@dataclass
|
|
186
|
+
class Runner:
|
|
187
|
+
replica_id: int
|
|
188
|
+
num_replicas: int
|
|
189
|
+
lighthouse_address: str
|
|
190
|
+
event_injector: EventInjector
|
|
191
|
+
train_loop: TrainLoop[object]
|
|
192
|
+
|
|
193
|
+
use_cuda: bool = False
|
|
194
|
+
world_size: int = 1
|
|
195
|
+
attempts: int = 3
|
|
196
|
+
manager_args: Dict[str, object] = field(default_factory=dict)
|
|
197
|
+
train_loop_args: Dict[str, Any] = field(default_factory=dict)
|
|
198
|
+
|
|
199
|
+
def _replica_main(self) -> List[object]:
|
|
200
|
+
store = dist.TCPStore(
|
|
201
|
+
host_name="localhost",
|
|
202
|
+
port=0,
|
|
203
|
+
is_master=True,
|
|
204
|
+
wait_for_workers=False,
|
|
205
|
+
)
|
|
206
|
+
|
|
207
|
+
with ThreadPoolExecutor(
|
|
208
|
+
max_workers=self.world_size, thread_name_prefix=f"replica{self.replica_id}"
|
|
209
|
+
) as executor:
|
|
210
|
+
futures = []
|
|
211
|
+
for rank in range(self.world_size):
|
|
212
|
+
if self.use_cuda:
|
|
213
|
+
num_cuda_devices = torch.cuda.device_count()
|
|
214
|
+
assert num_cuda_devices >= self.num_replicas
|
|
215
|
+
device_index = (
|
|
216
|
+
num_cuda_devices // self.num_replicas
|
|
217
|
+
) * self.replica_id + rank
|
|
218
|
+
device = torch.device(f"cuda:{device_index}")
|
|
219
|
+
else:
|
|
220
|
+
device = torch.device("cpu")
|
|
221
|
+
|
|
222
|
+
futures.append(
|
|
223
|
+
executor.submit(
|
|
224
|
+
self.train_loop,
|
|
225
|
+
rank=rank,
|
|
226
|
+
store_port=store.port,
|
|
227
|
+
device=device,
|
|
228
|
+
runner=self,
|
|
229
|
+
train_loop_args=self.train_loop_args,
|
|
230
|
+
)
|
|
231
|
+
)
|
|
232
|
+
|
|
233
|
+
for fut in as_completed(futures):
|
|
234
|
+
try:
|
|
235
|
+
fut.result()
|
|
236
|
+
except Exception as e:
|
|
237
|
+
logger.exception(f"worker {self.replica_id=} threw exception: {e}")
|
|
238
|
+
raise
|
|
239
|
+
|
|
240
|
+
return [fut.result() for fut in futures]
|
|
241
|
+
|
|
242
|
+
def run_replica(self) -> List[object]:
|
|
243
|
+
for i in range(self.attempts):
|
|
244
|
+
try:
|
|
245
|
+
print(
|
|
246
|
+
f"starting replica group {self.replica_id=} {self.world_size=} attempt {i}"
|
|
247
|
+
)
|
|
248
|
+
return self._replica_main()
|
|
249
|
+
except InjectedFailure as e:
|
|
250
|
+
print("got injected failure", i, e)
|
|
251
|
+
if i == self.attempts - 1:
|
|
252
|
+
raise
|
|
253
|
+
continue
|
|
254
|
+
|
|
255
|
+
raise RuntimeError("ran out of attempts")
|
|
256
|
+
|
|
257
|
+
|
|
258
|
+
def ddp_train_loop(
|
|
259
|
+
rank: int,
|
|
260
|
+
store_port: int,
|
|
261
|
+
device: torch.device,
|
|
262
|
+
runner: Runner,
|
|
263
|
+
train_loop_args: dict[str, Any] = {},
|
|
264
|
+
) -> Dict[str, Dict[str, object]]:
|
|
265
|
+
with ExitStack() as stack:
|
|
266
|
+
|
|
267
|
+
def load_state_dict(state_dict: Dict[str, Dict[str, object]]) -> None:
|
|
268
|
+
m.load_state_dict(state_dict["model"])
|
|
269
|
+
optimizer.load_state_dict(state_dict["optim"])
|
|
270
|
+
|
|
271
|
+
def state_dict() -> Dict[str, Dict[str, object]]:
|
|
272
|
+
return {
|
|
273
|
+
"model": m.state_dict(),
|
|
274
|
+
"optim": optimizer.state_dict(),
|
|
275
|
+
}
|
|
276
|
+
|
|
277
|
+
print(f"worker {runner.replica_id=} {rank=} {runner.world_size=} starting")
|
|
278
|
+
|
|
279
|
+
pg = ProcessGroupGloo()
|
|
280
|
+
manager = Manager(
|
|
281
|
+
pg=pg,
|
|
282
|
+
min_replica_size=2,
|
|
283
|
+
load_state_dict=load_state_dict,
|
|
284
|
+
state_dict=state_dict,
|
|
285
|
+
replica_id=str(runner.replica_id),
|
|
286
|
+
store_addr="localhost",
|
|
287
|
+
store_port=store_port,
|
|
288
|
+
rank=rank,
|
|
289
|
+
world_size=runner.world_size,
|
|
290
|
+
lighthouse_addr=runner.lighthouse_address,
|
|
291
|
+
port=19530 + runner.replica_id,
|
|
292
|
+
# pyre-fixme[6]: Incompatible parameter type
|
|
293
|
+
**runner.manager_args,
|
|
294
|
+
)
|
|
295
|
+
stack.callback(lambda: manager.shutdown(wait=False))
|
|
296
|
+
|
|
297
|
+
with INIT_LOCK:
|
|
298
|
+
# We need to lock during init for testing init_sync=False as all
|
|
299
|
+
# threads share the same RNG
|
|
300
|
+
torch.manual_seed(42)
|
|
301
|
+
m: nn.Module = MyModel()
|
|
302
|
+
|
|
303
|
+
m: nn.Module = DistributedDataParallel(manager, m)
|
|
304
|
+
optimizer: optim.Optimizer = OptimizerWrapper(
|
|
305
|
+
manager, optim.Adam(m.parameters())
|
|
306
|
+
)
|
|
307
|
+
criterion = nn.CrossEntropyLoss()
|
|
308
|
+
|
|
309
|
+
while True:
|
|
310
|
+
inputs = torch.rand(2, 3)
|
|
311
|
+
labels = torch.randint(4, (2,))
|
|
312
|
+
|
|
313
|
+
optimizer.zero_grad()
|
|
314
|
+
out = m(inputs)
|
|
315
|
+
loss = criterion(out, labels)
|
|
316
|
+
|
|
317
|
+
loss.backward()
|
|
318
|
+
|
|
319
|
+
optimizer.step()
|
|
320
|
+
|
|
321
|
+
if manager.current_step() >= 4:
|
|
322
|
+
break
|
|
323
|
+
|
|
324
|
+
runner.event_injector.check(rank, manager.current_step())
|
|
325
|
+
|
|
326
|
+
# return state_dict so we can check consistency
|
|
327
|
+
return state_dict()
|
|
328
|
+
|
|
329
|
+
|
|
330
|
+
class ManagerIntegTest(TestCase):
|
|
331
|
+
@contextmanager
|
|
332
|
+
def assertElapsedLessThan(
|
|
333
|
+
self, timeout: float, msg: str = ""
|
|
334
|
+
) -> Generator[None, None, None]:
|
|
335
|
+
start = time.perf_counter()
|
|
336
|
+
yield
|
|
337
|
+
elapsed = time.perf_counter() - start
|
|
338
|
+
self.assertLess(elapsed, timeout, msg)
|
|
339
|
+
|
|
340
|
+
def test_ddp_healthy(self) -> None:
|
|
341
|
+
lighthouse = LighthouseServer(
|
|
342
|
+
bind="[::]:0",
|
|
343
|
+
min_replicas=2,
|
|
344
|
+
)
|
|
345
|
+
num_replicas = 2
|
|
346
|
+
futures = []
|
|
347
|
+
|
|
348
|
+
with ThreadPoolExecutor(max_workers=num_replicas) as executor:
|
|
349
|
+
for replica_id in range(num_replicas):
|
|
350
|
+
event_injector = EventInjector()
|
|
351
|
+
runner = Runner(
|
|
352
|
+
replica_id=replica_id,
|
|
353
|
+
num_replicas=num_replicas,
|
|
354
|
+
lighthouse_address=lighthouse.address(),
|
|
355
|
+
event_injector=event_injector,
|
|
356
|
+
train_loop=ddp_train_loop,
|
|
357
|
+
)
|
|
358
|
+
futures.append(executor.submit(runner.run_replica))
|
|
359
|
+
|
|
360
|
+
state_dicts = []
|
|
361
|
+
|
|
362
|
+
for fut in as_completed(futures):
|
|
363
|
+
state_dicts.append(fut.result())
|
|
364
|
+
|
|
365
|
+
lighthouse.shutdown()
|
|
366
|
+
|
|
367
|
+
for state_dict in state_dicts:
|
|
368
|
+
torch.testing.assert_close(state_dict, state_dicts[0])
|
|
369
|
+
|
|
370
|
+
@parameterized.expand(
|
|
371
|
+
[
|
|
372
|
+
(
|
|
373
|
+
"async_quorum",
|
|
374
|
+
True,
|
|
375
|
+
),
|
|
376
|
+
(
|
|
377
|
+
"sync_quorum",
|
|
378
|
+
False,
|
|
379
|
+
),
|
|
380
|
+
]
|
|
381
|
+
)
|
|
382
|
+
def test_ddp_recovery(
|
|
383
|
+
self,
|
|
384
|
+
name: str,
|
|
385
|
+
use_async_quorum: bool,
|
|
386
|
+
) -> None:
|
|
387
|
+
lighthouse = LighthouseServer(
|
|
388
|
+
bind="[::]:0",
|
|
389
|
+
min_replicas=2,
|
|
390
|
+
)
|
|
391
|
+
num_replicas = 2
|
|
392
|
+
futures = []
|
|
393
|
+
|
|
394
|
+
event_injectors = [
|
|
395
|
+
EventInjector(),
|
|
396
|
+
EventInjector().fail_at(0, 2),
|
|
397
|
+
]
|
|
398
|
+
|
|
399
|
+
with ThreadPoolExecutor(max_workers=num_replicas) as executor:
|
|
400
|
+
for replica_id, event_injector in zip(range(num_replicas), event_injectors):
|
|
401
|
+
runner = Runner(
|
|
402
|
+
replica_id=replica_id,
|
|
403
|
+
num_replicas=num_replicas,
|
|
404
|
+
lighthouse_address=lighthouse.address(),
|
|
405
|
+
event_injector=event_injector,
|
|
406
|
+
manager_args={
|
|
407
|
+
"use_async_quorum": use_async_quorum,
|
|
408
|
+
},
|
|
409
|
+
train_loop=ddp_train_loop,
|
|
410
|
+
)
|
|
411
|
+
futures.append(executor.submit(runner.run_replica))
|
|
412
|
+
|
|
413
|
+
state_dicts = []
|
|
414
|
+
|
|
415
|
+
for fut in as_completed(futures):
|
|
416
|
+
try:
|
|
417
|
+
state_dicts.append(fut.result())
|
|
418
|
+
except Exception as e:
|
|
419
|
+
print(e)
|
|
420
|
+
raise
|
|
421
|
+
|
|
422
|
+
lighthouse.shutdown()
|
|
423
|
+
|
|
424
|
+
for state_dict in state_dicts:
|
|
425
|
+
torch.testing.assert_close(state_dict, state_dicts[0])
|
|
426
|
+
|
|
427
|
+
self.assertEqual(event_injectors[1].count[EventInjectorEvent.Failure], 1)
|
|
428
|
+
|
|
429
|
+
def test_ddp_skip_init_sync(
|
|
430
|
+
self,
|
|
431
|
+
) -> None:
|
|
432
|
+
lighthouse = LighthouseServer(
|
|
433
|
+
bind="[::]:0",
|
|
434
|
+
min_replicas=2,
|
|
435
|
+
)
|
|
436
|
+
num_replicas = 2
|
|
437
|
+
futures = []
|
|
438
|
+
|
|
439
|
+
# no failures
|
|
440
|
+
event_injectors = [
|
|
441
|
+
EventInjector(),
|
|
442
|
+
EventInjector(),
|
|
443
|
+
]
|
|
444
|
+
|
|
445
|
+
with ThreadPoolExecutor(max_workers=num_replicas) as executor:
|
|
446
|
+
for replica_id, event_injector in zip(range(num_replicas), event_injectors):
|
|
447
|
+
runner = Runner(
|
|
448
|
+
replica_id=replica_id,
|
|
449
|
+
num_replicas=num_replicas,
|
|
450
|
+
lighthouse_address=lighthouse.address(),
|
|
451
|
+
event_injector=event_injector,
|
|
452
|
+
manager_args={
|
|
453
|
+
"use_async_quorum": False,
|
|
454
|
+
"init_sync": False,
|
|
455
|
+
},
|
|
456
|
+
train_loop=ddp_train_loop,
|
|
457
|
+
)
|
|
458
|
+
futures.append(executor.submit(runner.run_replica))
|
|
459
|
+
|
|
460
|
+
state_dicts = []
|
|
461
|
+
|
|
462
|
+
for fut in as_completed(futures):
|
|
463
|
+
try:
|
|
464
|
+
state_dicts.append(fut.result())
|
|
465
|
+
except Exception as e:
|
|
466
|
+
print(e)
|
|
467
|
+
raise
|
|
468
|
+
|
|
469
|
+
lighthouse.shutdown()
|
|
470
|
+
|
|
471
|
+
for state_dict in state_dicts:
|
|
472
|
+
torch.testing.assert_close(state_dict, state_dicts[0])
|
|
473
|
+
|
|
474
|
+
def test_ddp_recovery_multi_rank(self) -> None:
|
|
475
|
+
lighthouse = LighthouseServer(
|
|
476
|
+
bind="[::]:0",
|
|
477
|
+
min_replicas=2,
|
|
478
|
+
)
|
|
479
|
+
num_replicas = 2
|
|
480
|
+
world_size = 2
|
|
481
|
+
futures = []
|
|
482
|
+
|
|
483
|
+
event_injectors = [
|
|
484
|
+
EventInjector(),
|
|
485
|
+
EventInjector().fail_at(0, 2).fail_at(1, 2),
|
|
486
|
+
]
|
|
487
|
+
|
|
488
|
+
with ThreadPoolExecutor(max_workers=num_replicas) as executor:
|
|
489
|
+
for replica_id, event_injector in zip(range(num_replicas), event_injectors):
|
|
490
|
+
runner = Runner(
|
|
491
|
+
replica_id=replica_id,
|
|
492
|
+
num_replicas=num_replicas,
|
|
493
|
+
lighthouse_address=lighthouse.address(),
|
|
494
|
+
event_injector=event_injector,
|
|
495
|
+
world_size=world_size,
|
|
496
|
+
train_loop=ddp_train_loop,
|
|
497
|
+
)
|
|
498
|
+
futures.append(executor.submit(runner.run_replica))
|
|
499
|
+
|
|
500
|
+
state_dicts = []
|
|
501
|
+
|
|
502
|
+
for fut in as_completed(futures):
|
|
503
|
+
try:
|
|
504
|
+
state_dicts.append(fut.result())
|
|
505
|
+
except Exception as e:
|
|
506
|
+
print(e)
|
|
507
|
+
raise
|
|
508
|
+
|
|
509
|
+
lighthouse.shutdown()
|
|
510
|
+
|
|
511
|
+
for state_dict in state_dicts:
|
|
512
|
+
torch.testing.assert_close(state_dict, state_dicts[0])
|
|
513
|
+
|
|
514
|
+
def test_quorum_timeout(self) -> None:
|
|
515
|
+
with ExitStack() as stack:
|
|
516
|
+
lighthouse = LighthouseServer(
|
|
517
|
+
bind="[::]:0",
|
|
518
|
+
min_replicas=2,
|
|
519
|
+
)
|
|
520
|
+
stack.callback(lighthouse.shutdown)
|
|
521
|
+
|
|
522
|
+
store = dist.TCPStore(
|
|
523
|
+
host_name="localhost",
|
|
524
|
+
port=0,
|
|
525
|
+
is_master=True,
|
|
526
|
+
wait_for_workers=False,
|
|
527
|
+
)
|
|
528
|
+
|
|
529
|
+
pg = ProcessGroupGloo()
|
|
530
|
+
manager = Manager(
|
|
531
|
+
pg=pg,
|
|
532
|
+
min_replica_size=2,
|
|
533
|
+
load_state_dict=lambda x: None,
|
|
534
|
+
state_dict=lambda: None,
|
|
535
|
+
store_addr="localhost",
|
|
536
|
+
store_port=store.port,
|
|
537
|
+
rank=0,
|
|
538
|
+
world_size=2,
|
|
539
|
+
lighthouse_addr=lighthouse.address(),
|
|
540
|
+
port=19530,
|
|
541
|
+
use_async_quorum=False,
|
|
542
|
+
)
|
|
543
|
+
stack.callback(lambda: manager.shutdown(wait=False))
|
|
544
|
+
|
|
545
|
+
with self.assertElapsedLessThan(1.0):
|
|
546
|
+
with self.assertRaisesRegex(
|
|
547
|
+
TimeoutError,
|
|
548
|
+
"status: Cancelled, message.*Timeout expired",
|
|
549
|
+
):
|
|
550
|
+
manager.start_quorum(timeout=timedelta(seconds=0.01))
|
|
551
|
+
|
|
552
|
+
with self.assertElapsedLessThan(1.0):
|
|
553
|
+
with self.assertRaisesRegex(
|
|
554
|
+
TimeoutError,
|
|
555
|
+
"status: Cancelled, message.*Timeout expired",
|
|
556
|
+
):
|
|
557
|
+
manager.should_commit(timeout=timedelta(seconds=0.01))
|
|
558
|
+
|
|
559
|
+
@parameterized.expand(
|
|
560
|
+
[
|
|
561
|
+
(True,), # Test with CUDA
|
|
562
|
+
(False,), # Test without CUDA (CPU)
|
|
563
|
+
]
|
|
564
|
+
)
|
|
565
|
+
def test_manager_allreduce(self, use_cuda: bool) -> None:
|
|
566
|
+
# Skip the test if use_cuda is True and there are not enough GPUs
|
|
567
|
+
if use_cuda and torch.cuda.device_count() < 2:
|
|
568
|
+
self.skipTest("Not enough GPUs for CUDA test")
|
|
569
|
+
|
|
570
|
+
# manager supports allreduce but we found an issue where the future callback is getting called
|
|
571
|
+
# before the allreduce is complete. This test is to ensure that the callback has stream synchronization
|
|
572
|
+
lighthouse = LighthouseServer(
|
|
573
|
+
bind="[::]:0",
|
|
574
|
+
min_replicas=2,
|
|
575
|
+
)
|
|
576
|
+
num_replicas = 2
|
|
577
|
+
futures = []
|
|
578
|
+
|
|
579
|
+
with ThreadPoolExecutor(max_workers=num_replicas) as executor:
|
|
580
|
+
for replica_id in range(num_replicas):
|
|
581
|
+
event_injector = EventInjector()
|
|
582
|
+
runner = Runner(
|
|
583
|
+
replica_id=replica_id,
|
|
584
|
+
num_replicas=num_replicas,
|
|
585
|
+
lighthouse_address=lighthouse.address(),
|
|
586
|
+
event_injector=event_injector,
|
|
587
|
+
train_loop=all_reduce_callback,
|
|
588
|
+
use_cuda=use_cuda,
|
|
589
|
+
)
|
|
590
|
+
futures.append(executor.submit(runner.run_replica))
|
|
591
|
+
|
|
592
|
+
results = []
|
|
593
|
+
for fut in as_completed(futures):
|
|
594
|
+
try:
|
|
595
|
+
results.append(fut.result()[0])
|
|
596
|
+
except Exception as e:
|
|
597
|
+
print(e, flush=True)
|
|
598
|
+
traceback.print_exc()
|
|
599
|
+
raise
|
|
600
|
+
|
|
601
|
+
lighthouse.shutdown()
|
|
602
|
+
|
|
603
|
+
print(results)
|
|
604
|
+
r0, r1 = results
|
|
605
|
+
torch.testing.assert_close(r0, r1, check_device=False)
|
|
606
|
+
|
|
607
|
+
|
|
608
|
+
def all_reduce_callback(
|
|
609
|
+
rank: int,
|
|
610
|
+
store_port: int,
|
|
611
|
+
device: torch.device,
|
|
612
|
+
runner: Runner,
|
|
613
|
+
train_loop_args: dict[str, Any] = {},
|
|
614
|
+
) -> Optional[torch.Tensor]:
|
|
615
|
+
with ExitStack() as stack:
|
|
616
|
+
print(f"worker {runner.replica_id=} {rank=} {runner.world_size=} starting")
|
|
617
|
+
|
|
618
|
+
if device.type == "cuda":
|
|
619
|
+
pg = ProcessGroupBabyNCCL()
|
|
620
|
+
else:
|
|
621
|
+
pg = ProcessGroupGloo()
|
|
622
|
+
manager = Manager(
|
|
623
|
+
pg=pg,
|
|
624
|
+
min_replica_size=2,
|
|
625
|
+
use_async_quorum=False,
|
|
626
|
+
load_state_dict=lambda x: None,
|
|
627
|
+
state_dict=lambda: None,
|
|
628
|
+
replica_id=str(runner.replica_id),
|
|
629
|
+
store_addr="localhost",
|
|
630
|
+
store_port=store_port,
|
|
631
|
+
rank=rank,
|
|
632
|
+
world_size=runner.world_size,
|
|
633
|
+
lighthouse_addr=runner.lighthouse_address,
|
|
634
|
+
port=19530 + runner.replica_id,
|
|
635
|
+
timeout=timedelta(seconds=10),
|
|
636
|
+
quorum_timeout=timedelta(seconds=10),
|
|
637
|
+
# pyre-fixme[6]: Incompatible parameter type
|
|
638
|
+
**runner.manager_args,
|
|
639
|
+
)
|
|
640
|
+
stack.callback(lambda: manager.shutdown(wait=False))
|
|
641
|
+
|
|
642
|
+
manager.start_quorum()
|
|
643
|
+
t1 = torch.ones((1, 3), device=device)
|
|
644
|
+
work = manager.allreduce(t1)
|
|
645
|
+
work.wait()
|
|
646
|
+
return t1
|
|
647
|
+
return None
|
|
648
|
+
|
|
649
|
+
|
|
650
|
+
if __name__ == "__main__":
|
|
651
|
+
import unittest
|
|
652
|
+
|
|
653
|
+
unittest.main()
|