torchft-nightly 2026.1.3__cp310-cp310-manylinux_2_24_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchft/__init__.py +34 -0
- torchft/_test/diloco_trainer.py +287 -0
- torchft/_test/managed_work_test.py +320 -0
- torchft/_test_utils.py +111 -0
- torchft/_torchft.cpython-310-x86_64-linux-gnu.so +0 -0
- torchft/_torchft.pyi +116 -0
- torchft/checkpointing/__init__.py +20 -0
- torchft/checkpointing/_rwlock.py +136 -0
- torchft/checkpointing/_serialization.py +39 -0
- torchft/checkpointing/http_transport.py +299 -0
- torchft/checkpointing/http_transport_bench.py +61 -0
- torchft/checkpointing/http_transport_test.py +146 -0
- torchft/checkpointing/pg_transport.py +306 -0
- torchft/checkpointing/pg_transport_bench.py +99 -0
- torchft/checkpointing/pg_transport_test.py +101 -0
- torchft/checkpointing/rwlock_test.py +58 -0
- torchft/checkpointing/transport.py +68 -0
- torchft/checkpointing/transport_test.py +161 -0
- torchft/collectives.py +415 -0
- torchft/collectives_test.py +212 -0
- torchft/coordination.py +39 -0
- torchft/coordination_test.py +29 -0
- torchft/data.py +77 -0
- torchft/data_test.py +39 -0
- torchft/ddp.py +105 -0
- torchft/ddp_test.py +68 -0
- torchft/diloco_regression_test.py +644 -0
- torchft/examples/slurm/README.md +34 -0
- torchft/examples/slurm/punisher.py +95 -0
- torchft/examples/slurm/runner.py +221 -0
- torchft/fsdp_test.py +102 -0
- torchft/futures.py +353 -0
- torchft/futures_test.py +140 -0
- torchft/http.py +13 -0
- torchft/lighthouse_test.py +163 -0
- torchft/local_sgd.py +796 -0
- torchft/local_sgd_integ_test.py +600 -0
- torchft/local_sgd_test.py +324 -0
- torchft/manager.py +1358 -0
- torchft/manager_integ_test.py +653 -0
- torchft/manager_test.py +911 -0
- torchft/multiprocessing.py +38 -0
- torchft/multiprocessing_dummy_context.py +135 -0
- torchft/multiprocessing_test.py +58 -0
- torchft/optim.py +63 -0
- torchft/optim_test.py +50 -0
- torchft/otel.py +134 -0
- torchft/parameter_server.py +195 -0
- torchft/parameter_server_test.py +47 -0
- torchft/process_group.py +2118 -0
- torchft/process_group_test.py +1028 -0
- torchft/quantization.py +686 -0
- torchft/quantization_test.py +131 -0
- torchft/torchx.py +89 -0
- torchft/utils.py +67 -0
- torchft/work.py +26 -0
- torchft_nightly-2026.1.3.dist-info/METADATA +308 -0
- torchft_nightly-2026.1.3.dist-info/RECORD +61 -0
- torchft_nightly-2026.1.3.dist-info/WHEEL +4 -0
- torchft_nightly-2026.1.3.dist-info/entry_points.txt +2 -0
- torchft_nightly-2026.1.3.dist-info/licenses/LICENSE +34 -0
|
@@ -0,0 +1,644 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
import argparse
|
|
8
|
+
import copy
|
|
9
|
+
import json
|
|
10
|
+
import logging
|
|
11
|
+
import os
|
|
12
|
+
import sys
|
|
13
|
+
import threading
|
|
14
|
+
from concurrent.futures import as_completed, ThreadPoolExecutor
|
|
15
|
+
from contextlib import ExitStack
|
|
16
|
+
from datetime import timedelta
|
|
17
|
+
from pathlib import Path
|
|
18
|
+
from typing import Any, Callable, cast, Dict, List, Optional, overload, Tuple
|
|
19
|
+
from unittest import skipIf, TestCase
|
|
20
|
+
|
|
21
|
+
import torch
|
|
22
|
+
from parameterized import parameterized
|
|
23
|
+
from torch import nn, optim
|
|
24
|
+
|
|
25
|
+
from torchft._test.diloco_trainer import DiLoCoTrainer, MultiModel
|
|
26
|
+
from torchft._torchft import LighthouseServer
|
|
27
|
+
from torchft.local_sgd import DiLoCo
|
|
28
|
+
from torchft.manager import Manager
|
|
29
|
+
from torchft.manager_integ_test import EventInjector, EventInjectorEvent, Runner
|
|
30
|
+
|
|
31
|
+
logging.basicConfig(level=logging.INFO)
|
|
32
|
+
logger: logging.Logger = logging.getLogger(__name__)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def handle_fixture(
|
|
36
|
+
fixture_filename: str,
|
|
37
|
+
results: list[list[Dict[str, Dict[int, Dict[str, List[float]]]]]],
|
|
38
|
+
) -> Optional[list[list[Dict[str, Dict[str, Dict[str, List[float]]]]]]]:
|
|
39
|
+
"""
|
|
40
|
+
Handle reading from or writing to a fixture file.
|
|
41
|
+
|
|
42
|
+
Args:
|
|
43
|
+
fixture_filename: The name of the fixture file (without path)
|
|
44
|
+
results: The results to write to the fixture file if in write mode
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
The fixture data when reading, None when writing
|
|
48
|
+
"""
|
|
49
|
+
script_directory = os.path.dirname(os.path.abspath(__file__))
|
|
50
|
+
root_directory = os.path.dirname(script_directory)
|
|
51
|
+
|
|
52
|
+
fixture_path = os.path.join(root_directory, "test_fixtures", fixture_filename)
|
|
53
|
+
|
|
54
|
+
write_fixture = os.environ.get("WRITE_FIXTURE", "false").lower() in ("true")
|
|
55
|
+
|
|
56
|
+
if write_fixture:
|
|
57
|
+
# Write results to fixture file
|
|
58
|
+
logger.info(f"Writing fixture to {fixture_path}")
|
|
59
|
+
with open(fixture_path, "w+") as f:
|
|
60
|
+
json.dump(results, f, indent=2)
|
|
61
|
+
return None
|
|
62
|
+
|
|
63
|
+
# Read fixture file and return the data
|
|
64
|
+
assert os.path.exists(fixture_path), f"Fixture file {fixture_path} does not exist"
|
|
65
|
+
logger.info(f"Validating against fixture at {fixture_path}")
|
|
66
|
+
with open(fixture_path, "r") as f:
|
|
67
|
+
fixture_data = json.load(f)
|
|
68
|
+
|
|
69
|
+
return fixture_data
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class MockLinear(nn.Module):
|
|
73
|
+
"""
|
|
74
|
+
A mock linear layer with deterministic parameter updates.
|
|
75
|
+
"""
|
|
76
|
+
|
|
77
|
+
def __init__(self, in_features: int, out_features: int) -> None:
|
|
78
|
+
super().__init__()
|
|
79
|
+
# Initialize with specific values to make tracking easier
|
|
80
|
+
self.weight = nn.Parameter(torch.ones(out_features, in_features))
|
|
81
|
+
|
|
82
|
+
# Fixed gradients for deterministic updates
|
|
83
|
+
self.weight_grad_value = 2
|
|
84
|
+
|
|
85
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
86
|
+
# We don't actually do a forward pass, this should not be called
|
|
87
|
+
raise
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
class MockModel(MultiModel):
|
|
91
|
+
"""
|
|
92
|
+
A mock model with deterministic parameter updates.
|
|
93
|
+
"""
|
|
94
|
+
|
|
95
|
+
def __init__(self, in_dim: int = 3, out_dim: int = 4, n_layers: int = 1) -> None:
|
|
96
|
+
super().__init__()
|
|
97
|
+
|
|
98
|
+
for _ in range(n_layers):
|
|
99
|
+
# We don't care about matching dimensionality, we're not going to pass any
|
|
100
|
+
# input through the model
|
|
101
|
+
self.layers.append(MockLinear(in_dim, out_dim))
|
|
102
|
+
|
|
103
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
104
|
+
# We don't actually do a forward pass, this should not be called
|
|
105
|
+
raise
|
|
106
|
+
|
|
107
|
+
|
|
108
|
+
class MockOptimizer(optim.Optimizer):
|
|
109
|
+
"""
|
|
110
|
+
A mock optimizer with deterministic parameter updates.
|
|
111
|
+
"""
|
|
112
|
+
|
|
113
|
+
from typing import Iterator
|
|
114
|
+
|
|
115
|
+
def __init__(self, params: Iterator[torch.nn.Parameter], lr: float = 0.1) -> None:
|
|
116
|
+
defaults = dict(lr=lr)
|
|
117
|
+
super(MockOptimizer, self).__init__(params, defaults)
|
|
118
|
+
|
|
119
|
+
@overload
|
|
120
|
+
def step(self, closure: None = None) -> None: ...
|
|
121
|
+
|
|
122
|
+
@overload
|
|
123
|
+
def step(self, closure: Callable[[], float]) -> float: ...
|
|
124
|
+
|
|
125
|
+
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
|
|
126
|
+
for group in self.param_groups:
|
|
127
|
+
for p in group["params"]:
|
|
128
|
+
if p.grad is None:
|
|
129
|
+
continue
|
|
130
|
+
|
|
131
|
+
# Apply a fixed update rule: subtract lr * grad
|
|
132
|
+
p.data.add_(p.grad.data, alpha=-group["lr"])
|
|
133
|
+
|
|
134
|
+
|
|
135
|
+
class MockDiLoCoTrainer(DiLoCoTrainer):
|
|
136
|
+
"""
|
|
137
|
+
A customized DiLoCoTrainer that uses mock components for deterministic parameter updates.
|
|
138
|
+
"""
|
|
139
|
+
|
|
140
|
+
def __init__(
|
|
141
|
+
self,
|
|
142
|
+
rank: int,
|
|
143
|
+
store_port: int,
|
|
144
|
+
device: torch.device,
|
|
145
|
+
runner: Runner,
|
|
146
|
+
model_state_dict: dict[str, Any],
|
|
147
|
+
n_fragments: int,
|
|
148
|
+
diloco_args: dict[str, Any],
|
|
149
|
+
inner_lr: float = 1,
|
|
150
|
+
outer_lr: float = 2,
|
|
151
|
+
quorum_barrier: Optional[threading.Barrier] = None,
|
|
152
|
+
) -> None:
|
|
153
|
+
self.inner_lr = inner_lr
|
|
154
|
+
self.outer_lr = outer_lr
|
|
155
|
+
|
|
156
|
+
# Call parent constructor
|
|
157
|
+
super().__init__(
|
|
158
|
+
rank, store_port, device, runner, model_state_dict, n_fragments, diloco_args
|
|
159
|
+
)
|
|
160
|
+
|
|
161
|
+
self.quorum_barrier = quorum_barrier
|
|
162
|
+
|
|
163
|
+
def setup_model(self) -> MockModel:
|
|
164
|
+
"""Set up the mock model and move it to the device."""
|
|
165
|
+
model = MockModel(in_dim=1, out_dim=1, n_layers=self.n_fragments)
|
|
166
|
+
model.load_state_dict(self.model_state_dict)
|
|
167
|
+
model.to(self.device)
|
|
168
|
+
return model
|
|
169
|
+
|
|
170
|
+
def setup_inner_optimizer(self) -> torch.optim.Optimizer:
|
|
171
|
+
"""Set up the mock inner optimizer."""
|
|
172
|
+
return MockOptimizer(self.model.parameters(), lr=self.inner_lr)
|
|
173
|
+
|
|
174
|
+
def setup_outer_optimizers(self) -> list[torch.optim.Optimizer]:
|
|
175
|
+
"""Set up mock outer optimizers."""
|
|
176
|
+
outer_optimizers = []
|
|
177
|
+
for i in range(self.n_fragments):
|
|
178
|
+
outer_optimizers.append(
|
|
179
|
+
MockOptimizer(self.model.layers[i].parameters(), lr=self.outer_lr)
|
|
180
|
+
)
|
|
181
|
+
return outer_optimizers
|
|
182
|
+
|
|
183
|
+
def train_loop(self) -> Dict[str, Any]:
|
|
184
|
+
"""Run the training loop with mocked components."""
|
|
185
|
+
# Ensure sync_every is set in diloco_args
|
|
186
|
+
if "sync_every" not in self.diloco_args:
|
|
187
|
+
self.diloco_args["sync_every"] = 2
|
|
188
|
+
|
|
189
|
+
parameter_history = {"history": {}, "global_parameter_history": {}}
|
|
190
|
+
|
|
191
|
+
with DiLoCo(
|
|
192
|
+
self.manager,
|
|
193
|
+
[layer for layer in self.model.layers],
|
|
194
|
+
self.inner_optimizer,
|
|
195
|
+
self.outer_optimizers,
|
|
196
|
+
backup_device=self.device,
|
|
197
|
+
**self.diloco_args,
|
|
198
|
+
) as self.diloco:
|
|
199
|
+
if self.quorum_barrier is not None:
|
|
200
|
+
self.manager.start_quorum()
|
|
201
|
+
self.manager.wait_quorum()
|
|
202
|
+
assert self.quorum_barrier is not None
|
|
203
|
+
self.quorum_barrier.wait()
|
|
204
|
+
assert self.manager.should_commit()
|
|
205
|
+
assert self.manager.should_commit()
|
|
206
|
+
|
|
207
|
+
local_step = 0
|
|
208
|
+
manager_steps = set()
|
|
209
|
+
while True:
|
|
210
|
+
# Capture parameters before each step
|
|
211
|
+
step_params = {}
|
|
212
|
+
for name, param in self.model.named_parameters():
|
|
213
|
+
step_params[name] = param.data.clone().detach().cpu().tolist()
|
|
214
|
+
parameter_history["history"][local_step] = step_params
|
|
215
|
+
|
|
216
|
+
manager_curr_step = self.manager.current_step()
|
|
217
|
+
|
|
218
|
+
if manager_curr_step == 7:
|
|
219
|
+
break
|
|
220
|
+
|
|
221
|
+
if manager_curr_step not in manager_steps:
|
|
222
|
+
# Store the manager state dict, converting to the right type
|
|
223
|
+
state_dict = copy.deepcopy(self.manager._manager_state_dict())
|
|
224
|
+
user_state_dict = cast(dict[str, object], state_dict["user"])
|
|
225
|
+
parameter_history["global_parameter_history"][local_step] = {}
|
|
226
|
+
|
|
227
|
+
for i in range(self.n_fragments):
|
|
228
|
+
value = cast(
|
|
229
|
+
dict[str, dict[str, torch.Tensor]],
|
|
230
|
+
user_state_dict[f"StreamingDiLoCoFragment_{i}"],
|
|
231
|
+
)
|
|
232
|
+
parameter_history["global_parameter_history"][local_step][
|
|
233
|
+
f"layers.{i}.weight"
|
|
234
|
+
] = (
|
|
235
|
+
value["original_parameters"]["weight"]
|
|
236
|
+
.data.clone()
|
|
237
|
+
.detach()
|
|
238
|
+
.cpu()
|
|
239
|
+
.tolist()
|
|
240
|
+
)
|
|
241
|
+
|
|
242
|
+
manager_steps.add(manager_curr_step)
|
|
243
|
+
|
|
244
|
+
# For each parameter, set a deterministic gradient
|
|
245
|
+
for _, layer in enumerate(self.model.layers):
|
|
246
|
+
if isinstance(layer, MockLinear):
|
|
247
|
+
# Set fixed gradients
|
|
248
|
+
layer.weight.grad = (
|
|
249
|
+
torch.ones_like(layer.weight) * layer.weight_grad_value
|
|
250
|
+
)
|
|
251
|
+
|
|
252
|
+
# Step with deterministic updates
|
|
253
|
+
self.inner_optimizer.step()
|
|
254
|
+
|
|
255
|
+
self.runner.event_injector.check(self.rank, self.manager.current_step())
|
|
256
|
+
local_step += 1
|
|
257
|
+
|
|
258
|
+
return parameter_history
|
|
259
|
+
|
|
260
|
+
|
|
261
|
+
def mock_diloco_train_loop(
|
|
262
|
+
rank: int,
|
|
263
|
+
store_port: int,
|
|
264
|
+
device: torch.device,
|
|
265
|
+
runner: Runner,
|
|
266
|
+
train_loop_args: dict[str, Any] = {},
|
|
267
|
+
) -> Dict[str, Dict[int, Dict[str, List[float]]]]:
|
|
268
|
+
"""
|
|
269
|
+
Training loop with mocked components for deterministic parameter updates.
|
|
270
|
+
Uses MockDiLoCoTrainer to handle the training process.
|
|
271
|
+
"""
|
|
272
|
+
model_state_dict = train_loop_args.get("model_state_dict", {})
|
|
273
|
+
n_fragments = train_loop_args.get("n_fragments", 1)
|
|
274
|
+
diloco_args = train_loop_args.get("diloco_args", {})
|
|
275
|
+
quorum_barrier = train_loop_args.get("quorum_barrier", None)
|
|
276
|
+
|
|
277
|
+
with ExitStack() as stack:
|
|
278
|
+
trainer = MockDiLoCoTrainer(
|
|
279
|
+
rank,
|
|
280
|
+
store_port,
|
|
281
|
+
device,
|
|
282
|
+
runner,
|
|
283
|
+
model_state_dict,
|
|
284
|
+
n_fragments,
|
|
285
|
+
diloco_args,
|
|
286
|
+
quorum_barrier=quorum_barrier,
|
|
287
|
+
)
|
|
288
|
+
stack.callback(trainer.manager.shutdown)
|
|
289
|
+
return trainer.train_loop()
|
|
290
|
+
return {}
|
|
291
|
+
|
|
292
|
+
|
|
293
|
+
class DiLoCoMockedUpdateTest(TestCase):
|
|
294
|
+
@parameterized.expand(
|
|
295
|
+
[
|
|
296
|
+
# Format: (use_cuda, n_fragments, fragment_sync_delay, fragment_update_alpha)
|
|
297
|
+
(False, 2, 0, 0), # 2 fragments, no delay, 0% mixing
|
|
298
|
+
(False, 2, 0, 0.5), # 2 fragments, no delay, 50% mixing
|
|
299
|
+
(False, 2, 0, 1), # 2 fragments, no delay, 100% mixing
|
|
300
|
+
(False, 2, 1, 0), # 2 fragments, with delay, 0% mixing
|
|
301
|
+
(False, 2, 1, 0.5), # 2 fragments, with delay, 50% mixing
|
|
302
|
+
(False, 2, 1, 1), # 2 fragments, with delay, 100% mixing
|
|
303
|
+
]
|
|
304
|
+
)
|
|
305
|
+
def test_diloco_mocked_updates(
|
|
306
|
+
self,
|
|
307
|
+
use_cuda: bool,
|
|
308
|
+
n_fragments: int,
|
|
309
|
+
fragment_sync_delay: int,
|
|
310
|
+
fragment_update_alpha: float,
|
|
311
|
+
) -> None:
|
|
312
|
+
"""
|
|
313
|
+
Test that validates the model parameters are correctly updated by DiLoCo
|
|
314
|
+
using mocked components for deterministic updates with different configurations:
|
|
315
|
+
- n_fragments: Number of model fragments (1 or 2)
|
|
316
|
+
- fragment_sync_delay: Delay between preparing and syncing fragments (0 or 1)
|
|
317
|
+
- fragment_update_alpha: Controls mixing of local and global parameters (0.0, 0.5, or 1.0)
|
|
318
|
+
"""
|
|
319
|
+
# Skip the test if use_cuda is True and there are not enough GPUs
|
|
320
|
+
if use_cuda and torch.cuda.device_count() < 2:
|
|
321
|
+
self.skipTest("Not enough GPUs for CUDA test")
|
|
322
|
+
if sys.platform == "darwin":
|
|
323
|
+
self.skipTest("not reliable on mac")
|
|
324
|
+
|
|
325
|
+
lighthouse = LighthouseServer(bind="[::]:0", min_replicas=2)
|
|
326
|
+
sync_every = 6
|
|
327
|
+
num_replicas = 2
|
|
328
|
+
futures = []
|
|
329
|
+
|
|
330
|
+
torch.manual_seed(42)
|
|
331
|
+
# Initialize the model with the specified number of fragments
|
|
332
|
+
# Create a proper state_dict for the model to avoid load_state_dict errors
|
|
333
|
+
temp_model = MockModel(in_dim=1, out_dim=1, n_layers=n_fragments)
|
|
334
|
+
model_state_dict = temp_model.state_dict()
|
|
335
|
+
quorum_barrier = threading.Barrier(num_replicas)
|
|
336
|
+
|
|
337
|
+
with ThreadPoolExecutor(max_workers=num_replicas) as executor:
|
|
338
|
+
for replica_id in range(num_replicas):
|
|
339
|
+
event_injector = EventInjector()
|
|
340
|
+
runner = Runner(
|
|
341
|
+
replica_id=replica_id,
|
|
342
|
+
num_replicas=num_replicas,
|
|
343
|
+
lighthouse_address=lighthouse.address(),
|
|
344
|
+
event_injector=event_injector,
|
|
345
|
+
train_loop=mock_diloco_train_loop,
|
|
346
|
+
use_cuda=use_cuda,
|
|
347
|
+
train_loop_args={
|
|
348
|
+
"quorum_barrier": quorum_barrier,
|
|
349
|
+
"n_fragments": n_fragments,
|
|
350
|
+
"model_state_dict": model_state_dict,
|
|
351
|
+
"diloco_args": {
|
|
352
|
+
"sync_every": sync_every,
|
|
353
|
+
"fragment_sync_delay": fragment_sync_delay,
|
|
354
|
+
"fragment_update_alpha": fragment_update_alpha,
|
|
355
|
+
},
|
|
356
|
+
},
|
|
357
|
+
)
|
|
358
|
+
futures.append(executor.submit(runner.run_replica))
|
|
359
|
+
|
|
360
|
+
for fut in as_completed(futures):
|
|
361
|
+
continue
|
|
362
|
+
|
|
363
|
+
results = []
|
|
364
|
+
for fut in futures:
|
|
365
|
+
results.append(fut.result())
|
|
366
|
+
|
|
367
|
+
lighthouse.shutdown()
|
|
368
|
+
|
|
369
|
+
# Check results against fixture or validate parameter updates
|
|
370
|
+
compared_with_fixture = self._check_against_fixture(results)
|
|
371
|
+
|
|
372
|
+
if not compared_with_fixture:
|
|
373
|
+
# If no fixture comparison was done, validate parameters directly
|
|
374
|
+
self._validate_parameter_updates(
|
|
375
|
+
results[0][0],
|
|
376
|
+
n_fragments,
|
|
377
|
+
sync_every,
|
|
378
|
+
fragment_sync_delay,
|
|
379
|
+
fragment_update_alpha,
|
|
380
|
+
)
|
|
381
|
+
|
|
382
|
+
@parameterized.expand(
|
|
383
|
+
[
|
|
384
|
+
# Format: (use_cuda, n_fragments, fragment_sync_delay, fragment_update_alpha)
|
|
385
|
+
(False, 2, 0, 0), # 2 fragments, no delay, 0% mixing
|
|
386
|
+
]
|
|
387
|
+
)
|
|
388
|
+
def test_diloco_mocked_failure_recovery(
|
|
389
|
+
self,
|
|
390
|
+
use_cuda: bool,
|
|
391
|
+
n_fragments: int,
|
|
392
|
+
fragment_sync_delay: int,
|
|
393
|
+
fragment_update_alpha: float,
|
|
394
|
+
) -> None:
|
|
395
|
+
"""
|
|
396
|
+
Test that validates DiLoCo can recover from a replica failure.
|
|
397
|
+
One replica is set to fail at step 2, and the test verifies that
|
|
398
|
+
the system recovers and parameters are correctly synchronized after recovery.
|
|
399
|
+
"""
|
|
400
|
+
# Skip the test if use_cuda is True and there are not enough GPUs
|
|
401
|
+
if use_cuda and torch.cuda.device_count() < 2:
|
|
402
|
+
self.skipTest("Not enough GPUs for CUDA test")
|
|
403
|
+
if sys.platform == "darwin":
|
|
404
|
+
self.skipTest("not reliable on mac")
|
|
405
|
+
|
|
406
|
+
lighthouse = LighthouseServer(bind="[::]:0", min_replicas=2)
|
|
407
|
+
sync_every = 6
|
|
408
|
+
num_replicas = 2
|
|
409
|
+
futures = []
|
|
410
|
+
|
|
411
|
+
# Create event injectors - make the second replica fail at step 2
|
|
412
|
+
event_injectors = [
|
|
413
|
+
EventInjector(), # First replica runs normally
|
|
414
|
+
EventInjector().fail_at(0, 2), # Second replica fails at step 2
|
|
415
|
+
]
|
|
416
|
+
|
|
417
|
+
torch.manual_seed(42)
|
|
418
|
+
# Initialize the model with the specified number of fragments
|
|
419
|
+
temp_model = MockModel(in_dim=1, out_dim=1, n_layers=n_fragments)
|
|
420
|
+
model_state_dict = temp_model.state_dict()
|
|
421
|
+
|
|
422
|
+
with ThreadPoolExecutor(max_workers=num_replicas) as executor:
|
|
423
|
+
for replica_id, event_injector in zip(range(num_replicas), event_injectors):
|
|
424
|
+
runner = Runner(
|
|
425
|
+
replica_id=replica_id,
|
|
426
|
+
num_replicas=num_replicas,
|
|
427
|
+
lighthouse_address=lighthouse.address(),
|
|
428
|
+
event_injector=event_injector,
|
|
429
|
+
train_loop=mock_diloco_train_loop,
|
|
430
|
+
use_cuda=use_cuda,
|
|
431
|
+
train_loop_args={
|
|
432
|
+
"n_fragments": n_fragments,
|
|
433
|
+
"model_state_dict": model_state_dict,
|
|
434
|
+
"diloco_args": {
|
|
435
|
+
"sync_every": sync_every,
|
|
436
|
+
"fragment_sync_delay": fragment_sync_delay,
|
|
437
|
+
"fragment_update_alpha": fragment_update_alpha,
|
|
438
|
+
},
|
|
439
|
+
},
|
|
440
|
+
)
|
|
441
|
+
futures.append(executor.submit(runner.run_replica))
|
|
442
|
+
|
|
443
|
+
# Wait for all futures to complete
|
|
444
|
+
for fut in as_completed(futures):
|
|
445
|
+
continue
|
|
446
|
+
|
|
447
|
+
results = []
|
|
448
|
+
for fut in futures:
|
|
449
|
+
try:
|
|
450
|
+
results.append(fut.result())
|
|
451
|
+
except Exception as e:
|
|
452
|
+
print(f"Error in replica: {e}")
|
|
453
|
+
raise
|
|
454
|
+
|
|
455
|
+
lighthouse.shutdown()
|
|
456
|
+
|
|
457
|
+
# Check results against fixture or validate failure recovery
|
|
458
|
+
compared_with_fixture = self._check_against_fixture(results)
|
|
459
|
+
|
|
460
|
+
if not compared_with_fixture:
|
|
461
|
+
# Verify that the failure was injected
|
|
462
|
+
self.assertEqual(
|
|
463
|
+
event_injectors[1].count[EventInjectorEvent.Failure],
|
|
464
|
+
1,
|
|
465
|
+
"Expected one failure event to be injected",
|
|
466
|
+
)
|
|
467
|
+
|
|
468
|
+
# Verify that both replicas have the same global parameters at the end
|
|
469
|
+
# Extract the global parameter history from both replicas
|
|
470
|
+
rep0_global = results[0][0]["global_parameter_history"]
|
|
471
|
+
rep1_global = results[1][0]["global_parameter_history"]
|
|
472
|
+
|
|
473
|
+
# Get the last step in both histories
|
|
474
|
+
last_step_rep0 = max(int(step) for step in rep0_global.keys())
|
|
475
|
+
last_step_rep1 = max(int(step) for step in rep1_global.keys())
|
|
476
|
+
|
|
477
|
+
# Compare the global parameters at the last step
|
|
478
|
+
for param_name in rep0_global[last_step_rep0].keys():
|
|
479
|
+
rep0_param = torch.tensor(rep0_global[last_step_rep0][param_name])
|
|
480
|
+
rep1_param = torch.tensor(rep1_global[last_step_rep1][param_name])
|
|
481
|
+
|
|
482
|
+
self.assertTrue(
|
|
483
|
+
torch.allclose(rep0_param, rep1_param, rtol=1e-5, atol=1e-5),
|
|
484
|
+
f"Global parameters don't match at the end: {rep0_param} vs {rep1_param} for {param_name}",
|
|
485
|
+
)
|
|
486
|
+
|
|
487
|
+
def _check_against_fixture(
|
|
488
|
+
self, results: list[list[Dict[str, Dict[int, Dict[str, List[float]]]]]]
|
|
489
|
+
) -> bool:
|
|
490
|
+
"""
|
|
491
|
+
Check test results against fixture data.
|
|
492
|
+
|
|
493
|
+
Args:
|
|
494
|
+
results: The test results to check against fixture
|
|
495
|
+
|
|
496
|
+
Returns:
|
|
497
|
+
bool: True if comparison with fixture was performed, False otherwise
|
|
498
|
+
"""
|
|
499
|
+
# Handle fixture reading/writing
|
|
500
|
+
fixture_data = handle_fixture(f"{self.id()}.json", results)
|
|
501
|
+
|
|
502
|
+
# If no fixture exists we can't compare results
|
|
503
|
+
if fixture_data is None:
|
|
504
|
+
return False
|
|
505
|
+
|
|
506
|
+
# Compare fixture data with current results
|
|
507
|
+
for replica_idx, (fixture_history, current_history) in enumerate(
|
|
508
|
+
zip(fixture_data, results)
|
|
509
|
+
):
|
|
510
|
+
fixture_history = fixture_history[0]["history"]
|
|
511
|
+
current_history = current_history[0]["history"]
|
|
512
|
+
for step, fixture_params in fixture_history.items():
|
|
513
|
+
for param_name, fixture_values in fixture_params.items():
|
|
514
|
+
current_values = current_history[int(step)][param_name]
|
|
515
|
+
# Convert to tensors for comparison with tolerance
|
|
516
|
+
fixture_tensor = torch.tensor(fixture_values)
|
|
517
|
+
current_tensor = torch.tensor(current_values)
|
|
518
|
+
self.assertTrue(
|
|
519
|
+
torch.allclose(
|
|
520
|
+
fixture_tensor, current_tensor, rtol=1e-5, atol=1e-5
|
|
521
|
+
),
|
|
522
|
+
f"{fixture_tensor} is not the same as {current_tensor} for {param_name} at step {step}",
|
|
523
|
+
)
|
|
524
|
+
|
|
525
|
+
return True
|
|
526
|
+
|
|
527
|
+
def _validate_parameter_updates(
|
|
528
|
+
self,
|
|
529
|
+
parameter_history: Dict[str, Dict[int, Dict[str, List[float]]]],
|
|
530
|
+
n_fragments: int,
|
|
531
|
+
sync_every: int,
|
|
532
|
+
fragment_sync_delay: int,
|
|
533
|
+
fragment_update_alpha: float,
|
|
534
|
+
) -> None:
|
|
535
|
+
"""
|
|
536
|
+
Validate that model parameters are updated as expected according to DiLoCo algorithm.
|
|
537
|
+
Validates both regular steps (inner optimizer updates) and sync steps (outer optimizer updates).
|
|
538
|
+
|
|
539
|
+
Args:
|
|
540
|
+
history: Parameter history for a replica
|
|
541
|
+
num_replicas: Total number of replicas
|
|
542
|
+
n_fragments: Number of model fragments
|
|
543
|
+
sync_every: How often to sync parameters
|
|
544
|
+
fragment_sync_delay: Delay between preparing and syncing fragments
|
|
545
|
+
fragment_update_alpha: Controls mixing of local and global parameters
|
|
546
|
+
"""
|
|
547
|
+
# Sync happens every sync_every steps for each fragment
|
|
548
|
+
sync_every_fragment = sync_every // n_fragments
|
|
549
|
+
|
|
550
|
+
history = parameter_history["history"]
|
|
551
|
+
global_parameter_history = parameter_history["global_parameter_history"]
|
|
552
|
+
|
|
553
|
+
# For each step in history, validate parameter updates
|
|
554
|
+
for step in range(1, 16): # Skip step 0 (initial state)
|
|
555
|
+
for fragment_param_name in history[step].keys():
|
|
556
|
+
# Get current parameters
|
|
557
|
+
fragment_idx = int(fragment_param_name.split(".")[1]) + 1
|
|
558
|
+
current_params = torch.tensor(history[step][fragment_param_name])
|
|
559
|
+
|
|
560
|
+
# Determine if this is a sync step for this fragment
|
|
561
|
+
# In DiLoCo, fragments are synced in a round-robin fashion
|
|
562
|
+
# Fragment i is synced at steps: i*sync_every_fragment + k*sync_every
|
|
563
|
+
# where k is a non-negative integer
|
|
564
|
+
is_sync_step = (
|
|
565
|
+
step - fragment_idx * sync_every_fragment
|
|
566
|
+
) % sync_every == 0
|
|
567
|
+
|
|
568
|
+
if is_sync_step:
|
|
569
|
+
# This is a sync step for this fragment
|
|
570
|
+
# Find the previous sync step for this fragment
|
|
571
|
+
prev_sync_step = max(step - sync_every, 0)
|
|
572
|
+
|
|
573
|
+
# Find the prepare step for this fragment (when pseudogradients were calculated)
|
|
574
|
+
prepare_step = step - fragment_sync_delay
|
|
575
|
+
|
|
576
|
+
# Parameters at the previous sync step (global parameters before update)
|
|
577
|
+
prev_sync_params = torch.tensor(
|
|
578
|
+
global_parameter_history[prev_sync_step][fragment_param_name]
|
|
579
|
+
)
|
|
580
|
+
|
|
581
|
+
# Parameters at the prepare step (before allreduce)
|
|
582
|
+
prepare_params = (
|
|
583
|
+
torch.tensor(history[prepare_step - 1][fragment_param_name]) - 2
|
|
584
|
+
) # inner_lr (1) * weight_grad_value (2)
|
|
585
|
+
|
|
586
|
+
# Calculate pseudogradient (difference between global and local params)
|
|
587
|
+
pseudogradient = prev_sync_params - prepare_params
|
|
588
|
+
|
|
589
|
+
# After allreduce, pseudogradient is averaged across replicas
|
|
590
|
+
# In our mock setup, all replicas have the same gradient, so no averaging is needed
|
|
591
|
+
averaged_pseudogradient = pseudogradient
|
|
592
|
+
|
|
593
|
+
# Outer optimizer applies this pseudogradient with its learning rate
|
|
594
|
+
outer_lr = 2
|
|
595
|
+
|
|
596
|
+
# Calculate expected global parameters after outer optimizer update
|
|
597
|
+
expected_global_params = (
|
|
598
|
+
prev_sync_params - outer_lr * averaged_pseudogradient
|
|
599
|
+
)
|
|
600
|
+
|
|
601
|
+
prev_params = torch.tensor(history[step - 1][fragment_param_name])
|
|
602
|
+
local_params = (
|
|
603
|
+
prev_params - 2
|
|
604
|
+
) # inner_lr (1) * weight_grad_value (2)
|
|
605
|
+
|
|
606
|
+
# lerp: result = global_params * fragment_update_alpha + local_params * (1 - fragment_update_alpha)
|
|
607
|
+
expected_params = (
|
|
608
|
+
local_params * fragment_update_alpha
|
|
609
|
+
+ expected_global_params * (1 - fragment_update_alpha)
|
|
610
|
+
)
|
|
611
|
+
|
|
612
|
+
# Validate synced parameters
|
|
613
|
+
self.assertTrue(
|
|
614
|
+
torch.allclose(
|
|
615
|
+
current_params, expected_params, rtol=1e-5, atol=1e-5
|
|
616
|
+
),
|
|
617
|
+
f"Parameters at sync step {step} for fragment {fragment_param_name} "
|
|
618
|
+
f"don't match expected: {current_params} vs {expected_params}. "
|
|
619
|
+
f"{prepare_params=}, {prev_sync_params=}, {pseudogradient=}, {averaged_pseudogradient=}, {expected_global_params=}",
|
|
620
|
+
)
|
|
621
|
+
else:
|
|
622
|
+
# Get previous parameters
|
|
623
|
+
prev_params = torch.tensor(history[step - 1][fragment_param_name])
|
|
624
|
+
|
|
625
|
+
# Regular step (inner optimizer update)
|
|
626
|
+
# In our mock setup, each step parameters change by -lr * grad = -1 * 2 = -2
|
|
627
|
+
expected_params = (
|
|
628
|
+
prev_params - 2
|
|
629
|
+
) # inner_lr (1) * weight_grad_value (2)
|
|
630
|
+
|
|
631
|
+
# Validate synced parameters
|
|
632
|
+
self.assertTrue(
|
|
633
|
+
torch.allclose(
|
|
634
|
+
current_params, expected_params, rtol=1e-5, atol=1e-5
|
|
635
|
+
),
|
|
636
|
+
f"Parameters at sync step {step} for fragment {fragment_param_name} "
|
|
637
|
+
f"don't match expected: {current_params} vs {expected_params}. ",
|
|
638
|
+
)
|
|
639
|
+
|
|
640
|
+
|
|
641
|
+
if __name__ == "__main__":
|
|
642
|
+
import unittest
|
|
643
|
+
|
|
644
|
+
unittest.main()
|
|
@@ -0,0 +1,34 @@
|
|
|
1
|
+
## Launch lighthouse
|
|
2
|
+
|
|
3
|
+
Run this command to launch the lighthouse somewhere and make sure other slurm nodes have access to this node
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
```
|
|
7
|
+
RUST_BACKTRACE=1 torchft_lighthouse --min_replicas 1 --quorum_tick_ms 100 --join_timeout_ms 10000
|
|
8
|
+
```
|
|
9
|
+
|
|
10
|
+
## Launch training
|
|
11
|
+
|
|
12
|
+
First, go to your local torchtitan folder and run
|
|
13
|
+
|
|
14
|
+
```bash
|
|
15
|
+
$ pip install -r requirements.txt
|
|
16
|
+
$ pip install .
|
|
17
|
+
```
|
|
18
|
+
|
|
19
|
+
Run the following command to launch torchft lighthouse and replicas using torchtitan on slurm
|
|
20
|
+
|
|
21
|
+
```bash
|
|
22
|
+
$ pip install torchx-nightly
|
|
23
|
+
$ # Set the address of the lighthouse server e.g.
|
|
24
|
+
$ export TORCHFT_LIGHTHOUSE=http://slurm-head-node-0:29510
|
|
25
|
+
$ python runner.py --workspace-dir=/path/to/torchtitan/folder --nodes=1 --nproc-per-node=8 --replica-count=2
|
|
26
|
+
```
|
|
27
|
+
|
|
28
|
+
## Test fault tolerance
|
|
29
|
+
|
|
30
|
+
To inject some failures, you can use the following command
|
|
31
|
+
|
|
32
|
+
```bash
|
|
33
|
+
$ python punisher.py kill_loop --num-failures=10 --mtbf-secs=300
|
|
34
|
+
```
|