torchft-nightly 2026.1.3__cp310-cp310-manylinux_2_24_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchft/__init__.py +34 -0
- torchft/_test/diloco_trainer.py +287 -0
- torchft/_test/managed_work_test.py +320 -0
- torchft/_test_utils.py +111 -0
- torchft/_torchft.cpython-310-x86_64-linux-gnu.so +0 -0
- torchft/_torchft.pyi +116 -0
- torchft/checkpointing/__init__.py +20 -0
- torchft/checkpointing/_rwlock.py +136 -0
- torchft/checkpointing/_serialization.py +39 -0
- torchft/checkpointing/http_transport.py +299 -0
- torchft/checkpointing/http_transport_bench.py +61 -0
- torchft/checkpointing/http_transport_test.py +146 -0
- torchft/checkpointing/pg_transport.py +306 -0
- torchft/checkpointing/pg_transport_bench.py +99 -0
- torchft/checkpointing/pg_transport_test.py +101 -0
- torchft/checkpointing/rwlock_test.py +58 -0
- torchft/checkpointing/transport.py +68 -0
- torchft/checkpointing/transport_test.py +161 -0
- torchft/collectives.py +415 -0
- torchft/collectives_test.py +212 -0
- torchft/coordination.py +39 -0
- torchft/coordination_test.py +29 -0
- torchft/data.py +77 -0
- torchft/data_test.py +39 -0
- torchft/ddp.py +105 -0
- torchft/ddp_test.py +68 -0
- torchft/diloco_regression_test.py +644 -0
- torchft/examples/slurm/README.md +34 -0
- torchft/examples/slurm/punisher.py +95 -0
- torchft/examples/slurm/runner.py +221 -0
- torchft/fsdp_test.py +102 -0
- torchft/futures.py +353 -0
- torchft/futures_test.py +140 -0
- torchft/http.py +13 -0
- torchft/lighthouse_test.py +163 -0
- torchft/local_sgd.py +796 -0
- torchft/local_sgd_integ_test.py +600 -0
- torchft/local_sgd_test.py +324 -0
- torchft/manager.py +1358 -0
- torchft/manager_integ_test.py +653 -0
- torchft/manager_test.py +911 -0
- torchft/multiprocessing.py +38 -0
- torchft/multiprocessing_dummy_context.py +135 -0
- torchft/multiprocessing_test.py +58 -0
- torchft/optim.py +63 -0
- torchft/optim_test.py +50 -0
- torchft/otel.py +134 -0
- torchft/parameter_server.py +195 -0
- torchft/parameter_server_test.py +47 -0
- torchft/process_group.py +2118 -0
- torchft/process_group_test.py +1028 -0
- torchft/quantization.py +686 -0
- torchft/quantization_test.py +131 -0
- torchft/torchx.py +89 -0
- torchft/utils.py +67 -0
- torchft/work.py +26 -0
- torchft_nightly-2026.1.3.dist-info/METADATA +308 -0
- torchft_nightly-2026.1.3.dist-info/RECORD +61 -0
- torchft_nightly-2026.1.3.dist-info/WHEEL +4 -0
- torchft_nightly-2026.1.3.dist-info/entry_points.txt +2 -0
- torchft_nightly-2026.1.3.dist-info/licenses/LICENSE +34 -0
|
@@ -0,0 +1,600 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
import copy
|
|
8
|
+
import logging
|
|
9
|
+
import os
|
|
10
|
+
import re
|
|
11
|
+
import sys
|
|
12
|
+
import threading
|
|
13
|
+
import traceback
|
|
14
|
+
from concurrent.futures import as_completed, ThreadPoolExecutor
|
|
15
|
+
from contextlib import ExitStack
|
|
16
|
+
from dataclasses import field
|
|
17
|
+
from datetime import timedelta
|
|
18
|
+
from typing import Any, cast, Dict
|
|
19
|
+
from unittest import skipIf, TestCase
|
|
20
|
+
|
|
21
|
+
import torch
|
|
22
|
+
from parameterized import parameterized
|
|
23
|
+
from torch import nn, optim
|
|
24
|
+
from torch.distributed.pipelining import pipeline, SplitPoint
|
|
25
|
+
from torch.distributed.tensor import DTensor, Replicate
|
|
26
|
+
|
|
27
|
+
from torchft._test.diloco_trainer import DiLoCoTrainer, MultiMyModel
|
|
28
|
+
from torchft._torchft import LighthouseServer
|
|
29
|
+
from torchft.local_sgd import DiLoCo, LocalSGD
|
|
30
|
+
from torchft.manager import Manager
|
|
31
|
+
from torchft.manager_integ_test import (
|
|
32
|
+
EventInjector,
|
|
33
|
+
EventInjectorEvent,
|
|
34
|
+
MyModel,
|
|
35
|
+
Runner,
|
|
36
|
+
)
|
|
37
|
+
from torchft.process_group import (
|
|
38
|
+
FakeProcessGroupWrapper,
|
|
39
|
+
ProcessGroupBabyNCCL,
|
|
40
|
+
ProcessGroupGloo,
|
|
41
|
+
)
|
|
42
|
+
|
|
43
|
+
logger: logging.Logger = logging.getLogger(__name__)
|
|
44
|
+
logging.basicConfig(level=logging.INFO)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def local_sgd_train_loop(
|
|
48
|
+
rank: int,
|
|
49
|
+
store_port: int,
|
|
50
|
+
device: torch.device,
|
|
51
|
+
runner: Runner,
|
|
52
|
+
train_loop_args: dict[str, Any] = {},
|
|
53
|
+
) -> Dict[str, Dict[str, object]]:
|
|
54
|
+
with ExitStack() as stack:
|
|
55
|
+
|
|
56
|
+
def load_state_dict(state_dict: Dict[str, Dict[str, object]]) -> None:
|
|
57
|
+
m.load_state_dict(state_dict["model"])
|
|
58
|
+
optimizer.load_state_dict(state_dict["optim"])
|
|
59
|
+
|
|
60
|
+
def state_dict() -> Dict[str, Dict[str, object]]:
|
|
61
|
+
return {
|
|
62
|
+
"model": m.state_dict(),
|
|
63
|
+
"optim": optimizer.state_dict(),
|
|
64
|
+
}
|
|
65
|
+
|
|
66
|
+
print(f"worker {runner.replica_id=} {rank=} {runner.world_size=} starting")
|
|
67
|
+
|
|
68
|
+
if device.type == "cuda":
|
|
69
|
+
pg = ProcessGroupBabyNCCL()
|
|
70
|
+
else:
|
|
71
|
+
pg = ProcessGroupGloo()
|
|
72
|
+
manager = Manager(
|
|
73
|
+
pg=pg,
|
|
74
|
+
min_replica_size=2,
|
|
75
|
+
load_state_dict=load_state_dict,
|
|
76
|
+
state_dict=state_dict,
|
|
77
|
+
replica_id=str(runner.replica_id),
|
|
78
|
+
store_addr="localhost",
|
|
79
|
+
store_port=store_port,
|
|
80
|
+
rank=rank,
|
|
81
|
+
world_size=runner.world_size,
|
|
82
|
+
lighthouse_addr=runner.lighthouse_address,
|
|
83
|
+
port=19530 + runner.replica_id,
|
|
84
|
+
timeout=timedelta(seconds=10),
|
|
85
|
+
# pyre-fixme[6]: Incompatible parameter type
|
|
86
|
+
**runner.manager_args,
|
|
87
|
+
)
|
|
88
|
+
stack.callback(lambda: manager.shutdown(wait=False))
|
|
89
|
+
|
|
90
|
+
m: nn.Module = MyModel().to(device)
|
|
91
|
+
|
|
92
|
+
optimizer: optim.Optimizer = optim.Adam(m.parameters())
|
|
93
|
+
criterion = nn.CrossEntropyLoss()
|
|
94
|
+
|
|
95
|
+
with LocalSGD(manager, m, optimizer, sync_every=2) as local_sgd:
|
|
96
|
+
while True:
|
|
97
|
+
inputs = torch.rand(2, 3).to(device)
|
|
98
|
+
labels = torch.randint(4, (2,)).to(device)
|
|
99
|
+
|
|
100
|
+
optimizer.zero_grad()
|
|
101
|
+
out = m(inputs)
|
|
102
|
+
loss = criterion(out, labels)
|
|
103
|
+
loss.backward()
|
|
104
|
+
|
|
105
|
+
optimizer.step()
|
|
106
|
+
|
|
107
|
+
if manager.current_step() >= 4:
|
|
108
|
+
break
|
|
109
|
+
|
|
110
|
+
runner.event_injector.check(rank, manager.current_step())
|
|
111
|
+
|
|
112
|
+
# return state_dict so we can check consistency
|
|
113
|
+
return state_dict()
|
|
114
|
+
return {}
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
def diloco_train_loop(
|
|
118
|
+
rank: int,
|
|
119
|
+
store_port: int,
|
|
120
|
+
device: torch.device,
|
|
121
|
+
runner: Runner,
|
|
122
|
+
train_loop_args: dict[str, Any] = {},
|
|
123
|
+
) -> Dict[str, Dict[str, object]]:
|
|
124
|
+
model_state_dict = train_loop_args.get("model_state_dict", {})
|
|
125
|
+
n_fragments = train_loop_args.get("n_fragments", 1)
|
|
126
|
+
diloco_args = train_loop_args.get("diloco_args", {})
|
|
127
|
+
|
|
128
|
+
with ExitStack() as stack:
|
|
129
|
+
trainer = DiLoCoTrainer(
|
|
130
|
+
rank, store_port, device, runner, model_state_dict, n_fragments, diloco_args
|
|
131
|
+
)
|
|
132
|
+
stack.callback(trainer.manager.shutdown)
|
|
133
|
+
return trainer.train_loop()
|
|
134
|
+
return {}
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
def assert_equal_global_state(
|
|
138
|
+
n_fragments: int,
|
|
139
|
+
rep0: dict[str, dict[str, dict[str, dict[str, object]]]],
|
|
140
|
+
rep1: dict[str, dict[str, dict[str, dict[str, object]]]],
|
|
141
|
+
) -> None:
|
|
142
|
+
"""
|
|
143
|
+
Asserts that the global state of the two replicas are equal
|
|
144
|
+
"""
|
|
145
|
+
for step in rep0.keys():
|
|
146
|
+
for i in range(n_fragments):
|
|
147
|
+
torch.testing.assert_close(
|
|
148
|
+
rep1[step]["user"][f"StreamingDiLoCoFragment_{i}"][
|
|
149
|
+
"original_parameters"
|
|
150
|
+
],
|
|
151
|
+
rep0[step]["user"][f"StreamingDiLoCoFragment_{i}"][
|
|
152
|
+
"original_parameters"
|
|
153
|
+
],
|
|
154
|
+
check_device=False,
|
|
155
|
+
msg=f"{step=} {i=}",
|
|
156
|
+
)
|
|
157
|
+
# Check all outer optimizers
|
|
158
|
+
torch.testing.assert_close(
|
|
159
|
+
cast(
|
|
160
|
+
dict[str, dict[str, torch.Tensor]],
|
|
161
|
+
rep1[step]["user"][f"StreamingDiLoCoFragment_{i}"][
|
|
162
|
+
"outer_optimizer"
|
|
163
|
+
],
|
|
164
|
+
),
|
|
165
|
+
cast(
|
|
166
|
+
dict[str, dict[str, torch.Tensor]],
|
|
167
|
+
rep0[step]["user"][f"StreamingDiLoCoFragment_{i}"][
|
|
168
|
+
"outer_optimizer"
|
|
169
|
+
],
|
|
170
|
+
),
|
|
171
|
+
check_device=False,
|
|
172
|
+
)
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
class LocalSGDIntegTest(TestCase):
|
|
176
|
+
# TODO: race condition due to using NCCL in threads causes manager allreduce to sometimes not be correct
|
|
177
|
+
# Because of that the test is disabled for cuda
|
|
178
|
+
@parameterized.expand(
|
|
179
|
+
[
|
|
180
|
+
# (True,),
|
|
181
|
+
(False,),
|
|
182
|
+
]
|
|
183
|
+
)
|
|
184
|
+
def test_local_sgd_recovery(self, use_cuda: bool) -> None:
|
|
185
|
+
# Skip the test if use_cuda is True and there are not enough GPUs
|
|
186
|
+
if use_cuda and torch.cuda.device_count() < 2:
|
|
187
|
+
self.skipTest("Not enough GPUs for CUDA test")
|
|
188
|
+
if sys.platform == "darwin":
|
|
189
|
+
self.skipTest("not reliable on mac")
|
|
190
|
+
|
|
191
|
+
lighthouse = LighthouseServer(
|
|
192
|
+
bind="[::]:0",
|
|
193
|
+
min_replicas=2,
|
|
194
|
+
)
|
|
195
|
+
num_replicas = 2
|
|
196
|
+
futures = []
|
|
197
|
+
|
|
198
|
+
event_injectors = [
|
|
199
|
+
EventInjector(),
|
|
200
|
+
EventInjector().fail_at(0, 2),
|
|
201
|
+
]
|
|
202
|
+
|
|
203
|
+
with ThreadPoolExecutor(max_workers=num_replicas) as executor:
|
|
204
|
+
for replica_id, event_injector in zip(range(num_replicas), event_injectors):
|
|
205
|
+
runner = Runner(
|
|
206
|
+
replica_id=replica_id,
|
|
207
|
+
num_replicas=num_replicas,
|
|
208
|
+
lighthouse_address=lighthouse.address(),
|
|
209
|
+
event_injector=event_injector,
|
|
210
|
+
train_loop=local_sgd_train_loop,
|
|
211
|
+
use_cuda=use_cuda,
|
|
212
|
+
manager_args={
|
|
213
|
+
"use_async_quorum": False,
|
|
214
|
+
},
|
|
215
|
+
)
|
|
216
|
+
futures.append(executor.submit(runner.run_replica))
|
|
217
|
+
|
|
218
|
+
state_dicts = []
|
|
219
|
+
|
|
220
|
+
for fut in as_completed(futures):
|
|
221
|
+
try:
|
|
222
|
+
state_dicts.append(fut.result())
|
|
223
|
+
except Exception as e:
|
|
224
|
+
print(e)
|
|
225
|
+
raise
|
|
226
|
+
|
|
227
|
+
lighthouse.shutdown()
|
|
228
|
+
|
|
229
|
+
for state_dict in state_dicts:
|
|
230
|
+
# LocalSGD only guarantees that the model is consistent across
|
|
231
|
+
# replicas but uses separate optimizer states.
|
|
232
|
+
torch.testing.assert_close(
|
|
233
|
+
state_dict[0]["model"], state_dicts[0][0]["model"], check_device=False
|
|
234
|
+
)
|
|
235
|
+
|
|
236
|
+
self.assertEqual(event_injectors[1].count[EventInjectorEvent.Failure], 1)
|
|
237
|
+
|
|
238
|
+
@parameterized.expand(
|
|
239
|
+
[
|
|
240
|
+
# (True,),
|
|
241
|
+
(False,),
|
|
242
|
+
]
|
|
243
|
+
)
|
|
244
|
+
def test_diloco_healthy(self, use_cuda: bool) -> None:
|
|
245
|
+
# Skip the test if use_cuda is True and there are not enough GPUs
|
|
246
|
+
if use_cuda and torch.cuda.device_count() < 2:
|
|
247
|
+
self.skipTest("Not enough GPUs for CUDA test")
|
|
248
|
+
if sys.platform == "darwin":
|
|
249
|
+
self.skipTest("not reliable on mac")
|
|
250
|
+
|
|
251
|
+
lighthouse = LighthouseServer(bind="[::]:0", min_replicas=2)
|
|
252
|
+
num_replicas = 2
|
|
253
|
+
futures = []
|
|
254
|
+
|
|
255
|
+
torch.manual_seed(42)
|
|
256
|
+
# Initialize the model so we can pass in the state_dict
|
|
257
|
+
m: nn.Module = MultiMyModel(2, 3, 1)
|
|
258
|
+
|
|
259
|
+
with ThreadPoolExecutor(max_workers=num_replicas) as executor:
|
|
260
|
+
for replica_id in range(num_replicas):
|
|
261
|
+
event_injector = EventInjector()
|
|
262
|
+
runner = Runner(
|
|
263
|
+
replica_id=replica_id,
|
|
264
|
+
num_replicas=num_replicas,
|
|
265
|
+
lighthouse_address=lighthouse.address(),
|
|
266
|
+
event_injector=event_injector,
|
|
267
|
+
train_loop=diloco_train_loop,
|
|
268
|
+
use_cuda=use_cuda,
|
|
269
|
+
train_loop_args={
|
|
270
|
+
"model_state_dict": m.state_dict(),
|
|
271
|
+
},
|
|
272
|
+
)
|
|
273
|
+
futures.append(executor.submit(runner.run_replica))
|
|
274
|
+
|
|
275
|
+
state_dicts = []
|
|
276
|
+
for fut in as_completed(futures):
|
|
277
|
+
try:
|
|
278
|
+
state_dicts.append(fut.result()[0])
|
|
279
|
+
except Exception as e:
|
|
280
|
+
print(e, flush=True)
|
|
281
|
+
traceback.print_exc()
|
|
282
|
+
raise
|
|
283
|
+
|
|
284
|
+
lighthouse.shutdown()
|
|
285
|
+
|
|
286
|
+
rep0, rep1 = state_dicts
|
|
287
|
+
assert_equal_global_state(1, rep1, rep0)
|
|
288
|
+
|
|
289
|
+
# pyre-fixme[56]: Pyre was not able to infer the type of argument
|
|
290
|
+
@skipIf(sys.platform == "darwin", "not reliable on mac")
|
|
291
|
+
@parameterized.expand(
|
|
292
|
+
[
|
|
293
|
+
# (True,),
|
|
294
|
+
(False,),
|
|
295
|
+
]
|
|
296
|
+
)
|
|
297
|
+
def test_diloco_recovery(self, use_cuda: bool) -> None:
|
|
298
|
+
# Skip the test if use_cuda is True and there are not enough GPUs
|
|
299
|
+
if use_cuda and torch.cuda.device_count() < 2:
|
|
300
|
+
self.skipTest("Not enough GPUs for CUDA test")
|
|
301
|
+
if sys.platform == "darwin":
|
|
302
|
+
self.skipTest("not reliable on mac")
|
|
303
|
+
|
|
304
|
+
lighthouse = LighthouseServer(
|
|
305
|
+
bind="[::]:0",
|
|
306
|
+
min_replicas=2,
|
|
307
|
+
)
|
|
308
|
+
num_replicas = 2
|
|
309
|
+
futures = []
|
|
310
|
+
|
|
311
|
+
event_injectors = [
|
|
312
|
+
EventInjector(),
|
|
313
|
+
EventInjector().fail_at(0, 2),
|
|
314
|
+
]
|
|
315
|
+
|
|
316
|
+
torch.manual_seed(42)
|
|
317
|
+
# Initialize the model so we can pass in the state_dict
|
|
318
|
+
m: nn.Module = MultiMyModel(2, 3, 1)
|
|
319
|
+
|
|
320
|
+
with ThreadPoolExecutor(max_workers=num_replicas) as executor:
|
|
321
|
+
for replica_id, event_injector in zip(range(num_replicas), event_injectors):
|
|
322
|
+
runner = Runner(
|
|
323
|
+
replica_id=replica_id,
|
|
324
|
+
num_replicas=num_replicas,
|
|
325
|
+
lighthouse_address=lighthouse.address(),
|
|
326
|
+
event_injector=event_injector,
|
|
327
|
+
train_loop=diloco_train_loop,
|
|
328
|
+
train_loop_args={
|
|
329
|
+
"model_state_dict": m.state_dict(),
|
|
330
|
+
},
|
|
331
|
+
)
|
|
332
|
+
futures.append(executor.submit(runner.run_replica))
|
|
333
|
+
|
|
334
|
+
state_dicts = []
|
|
335
|
+
|
|
336
|
+
for fut in as_completed(futures):
|
|
337
|
+
continue
|
|
338
|
+
|
|
339
|
+
for fut in futures:
|
|
340
|
+
try:
|
|
341
|
+
state_dicts.append(fut.result()[0])
|
|
342
|
+
except Exception as e:
|
|
343
|
+
print(e)
|
|
344
|
+
raise
|
|
345
|
+
|
|
346
|
+
lighthouse.shutdown()
|
|
347
|
+
|
|
348
|
+
rep0, rep1 = state_dicts
|
|
349
|
+
|
|
350
|
+
# Inner optimizer and local model parameters will be different e.g.
|
|
351
|
+
# with 2 replicas r1 and r2, we sync every 2 steps
|
|
352
|
+
#
|
|
353
|
+
# - Manager Step 1
|
|
354
|
+
# - Step 1: r1 and r2 step
|
|
355
|
+
# - Step 2: r1 and r2 step, sync the model, quorum succeeds
|
|
356
|
+
# - Manager Step 2
|
|
357
|
+
# - Step 1: r1 steps but r2 fails
|
|
358
|
+
# - Step 2:
|
|
359
|
+
# - r1 steps, sync fails because r2 is down
|
|
360
|
+
# - r1 recovers r2 from the model state at this step
|
|
361
|
+
# that is different from the model for r1 at the beginning
|
|
362
|
+
# of step Manager Step 2
|
|
363
|
+
#
|
|
364
|
+
# Outer optimizer and global model should be the same
|
|
365
|
+
assert_equal_global_state(1, rep1, rep0)
|
|
366
|
+
|
|
367
|
+
self.assertEqual(event_injectors[1].count[EventInjectorEvent.Failure], 1)
|
|
368
|
+
|
|
369
|
+
# pyre-fixme[56]: Pyre was not able to infer the type of argument
|
|
370
|
+
@skipIf(sys.platform == "darwin", "not reliable on mac")
|
|
371
|
+
@parameterized.expand(
|
|
372
|
+
[
|
|
373
|
+
# (True,),
|
|
374
|
+
(False,),
|
|
375
|
+
]
|
|
376
|
+
)
|
|
377
|
+
def test_streaming_diloco_recovery(self, use_cuda: bool) -> None:
|
|
378
|
+
# Skip the test if use_cuda is True and there are not enough GPUs
|
|
379
|
+
if use_cuda and torch.cuda.device_count() < 2:
|
|
380
|
+
self.skipTest("Not enough GPUs for CUDA test")
|
|
381
|
+
if sys.platform == "darwin":
|
|
382
|
+
self.skipTest("not reliable on mac")
|
|
383
|
+
|
|
384
|
+
lighthouse = LighthouseServer(
|
|
385
|
+
bind="[::]:0",
|
|
386
|
+
min_replicas=2,
|
|
387
|
+
)
|
|
388
|
+
num_replicas = 2
|
|
389
|
+
futures = []
|
|
390
|
+
|
|
391
|
+
event_injectors = [
|
|
392
|
+
EventInjector(),
|
|
393
|
+
EventInjector().fail_at(0, 2),
|
|
394
|
+
]
|
|
395
|
+
|
|
396
|
+
torch.manual_seed(42)
|
|
397
|
+
# Initialize the model so we can pass in the state_dict
|
|
398
|
+
m: nn.Module = MultiMyModel(2, 3, 2)
|
|
399
|
+
|
|
400
|
+
with ThreadPoolExecutor(max_workers=num_replicas) as executor:
|
|
401
|
+
for replica_id, event_injector in zip(range(num_replicas), event_injectors):
|
|
402
|
+
runner = Runner(
|
|
403
|
+
replica_id=replica_id,
|
|
404
|
+
num_replicas=num_replicas,
|
|
405
|
+
lighthouse_address=lighthouse.address(),
|
|
406
|
+
event_injector=event_injector,
|
|
407
|
+
train_loop=diloco_train_loop,
|
|
408
|
+
train_loop_args={
|
|
409
|
+
"model_state_dict": m.state_dict(),
|
|
410
|
+
"n_fragments": 2,
|
|
411
|
+
"diloco_args": {
|
|
412
|
+
"fragment_sync_delay": 1,
|
|
413
|
+
"sync_every": 4,
|
|
414
|
+
},
|
|
415
|
+
},
|
|
416
|
+
)
|
|
417
|
+
futures.append(executor.submit(runner.run_replica))
|
|
418
|
+
|
|
419
|
+
state_dicts = []
|
|
420
|
+
|
|
421
|
+
for fut in as_completed(futures):
|
|
422
|
+
continue
|
|
423
|
+
|
|
424
|
+
for fut in futures:
|
|
425
|
+
try:
|
|
426
|
+
state_dicts.append(fut.result()[0])
|
|
427
|
+
except Exception as e:
|
|
428
|
+
print(e)
|
|
429
|
+
raise
|
|
430
|
+
|
|
431
|
+
lighthouse.shutdown()
|
|
432
|
+
|
|
433
|
+
rep0, rep1 = state_dicts
|
|
434
|
+
|
|
435
|
+
assert_equal_global_state(2, rep1, rep0)
|
|
436
|
+
|
|
437
|
+
self.assertEqual(event_injectors[1].count[EventInjectorEvent.Failure], 1)
|
|
438
|
+
|
|
439
|
+
CONFIG: list[tuple[bool, int, int, float]] = [
|
|
440
|
+
(use_cuda, n_fragments, fragment_sync_delay, alpha)
|
|
441
|
+
for use_cuda in [False]
|
|
442
|
+
for n_fragments in [1, 2]
|
|
443
|
+
for fragment_sync_delay in [0, 1]
|
|
444
|
+
for alpha in [0.0, 0.5, 1.0]
|
|
445
|
+
]
|
|
446
|
+
|
|
447
|
+
# pyre-fixme[56]: Pyre was not able to infer the type of argument
|
|
448
|
+
@skipIf(sys.platform == "darwin", "not reliable on mac")
|
|
449
|
+
@parameterized.expand(CONFIG)
|
|
450
|
+
def test_streaming_diloco_upscale(
|
|
451
|
+
self, use_cuda: bool, n_fragments: int, fragment_sync_delay: int, alpha: float
|
|
452
|
+
) -> None:
|
|
453
|
+
# Skip the test if use_cuda is True and there are not enough GPUs
|
|
454
|
+
if use_cuda and torch.cuda.device_count() < 2:
|
|
455
|
+
self.skipTest("Not enough GPUs for CUDA test")
|
|
456
|
+
if sys.platform == "darwin":
|
|
457
|
+
self.skipTest("not reliable on mac")
|
|
458
|
+
|
|
459
|
+
lighthouse = LighthouseServer(
|
|
460
|
+
bind="[::]:0",
|
|
461
|
+
min_replicas=2,
|
|
462
|
+
)
|
|
463
|
+
num_replicas = 3
|
|
464
|
+
futures = []
|
|
465
|
+
executors = []
|
|
466
|
+
|
|
467
|
+
barrier = threading.Barrier(num_replicas)
|
|
468
|
+
|
|
469
|
+
event_injectors = [
|
|
470
|
+
# Make this replica join after other replicas have made 2 steps
|
|
471
|
+
EventInjector().barrier_at(0, 0, barrier),
|
|
472
|
+
EventInjector().barrier_at(0, 2, barrier),
|
|
473
|
+
EventInjector().barrier_at(0, 2, barrier),
|
|
474
|
+
]
|
|
475
|
+
|
|
476
|
+
torch.manual_seed(42)
|
|
477
|
+
# Initialize the model so we can pass in the state_dict
|
|
478
|
+
m: nn.Module = MultiMyModel(2, 3, n_fragments)
|
|
479
|
+
|
|
480
|
+
for replica_id, event_injector in zip(range(num_replicas), event_injectors):
|
|
481
|
+
executor = ThreadPoolExecutor(max_workers=1)
|
|
482
|
+
executors.append(executor)
|
|
483
|
+
runner = Runner(
|
|
484
|
+
replica_id=replica_id,
|
|
485
|
+
num_replicas=num_replicas,
|
|
486
|
+
lighthouse_address=lighthouse.address(),
|
|
487
|
+
event_injector=event_injector,
|
|
488
|
+
train_loop=diloco_train_loop,
|
|
489
|
+
train_loop_args={
|
|
490
|
+
"model_state_dict": m.state_dict(),
|
|
491
|
+
"n_fragments": n_fragments,
|
|
492
|
+
"diloco_args": {
|
|
493
|
+
"fragment_sync_delay": fragment_sync_delay,
|
|
494
|
+
"sync_every": 4,
|
|
495
|
+
"fragment_update_alpha": alpha,
|
|
496
|
+
},
|
|
497
|
+
},
|
|
498
|
+
)
|
|
499
|
+
futures.append(executor.submit(runner.run_replica))
|
|
500
|
+
|
|
501
|
+
state_dicts = []
|
|
502
|
+
|
|
503
|
+
for fut in as_completed(futures):
|
|
504
|
+
continue
|
|
505
|
+
|
|
506
|
+
for fut in futures:
|
|
507
|
+
try:
|
|
508
|
+
state_dicts.append(fut.result()[0])
|
|
509
|
+
except Exception as e:
|
|
510
|
+
print(e)
|
|
511
|
+
raise
|
|
512
|
+
|
|
513
|
+
lighthouse.shutdown()
|
|
514
|
+
|
|
515
|
+
rep0, rep1, rep2 = state_dicts
|
|
516
|
+
|
|
517
|
+
assert_equal_global_state(n_fragments, rep0, rep1)
|
|
518
|
+
assert_equal_global_state(n_fragments, rep0, rep2)
|
|
519
|
+
|
|
520
|
+
for event_injector in event_injectors:
|
|
521
|
+
self.assertEqual(event_injectors[1].count[EventInjectorEvent.Barrier], 1)
|
|
522
|
+
|
|
523
|
+
# pyre-fixme[56]: Pyre was not able to infer the type of argument
|
|
524
|
+
@skipIf(sys.platform == "darwin", "not reliable on mac")
|
|
525
|
+
@parameterized.expand(CONFIG)
|
|
526
|
+
def test_streaming_diloco_commit_failure(
|
|
527
|
+
self, use_cuda: bool, n_fragments: int, fragment_sync_delay: int, alpha: float
|
|
528
|
+
) -> None:
|
|
529
|
+
# Skip the test if use_cuda is True and there are not enough GPUs
|
|
530
|
+
if use_cuda and torch.cuda.device_count() < 2:
|
|
531
|
+
self.skipTest("Not enough GPUs for CUDA test")
|
|
532
|
+
if sys.platform == "darwin":
|
|
533
|
+
self.skipTest("not reliable on mac")
|
|
534
|
+
|
|
535
|
+
lighthouse = LighthouseServer(
|
|
536
|
+
bind="[::]:0",
|
|
537
|
+
min_replicas=2,
|
|
538
|
+
)
|
|
539
|
+
num_replicas = 2
|
|
540
|
+
futures = []
|
|
541
|
+
executors = []
|
|
542
|
+
|
|
543
|
+
event_injectors = [
|
|
544
|
+
EventInjector().fail_allreduce_at(0, 1),
|
|
545
|
+
EventInjector().fail_allreduce_at(0, 1),
|
|
546
|
+
]
|
|
547
|
+
|
|
548
|
+
torch.manual_seed(42)
|
|
549
|
+
# Initialize the model so we can pass in the state_dict
|
|
550
|
+
m: nn.Module = MultiMyModel(2, 3, n_fragments)
|
|
551
|
+
|
|
552
|
+
for replica_id, event_injector in zip(range(num_replicas), event_injectors):
|
|
553
|
+
executor = ThreadPoolExecutor(max_workers=1)
|
|
554
|
+
executors.append(executor)
|
|
555
|
+
runner = Runner(
|
|
556
|
+
replica_id=replica_id,
|
|
557
|
+
num_replicas=num_replicas,
|
|
558
|
+
lighthouse_address=lighthouse.address(),
|
|
559
|
+
event_injector=event_injector,
|
|
560
|
+
train_loop=diloco_train_loop,
|
|
561
|
+
train_loop_args={
|
|
562
|
+
"model_state_dict": m.state_dict(),
|
|
563
|
+
"n_fragments": n_fragments,
|
|
564
|
+
"diloco_args": {
|
|
565
|
+
"fragment_sync_delay": fragment_sync_delay,
|
|
566
|
+
"sync_every": 4,
|
|
567
|
+
"fragment_update_alpha": alpha,
|
|
568
|
+
},
|
|
569
|
+
},
|
|
570
|
+
)
|
|
571
|
+
futures.append(executor.submit(runner.run_replica))
|
|
572
|
+
|
|
573
|
+
state_dicts = []
|
|
574
|
+
|
|
575
|
+
for fut in as_completed(futures):
|
|
576
|
+
continue
|
|
577
|
+
|
|
578
|
+
for fut in futures:
|
|
579
|
+
try:
|
|
580
|
+
state_dicts.append(fut.result()[0])
|
|
581
|
+
except Exception as e:
|
|
582
|
+
print(e)
|
|
583
|
+
raise
|
|
584
|
+
|
|
585
|
+
lighthouse.shutdown()
|
|
586
|
+
|
|
587
|
+
rep0, rep1 = state_dicts
|
|
588
|
+
|
|
589
|
+
assert_equal_global_state(n_fragments, rep0, rep1)
|
|
590
|
+
|
|
591
|
+
for event_injector in event_injectors:
|
|
592
|
+
self.assertEqual(
|
|
593
|
+
event_injector.count[EventInjectorEvent.AllreduceFailure], 1
|
|
594
|
+
)
|
|
595
|
+
|
|
596
|
+
|
|
597
|
+
if __name__ == "__main__":
|
|
598
|
+
import unittest
|
|
599
|
+
|
|
600
|
+
unittest.main()
|