torchft-nightly 2026.1.3__cp310-cp310-manylinux_2_24_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchft/__init__.py +34 -0
- torchft/_test/diloco_trainer.py +287 -0
- torchft/_test/managed_work_test.py +320 -0
- torchft/_test_utils.py +111 -0
- torchft/_torchft.cpython-310-x86_64-linux-gnu.so +0 -0
- torchft/_torchft.pyi +116 -0
- torchft/checkpointing/__init__.py +20 -0
- torchft/checkpointing/_rwlock.py +136 -0
- torchft/checkpointing/_serialization.py +39 -0
- torchft/checkpointing/http_transport.py +299 -0
- torchft/checkpointing/http_transport_bench.py +61 -0
- torchft/checkpointing/http_transport_test.py +146 -0
- torchft/checkpointing/pg_transport.py +306 -0
- torchft/checkpointing/pg_transport_bench.py +99 -0
- torchft/checkpointing/pg_transport_test.py +101 -0
- torchft/checkpointing/rwlock_test.py +58 -0
- torchft/checkpointing/transport.py +68 -0
- torchft/checkpointing/transport_test.py +161 -0
- torchft/collectives.py +415 -0
- torchft/collectives_test.py +212 -0
- torchft/coordination.py +39 -0
- torchft/coordination_test.py +29 -0
- torchft/data.py +77 -0
- torchft/data_test.py +39 -0
- torchft/ddp.py +105 -0
- torchft/ddp_test.py +68 -0
- torchft/diloco_regression_test.py +644 -0
- torchft/examples/slurm/README.md +34 -0
- torchft/examples/slurm/punisher.py +95 -0
- torchft/examples/slurm/runner.py +221 -0
- torchft/fsdp_test.py +102 -0
- torchft/futures.py +353 -0
- torchft/futures_test.py +140 -0
- torchft/http.py +13 -0
- torchft/lighthouse_test.py +163 -0
- torchft/local_sgd.py +796 -0
- torchft/local_sgd_integ_test.py +600 -0
- torchft/local_sgd_test.py +324 -0
- torchft/manager.py +1358 -0
- torchft/manager_integ_test.py +653 -0
- torchft/manager_test.py +911 -0
- torchft/multiprocessing.py +38 -0
- torchft/multiprocessing_dummy_context.py +135 -0
- torchft/multiprocessing_test.py +58 -0
- torchft/optim.py +63 -0
- torchft/optim_test.py +50 -0
- torchft/otel.py +134 -0
- torchft/parameter_server.py +195 -0
- torchft/parameter_server_test.py +47 -0
- torchft/process_group.py +2118 -0
- torchft/process_group_test.py +1028 -0
- torchft/quantization.py +686 -0
- torchft/quantization_test.py +131 -0
- torchft/torchx.py +89 -0
- torchft/utils.py +67 -0
- torchft/work.py +26 -0
- torchft_nightly-2026.1.3.dist-info/METADATA +308 -0
- torchft_nightly-2026.1.3.dist-info/RECORD +61 -0
- torchft_nightly-2026.1.3.dist-info/WHEEL +4 -0
- torchft_nightly-2026.1.3.dist-info/entry_points.txt +2 -0
- torchft_nightly-2026.1.3.dist-info/licenses/LICENSE +34 -0
|
@@ -0,0 +1,324 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
from typing import Dict
|
|
8
|
+
from unittest import TestCase
|
|
9
|
+
from unittest.mock import create_autospec, MagicMock
|
|
10
|
+
|
|
11
|
+
import torch
|
|
12
|
+
from parameterized import parameterized
|
|
13
|
+
from torch import nn, optim, Tensor
|
|
14
|
+
from torch.distributed.distributed_c10d import Work
|
|
15
|
+
from torch.distributed.tensor import DTensor
|
|
16
|
+
|
|
17
|
+
from torchft.local_sgd import DiLoCo, extract_local_tensor, LocalSGD
|
|
18
|
+
from torchft.manager import Manager
|
|
19
|
+
from torchft.work import _DummyWork
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
def create_manager() -> MagicMock:
|
|
23
|
+
"""
|
|
24
|
+
Creates a mock manager with some useful defaults for testing
|
|
25
|
+
the optimizer's usage of the Manager
|
|
26
|
+
"""
|
|
27
|
+
manager = create_autospec(Manager)
|
|
28
|
+
|
|
29
|
+
manager.errored.return_value = None
|
|
30
|
+
|
|
31
|
+
def mock_allreduce(tensor: torch.Tensor, should_quantize: bool = False) -> Work:
|
|
32
|
+
return _DummyWork(tensor)
|
|
33
|
+
|
|
34
|
+
manager.allreduce.side_effect = mock_allreduce
|
|
35
|
+
|
|
36
|
+
return manager
|
|
37
|
+
|
|
38
|
+
|
|
39
|
+
class SimpleModel(nn.Module):
|
|
40
|
+
def __init__(self) -> None:
|
|
41
|
+
super().__init__()
|
|
42
|
+
|
|
43
|
+
self.model = nn.Sequential(
|
|
44
|
+
nn.Linear(3, 4),
|
|
45
|
+
nn.ReLU(),
|
|
46
|
+
nn.Linear(4, 5),
|
|
47
|
+
nn.Sigmoid(),
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
51
|
+
return self.model(x)
|
|
52
|
+
|
|
53
|
+
|
|
54
|
+
def _params_dict(m: torch.nn.Module) -> Dict[str, torch.Tensor]:
|
|
55
|
+
return {name: p.data for name, p in m.named_parameters()}
|
|
56
|
+
|
|
57
|
+
|
|
58
|
+
def _copy_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
|
|
59
|
+
return {name: value.clone().detach() for name, value in state_dict.items()}
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
class TinyModel(nn.Module):
|
|
63
|
+
def __init__(self) -> None:
|
|
64
|
+
super().__init__()
|
|
65
|
+
self.w1 = nn.Parameter(torch.tensor([1.0, 2.0]))
|
|
66
|
+
self.w2 = nn.Parameter(torch.tensor([3.0, 4.0, 5.0]))
|
|
67
|
+
|
|
68
|
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
|
69
|
+
return x @ self.w1.unsqueeze(0).T + self.w2.sum()
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class LocalSGDTest(TestCase):
|
|
73
|
+
def test_local_sgd_healthy(self) -> None:
|
|
74
|
+
model = SimpleModel()
|
|
75
|
+
optimizer = optim.SGD(model.parameters())
|
|
76
|
+
manager = create_manager()
|
|
77
|
+
with LocalSGD(manager, model, optimizer, sync_every=2) as local_sgd:
|
|
78
|
+
self.assertEqual(local_sgd._local_step, 0)
|
|
79
|
+
inp = torch.rand(2, 3)
|
|
80
|
+
loss = model(inp).mean()
|
|
81
|
+
loss.backward()
|
|
82
|
+
optimizer.step()
|
|
83
|
+
|
|
84
|
+
self.assertEqual(local_sgd._local_step, 1)
|
|
85
|
+
self.assertEqual(manager.start_quorum.call_count, 0)
|
|
86
|
+
loss = model(inp).mean()
|
|
87
|
+
loss.backward()
|
|
88
|
+
optimizer.step()
|
|
89
|
+
self.assertEqual(manager.start_quorum.call_count, 1)
|
|
90
|
+
|
|
91
|
+
manager.should_commit.return_value = True
|
|
92
|
+
self.assertEqual(local_sgd._local_step, 0)
|
|
93
|
+
self.assertEqual(manager.should_commit.call_count, 1)
|
|
94
|
+
self.assertEqual(manager.allreduce.call_count, 4)
|
|
95
|
+
|
|
96
|
+
def test_extract_local_tensor(self) -> None:
|
|
97
|
+
regular_tensor = torch.rand(3, 3, requires_grad=True)
|
|
98
|
+
regular_result = extract_local_tensor(regular_tensor)
|
|
99
|
+
|
|
100
|
+
self.assertTrue(torch.equal(regular_result, regular_tensor))
|
|
101
|
+
self.assertIsNone(regular_result.grad)
|
|
102
|
+
self.assertNotEqual(id(regular_result), id(regular_tensor))
|
|
103
|
+
local_tensor = torch.rand(3, 3, requires_grad=True)
|
|
104
|
+
dtensor = MagicMock(spec=DTensor)
|
|
105
|
+
dtensor.to_local.return_value = local_tensor
|
|
106
|
+
dtensor_result = extract_local_tensor(dtensor)
|
|
107
|
+
|
|
108
|
+
self.assertTrue(torch.equal(dtensor_result, local_tensor))
|
|
109
|
+
self.assertIsNone(dtensor_result.grad)
|
|
110
|
+
self.assertNotEqual(id(dtensor_result), id(local_tensor))
|
|
111
|
+
dtensor.to_local.assert_called_once()
|
|
112
|
+
|
|
113
|
+
def test_local_sgd_recovery(self) -> None:
|
|
114
|
+
model = SimpleModel()
|
|
115
|
+
optimizer = optim.SGD(model.parameters())
|
|
116
|
+
manager = create_autospec(Manager)
|
|
117
|
+
|
|
118
|
+
with LocalSGD(manager, model, optimizer, sync_every=2) as local_sgd:
|
|
119
|
+
og_state_dict = _copy_state_dict(model.state_dict())
|
|
120
|
+
|
|
121
|
+
inp = torch.rand(2, 3)
|
|
122
|
+
|
|
123
|
+
loss = model(inp).mean()
|
|
124
|
+
loss.backward()
|
|
125
|
+
optimizer.step()
|
|
126
|
+
|
|
127
|
+
# Check that the model's state dict has been updated
|
|
128
|
+
for name, param in model.state_dict().items():
|
|
129
|
+
# Ensure the parameter has changed
|
|
130
|
+
self.assertFalse(
|
|
131
|
+
torch.equal(og_state_dict[name], param),
|
|
132
|
+
f"Parameter {name} did not change.",
|
|
133
|
+
)
|
|
134
|
+
self.assertEqual(local_sgd._local_step, 1)
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
class DiLoCoTest(TestCase):
|
|
138
|
+
def test_diloco_healthy(self) -> None:
|
|
139
|
+
model = SimpleModel()
|
|
140
|
+
|
|
141
|
+
# Setup optimizers
|
|
142
|
+
inner_optimizer = torch.optim.AdamW(
|
|
143
|
+
model.parameters(), lr=4e-4, weight_decay=0.1, betas=(0.9, 0.95)
|
|
144
|
+
)
|
|
145
|
+
outer_optimizer = torch.optim.SGD(
|
|
146
|
+
model.parameters(), lr=0.7, momentum=0.9, nesterov=True
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
manager = create_manager()
|
|
150
|
+
manager._use_async_quorum = False
|
|
151
|
+
with DiLoCo(
|
|
152
|
+
manager, [model], inner_optimizer, outer_optimizer, sync_every=2
|
|
153
|
+
) as diloco:
|
|
154
|
+
parameter_count = len(list(model.parameters()))
|
|
155
|
+
initial_outer_opt_state = outer_optimizer.state_dict()
|
|
156
|
+
self.assertEqual(initial_outer_opt_state["state"], {})
|
|
157
|
+
|
|
158
|
+
self.assertEqual(diloco._local_step, 0)
|
|
159
|
+
torch.testing.assert_close(
|
|
160
|
+
diloco._fragments[0].original_parameters, _params_dict(model)
|
|
161
|
+
)
|
|
162
|
+
inp = torch.rand(2, 3)
|
|
163
|
+
loss = model(inp).mean()
|
|
164
|
+
loss.backward()
|
|
165
|
+
inner_optimizer.step()
|
|
166
|
+
|
|
167
|
+
self.assertEqual(diloco._local_step, 1)
|
|
168
|
+
manager.current_step.return_value = 0
|
|
169
|
+
manager.should_commit.return_value = True
|
|
170
|
+
loss = model(inp).mean()
|
|
171
|
+
loss.backward()
|
|
172
|
+
inner_optimizer.step()
|
|
173
|
+
|
|
174
|
+
self.assertEqual(diloco._local_step, 0)
|
|
175
|
+
self.assertEqual(manager.start_quorum.call_count, 1)
|
|
176
|
+
torch.testing.assert_close(
|
|
177
|
+
diloco._fragments[0].original_parameters, _params_dict(model)
|
|
178
|
+
)
|
|
179
|
+
self.assertEqual(manager.should_commit.call_count, 1)
|
|
180
|
+
self.assertEqual(manager.allreduce.call_count, parameter_count)
|
|
181
|
+
|
|
182
|
+
outer_opt_state = outer_optimizer.state_dict()
|
|
183
|
+
self.assertEqual(len(outer_opt_state["state"]), parameter_count)
|
|
184
|
+
|
|
185
|
+
@parameterized.expand(
|
|
186
|
+
[
|
|
187
|
+
("bucketized_should_use_fewer_calls", True, True),
|
|
188
|
+
("non_bucketized_should_call_per_param", False, False),
|
|
189
|
+
]
|
|
190
|
+
)
|
|
191
|
+
def test_diloco_allreduce_call_efficiency(
|
|
192
|
+
self,
|
|
193
|
+
name: str,
|
|
194
|
+
use_bucketization: bool,
|
|
195
|
+
expect_fewer_calls: bool,
|
|
196
|
+
) -> None:
|
|
197
|
+
model = SimpleModel()
|
|
198
|
+
|
|
199
|
+
inner_optimizer = torch.optim.AdamW(
|
|
200
|
+
model.parameters(), lr=4e-4, weight_decay=0.1, betas=(0.9, 0.95)
|
|
201
|
+
)
|
|
202
|
+
outer_optimizer = torch.optim.SGD(
|
|
203
|
+
model.parameters(), lr=0.7, momentum=0.9, nesterov=True
|
|
204
|
+
)
|
|
205
|
+
|
|
206
|
+
manager = create_manager()
|
|
207
|
+
manager._use_async_quorum = False
|
|
208
|
+
manager.should_commit.return_value = True
|
|
209
|
+
|
|
210
|
+
with DiLoCo(
|
|
211
|
+
manager,
|
|
212
|
+
[model],
|
|
213
|
+
inner_optimizer,
|
|
214
|
+
outer_optimizer,
|
|
215
|
+
sync_every=2,
|
|
216
|
+
use_bucketization=use_bucketization,
|
|
217
|
+
) as diloco:
|
|
218
|
+
inp = torch.rand(2, 3)
|
|
219
|
+
loss = model(inp).mean()
|
|
220
|
+
loss.backward()
|
|
221
|
+
inner_optimizer.step()
|
|
222
|
+
|
|
223
|
+
manager.current_step.return_value = 0
|
|
224
|
+
loss = model(inp).mean()
|
|
225
|
+
loss.backward()
|
|
226
|
+
inner_optimizer.step()
|
|
227
|
+
|
|
228
|
+
loss = model(inp).mean()
|
|
229
|
+
loss.backward()
|
|
230
|
+
inner_optimizer.step()
|
|
231
|
+
|
|
232
|
+
allreduce_calls = manager.allreduce.call_count
|
|
233
|
+
param_count = len([p for p in model.parameters() if p.requires_grad])
|
|
234
|
+
|
|
235
|
+
if expect_fewer_calls:
|
|
236
|
+
self.assertLess(int(allreduce_calls), int(param_count))
|
|
237
|
+
else:
|
|
238
|
+
self.assertEqual(int(allreduce_calls), int(param_count))
|
|
239
|
+
|
|
240
|
+
def test_bucketization_correctness(self) -> None:
|
|
241
|
+
model = TinyModel()
|
|
242
|
+
inner_opt = torch.optim.SGD(model.parameters(), lr=0.1)
|
|
243
|
+
outer_opt = torch.optim.SGD(model.parameters(), lr=0.1)
|
|
244
|
+
|
|
245
|
+
manager = create_autospec(Manager)
|
|
246
|
+
manager._use_async_quorum = False
|
|
247
|
+
manager.should_commit.return_value = True
|
|
248
|
+
|
|
249
|
+
# Define fake allreduce: multiplies buffer by 2
|
|
250
|
+
def fake_allreduce(tensor: Tensor, should_quantize: bool) -> Work:
|
|
251
|
+
tensor.mul_(2)
|
|
252
|
+
return _DummyWork(tensor)
|
|
253
|
+
|
|
254
|
+
manager.allreduce.side_effect = fake_allreduce
|
|
255
|
+
|
|
256
|
+
diloco = DiLoCo(
|
|
257
|
+
manager, [model], inner_opt, outer_opt, sync_every=2, use_bucketization=True
|
|
258
|
+
)
|
|
259
|
+
diloco._fragments[0].bucket_cap_mb = 10 * 1024 * 1024
|
|
260
|
+
|
|
261
|
+
# Manually assign fake gradients
|
|
262
|
+
grads = [torch.tensor([1.0, 2.0]), torch.tensor([3.0, 4.0, 5.0])]
|
|
263
|
+
for g, (name, param) in zip(grads, model.named_parameters()):
|
|
264
|
+
diloco._fragments[0]._grads[name] = g.clone()
|
|
265
|
+
|
|
266
|
+
# Run only bucketized logic
|
|
267
|
+
diloco._fragments[0]._average_grads()
|
|
268
|
+
|
|
269
|
+
# The parameter gradients should not be set
|
|
270
|
+
for param in model.parameters():
|
|
271
|
+
self.assertEqual(param.grad, None)
|
|
272
|
+
|
|
273
|
+
diloco._fragments[0]._set_grads()
|
|
274
|
+
|
|
275
|
+
# Expect grads to have been doubled
|
|
276
|
+
expected_grads = [g * 2 for g in grads]
|
|
277
|
+
for param, expected in zip(model.parameters(), expected_grads):
|
|
278
|
+
torch.testing.assert_close(param.grad, expected, rtol=1e-5, atol=1e-8)
|
|
279
|
+
|
|
280
|
+
def test_gradient_correctness(self) -> None:
|
|
281
|
+
model = TinyModel()
|
|
282
|
+
inner_opt = torch.optim.SGD(model.parameters(), lr=0.1)
|
|
283
|
+
outer_opt = torch.optim.SGD(model.parameters(), lr=0.1)
|
|
284
|
+
|
|
285
|
+
manager = create_autospec(Manager)
|
|
286
|
+
manager._use_async_quorum = False
|
|
287
|
+
manager.should_commit.return_value = True
|
|
288
|
+
|
|
289
|
+
# Define fake allreduce: multiplies buffer by 2
|
|
290
|
+
def fake_allreduce(tensor: Tensor, should_quantize: bool) -> Work:
|
|
291
|
+
tensor.mul_(2)
|
|
292
|
+
return _DummyWork(tensor)
|
|
293
|
+
|
|
294
|
+
manager.allreduce.side_effect = fake_allreduce
|
|
295
|
+
|
|
296
|
+
diloco = DiLoCo(manager, [model], inner_opt, outer_opt, sync_every=2)
|
|
297
|
+
|
|
298
|
+
# save original parameters
|
|
299
|
+
diloco._fragments[0].save_parameters()
|
|
300
|
+
|
|
301
|
+
# change the model's parameters
|
|
302
|
+
for p in model.parameters():
|
|
303
|
+
p.data.add_(2)
|
|
304
|
+
|
|
305
|
+
# calculate and set the gradients
|
|
306
|
+
diloco._fragments[0]._save_grads()
|
|
307
|
+
|
|
308
|
+
# calculate
|
|
309
|
+
diloco._fragments[0]._average_grads()
|
|
310
|
+
|
|
311
|
+
# The parameter gradients should not be set
|
|
312
|
+
for param in model.parameters():
|
|
313
|
+
self.assertEqual(param.grad, None)
|
|
314
|
+
|
|
315
|
+
diloco._fragments[0]._set_grads()
|
|
316
|
+
|
|
317
|
+
# we added 2 to the parameters, then multiplied the gradients by 2
|
|
318
|
+
# so we should expect the model's gradient to be -4
|
|
319
|
+
expected_grad = -4
|
|
320
|
+
for param in model.parameters():
|
|
321
|
+
assert param.grad is not None
|
|
322
|
+
t = torch.empty_like(param.grad)
|
|
323
|
+
t.fill_(expected_grad)
|
|
324
|
+
torch.testing.assert_close(param.grad, t)
|