torchft-nightly 2026.1.3__cp310-cp310-manylinux_2_24_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. torchft/__init__.py +34 -0
  2. torchft/_test/diloco_trainer.py +287 -0
  3. torchft/_test/managed_work_test.py +320 -0
  4. torchft/_test_utils.py +111 -0
  5. torchft/_torchft.cpython-310-x86_64-linux-gnu.so +0 -0
  6. torchft/_torchft.pyi +116 -0
  7. torchft/checkpointing/__init__.py +20 -0
  8. torchft/checkpointing/_rwlock.py +136 -0
  9. torchft/checkpointing/_serialization.py +39 -0
  10. torchft/checkpointing/http_transport.py +299 -0
  11. torchft/checkpointing/http_transport_bench.py +61 -0
  12. torchft/checkpointing/http_transport_test.py +146 -0
  13. torchft/checkpointing/pg_transport.py +306 -0
  14. torchft/checkpointing/pg_transport_bench.py +99 -0
  15. torchft/checkpointing/pg_transport_test.py +101 -0
  16. torchft/checkpointing/rwlock_test.py +58 -0
  17. torchft/checkpointing/transport.py +68 -0
  18. torchft/checkpointing/transport_test.py +161 -0
  19. torchft/collectives.py +415 -0
  20. torchft/collectives_test.py +212 -0
  21. torchft/coordination.py +39 -0
  22. torchft/coordination_test.py +29 -0
  23. torchft/data.py +77 -0
  24. torchft/data_test.py +39 -0
  25. torchft/ddp.py +105 -0
  26. torchft/ddp_test.py +68 -0
  27. torchft/diloco_regression_test.py +644 -0
  28. torchft/examples/slurm/README.md +34 -0
  29. torchft/examples/slurm/punisher.py +95 -0
  30. torchft/examples/slurm/runner.py +221 -0
  31. torchft/fsdp_test.py +102 -0
  32. torchft/futures.py +353 -0
  33. torchft/futures_test.py +140 -0
  34. torchft/http.py +13 -0
  35. torchft/lighthouse_test.py +163 -0
  36. torchft/local_sgd.py +796 -0
  37. torchft/local_sgd_integ_test.py +600 -0
  38. torchft/local_sgd_test.py +324 -0
  39. torchft/manager.py +1358 -0
  40. torchft/manager_integ_test.py +653 -0
  41. torchft/manager_test.py +911 -0
  42. torchft/multiprocessing.py +38 -0
  43. torchft/multiprocessing_dummy_context.py +135 -0
  44. torchft/multiprocessing_test.py +58 -0
  45. torchft/optim.py +63 -0
  46. torchft/optim_test.py +50 -0
  47. torchft/otel.py +134 -0
  48. torchft/parameter_server.py +195 -0
  49. torchft/parameter_server_test.py +47 -0
  50. torchft/process_group.py +2118 -0
  51. torchft/process_group_test.py +1028 -0
  52. torchft/quantization.py +686 -0
  53. torchft/quantization_test.py +131 -0
  54. torchft/torchx.py +89 -0
  55. torchft/utils.py +67 -0
  56. torchft/work.py +26 -0
  57. torchft_nightly-2026.1.3.dist-info/METADATA +308 -0
  58. torchft_nightly-2026.1.3.dist-info/RECORD +61 -0
  59. torchft_nightly-2026.1.3.dist-info/WHEEL +4 -0
  60. torchft_nightly-2026.1.3.dist-info/entry_points.txt +2 -0
  61. torchft_nightly-2026.1.3.dist-info/licenses/LICENSE +34 -0
@@ -0,0 +1,324 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ from typing import Dict
8
+ from unittest import TestCase
9
+ from unittest.mock import create_autospec, MagicMock
10
+
11
+ import torch
12
+ from parameterized import parameterized
13
+ from torch import nn, optim, Tensor
14
+ from torch.distributed.distributed_c10d import Work
15
+ from torch.distributed.tensor import DTensor
16
+
17
+ from torchft.local_sgd import DiLoCo, extract_local_tensor, LocalSGD
18
+ from torchft.manager import Manager
19
+ from torchft.work import _DummyWork
20
+
21
+
22
+ def create_manager() -> MagicMock:
23
+ """
24
+ Creates a mock manager with some useful defaults for testing
25
+ the optimizer's usage of the Manager
26
+ """
27
+ manager = create_autospec(Manager)
28
+
29
+ manager.errored.return_value = None
30
+
31
+ def mock_allreduce(tensor: torch.Tensor, should_quantize: bool = False) -> Work:
32
+ return _DummyWork(tensor)
33
+
34
+ manager.allreduce.side_effect = mock_allreduce
35
+
36
+ return manager
37
+
38
+
39
+ class SimpleModel(nn.Module):
40
+ def __init__(self) -> None:
41
+ super().__init__()
42
+
43
+ self.model = nn.Sequential(
44
+ nn.Linear(3, 4),
45
+ nn.ReLU(),
46
+ nn.Linear(4, 5),
47
+ nn.Sigmoid(),
48
+ )
49
+
50
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
51
+ return self.model(x)
52
+
53
+
54
+ def _params_dict(m: torch.nn.Module) -> Dict[str, torch.Tensor]:
55
+ return {name: p.data for name, p in m.named_parameters()}
56
+
57
+
58
+ def _copy_state_dict(state_dict: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
59
+ return {name: value.clone().detach() for name, value in state_dict.items()}
60
+
61
+
62
+ class TinyModel(nn.Module):
63
+ def __init__(self) -> None:
64
+ super().__init__()
65
+ self.w1 = nn.Parameter(torch.tensor([1.0, 2.0]))
66
+ self.w2 = nn.Parameter(torch.tensor([3.0, 4.0, 5.0]))
67
+
68
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
69
+ return x @ self.w1.unsqueeze(0).T + self.w2.sum()
70
+
71
+
72
+ class LocalSGDTest(TestCase):
73
+ def test_local_sgd_healthy(self) -> None:
74
+ model = SimpleModel()
75
+ optimizer = optim.SGD(model.parameters())
76
+ manager = create_manager()
77
+ with LocalSGD(manager, model, optimizer, sync_every=2) as local_sgd:
78
+ self.assertEqual(local_sgd._local_step, 0)
79
+ inp = torch.rand(2, 3)
80
+ loss = model(inp).mean()
81
+ loss.backward()
82
+ optimizer.step()
83
+
84
+ self.assertEqual(local_sgd._local_step, 1)
85
+ self.assertEqual(manager.start_quorum.call_count, 0)
86
+ loss = model(inp).mean()
87
+ loss.backward()
88
+ optimizer.step()
89
+ self.assertEqual(manager.start_quorum.call_count, 1)
90
+
91
+ manager.should_commit.return_value = True
92
+ self.assertEqual(local_sgd._local_step, 0)
93
+ self.assertEqual(manager.should_commit.call_count, 1)
94
+ self.assertEqual(manager.allreduce.call_count, 4)
95
+
96
+ def test_extract_local_tensor(self) -> None:
97
+ regular_tensor = torch.rand(3, 3, requires_grad=True)
98
+ regular_result = extract_local_tensor(regular_tensor)
99
+
100
+ self.assertTrue(torch.equal(regular_result, regular_tensor))
101
+ self.assertIsNone(regular_result.grad)
102
+ self.assertNotEqual(id(regular_result), id(regular_tensor))
103
+ local_tensor = torch.rand(3, 3, requires_grad=True)
104
+ dtensor = MagicMock(spec=DTensor)
105
+ dtensor.to_local.return_value = local_tensor
106
+ dtensor_result = extract_local_tensor(dtensor)
107
+
108
+ self.assertTrue(torch.equal(dtensor_result, local_tensor))
109
+ self.assertIsNone(dtensor_result.grad)
110
+ self.assertNotEqual(id(dtensor_result), id(local_tensor))
111
+ dtensor.to_local.assert_called_once()
112
+
113
+ def test_local_sgd_recovery(self) -> None:
114
+ model = SimpleModel()
115
+ optimizer = optim.SGD(model.parameters())
116
+ manager = create_autospec(Manager)
117
+
118
+ with LocalSGD(manager, model, optimizer, sync_every=2) as local_sgd:
119
+ og_state_dict = _copy_state_dict(model.state_dict())
120
+
121
+ inp = torch.rand(2, 3)
122
+
123
+ loss = model(inp).mean()
124
+ loss.backward()
125
+ optimizer.step()
126
+
127
+ # Check that the model's state dict has been updated
128
+ for name, param in model.state_dict().items():
129
+ # Ensure the parameter has changed
130
+ self.assertFalse(
131
+ torch.equal(og_state_dict[name], param),
132
+ f"Parameter {name} did not change.",
133
+ )
134
+ self.assertEqual(local_sgd._local_step, 1)
135
+
136
+
137
+ class DiLoCoTest(TestCase):
138
+ def test_diloco_healthy(self) -> None:
139
+ model = SimpleModel()
140
+
141
+ # Setup optimizers
142
+ inner_optimizer = torch.optim.AdamW(
143
+ model.parameters(), lr=4e-4, weight_decay=0.1, betas=(0.9, 0.95)
144
+ )
145
+ outer_optimizer = torch.optim.SGD(
146
+ model.parameters(), lr=0.7, momentum=0.9, nesterov=True
147
+ )
148
+
149
+ manager = create_manager()
150
+ manager._use_async_quorum = False
151
+ with DiLoCo(
152
+ manager, [model], inner_optimizer, outer_optimizer, sync_every=2
153
+ ) as diloco:
154
+ parameter_count = len(list(model.parameters()))
155
+ initial_outer_opt_state = outer_optimizer.state_dict()
156
+ self.assertEqual(initial_outer_opt_state["state"], {})
157
+
158
+ self.assertEqual(diloco._local_step, 0)
159
+ torch.testing.assert_close(
160
+ diloco._fragments[0].original_parameters, _params_dict(model)
161
+ )
162
+ inp = torch.rand(2, 3)
163
+ loss = model(inp).mean()
164
+ loss.backward()
165
+ inner_optimizer.step()
166
+
167
+ self.assertEqual(diloco._local_step, 1)
168
+ manager.current_step.return_value = 0
169
+ manager.should_commit.return_value = True
170
+ loss = model(inp).mean()
171
+ loss.backward()
172
+ inner_optimizer.step()
173
+
174
+ self.assertEqual(diloco._local_step, 0)
175
+ self.assertEqual(manager.start_quorum.call_count, 1)
176
+ torch.testing.assert_close(
177
+ diloco._fragments[0].original_parameters, _params_dict(model)
178
+ )
179
+ self.assertEqual(manager.should_commit.call_count, 1)
180
+ self.assertEqual(manager.allreduce.call_count, parameter_count)
181
+
182
+ outer_opt_state = outer_optimizer.state_dict()
183
+ self.assertEqual(len(outer_opt_state["state"]), parameter_count)
184
+
185
+ @parameterized.expand(
186
+ [
187
+ ("bucketized_should_use_fewer_calls", True, True),
188
+ ("non_bucketized_should_call_per_param", False, False),
189
+ ]
190
+ )
191
+ def test_diloco_allreduce_call_efficiency(
192
+ self,
193
+ name: str,
194
+ use_bucketization: bool,
195
+ expect_fewer_calls: bool,
196
+ ) -> None:
197
+ model = SimpleModel()
198
+
199
+ inner_optimizer = torch.optim.AdamW(
200
+ model.parameters(), lr=4e-4, weight_decay=0.1, betas=(0.9, 0.95)
201
+ )
202
+ outer_optimizer = torch.optim.SGD(
203
+ model.parameters(), lr=0.7, momentum=0.9, nesterov=True
204
+ )
205
+
206
+ manager = create_manager()
207
+ manager._use_async_quorum = False
208
+ manager.should_commit.return_value = True
209
+
210
+ with DiLoCo(
211
+ manager,
212
+ [model],
213
+ inner_optimizer,
214
+ outer_optimizer,
215
+ sync_every=2,
216
+ use_bucketization=use_bucketization,
217
+ ) as diloco:
218
+ inp = torch.rand(2, 3)
219
+ loss = model(inp).mean()
220
+ loss.backward()
221
+ inner_optimizer.step()
222
+
223
+ manager.current_step.return_value = 0
224
+ loss = model(inp).mean()
225
+ loss.backward()
226
+ inner_optimizer.step()
227
+
228
+ loss = model(inp).mean()
229
+ loss.backward()
230
+ inner_optimizer.step()
231
+
232
+ allreduce_calls = manager.allreduce.call_count
233
+ param_count = len([p for p in model.parameters() if p.requires_grad])
234
+
235
+ if expect_fewer_calls:
236
+ self.assertLess(int(allreduce_calls), int(param_count))
237
+ else:
238
+ self.assertEqual(int(allreduce_calls), int(param_count))
239
+
240
+ def test_bucketization_correctness(self) -> None:
241
+ model = TinyModel()
242
+ inner_opt = torch.optim.SGD(model.parameters(), lr=0.1)
243
+ outer_opt = torch.optim.SGD(model.parameters(), lr=0.1)
244
+
245
+ manager = create_autospec(Manager)
246
+ manager._use_async_quorum = False
247
+ manager.should_commit.return_value = True
248
+
249
+ # Define fake allreduce: multiplies buffer by 2
250
+ def fake_allreduce(tensor: Tensor, should_quantize: bool) -> Work:
251
+ tensor.mul_(2)
252
+ return _DummyWork(tensor)
253
+
254
+ manager.allreduce.side_effect = fake_allreduce
255
+
256
+ diloco = DiLoCo(
257
+ manager, [model], inner_opt, outer_opt, sync_every=2, use_bucketization=True
258
+ )
259
+ diloco._fragments[0].bucket_cap_mb = 10 * 1024 * 1024
260
+
261
+ # Manually assign fake gradients
262
+ grads = [torch.tensor([1.0, 2.0]), torch.tensor([3.0, 4.0, 5.0])]
263
+ for g, (name, param) in zip(grads, model.named_parameters()):
264
+ diloco._fragments[0]._grads[name] = g.clone()
265
+
266
+ # Run only bucketized logic
267
+ diloco._fragments[0]._average_grads()
268
+
269
+ # The parameter gradients should not be set
270
+ for param in model.parameters():
271
+ self.assertEqual(param.grad, None)
272
+
273
+ diloco._fragments[0]._set_grads()
274
+
275
+ # Expect grads to have been doubled
276
+ expected_grads = [g * 2 for g in grads]
277
+ for param, expected in zip(model.parameters(), expected_grads):
278
+ torch.testing.assert_close(param.grad, expected, rtol=1e-5, atol=1e-8)
279
+
280
+ def test_gradient_correctness(self) -> None:
281
+ model = TinyModel()
282
+ inner_opt = torch.optim.SGD(model.parameters(), lr=0.1)
283
+ outer_opt = torch.optim.SGD(model.parameters(), lr=0.1)
284
+
285
+ manager = create_autospec(Manager)
286
+ manager._use_async_quorum = False
287
+ manager.should_commit.return_value = True
288
+
289
+ # Define fake allreduce: multiplies buffer by 2
290
+ def fake_allreduce(tensor: Tensor, should_quantize: bool) -> Work:
291
+ tensor.mul_(2)
292
+ return _DummyWork(tensor)
293
+
294
+ manager.allreduce.side_effect = fake_allreduce
295
+
296
+ diloco = DiLoCo(manager, [model], inner_opt, outer_opt, sync_every=2)
297
+
298
+ # save original parameters
299
+ diloco._fragments[0].save_parameters()
300
+
301
+ # change the model's parameters
302
+ for p in model.parameters():
303
+ p.data.add_(2)
304
+
305
+ # calculate and set the gradients
306
+ diloco._fragments[0]._save_grads()
307
+
308
+ # calculate
309
+ diloco._fragments[0]._average_grads()
310
+
311
+ # The parameter gradients should not be set
312
+ for param in model.parameters():
313
+ self.assertEqual(param.grad, None)
314
+
315
+ diloco._fragments[0]._set_grads()
316
+
317
+ # we added 2 to the parameters, then multiplied the gradients by 2
318
+ # so we should expect the model's gradient to be -4
319
+ expected_grad = -4
320
+ for param in model.parameters():
321
+ assert param.grad is not None
322
+ t = torch.empty_like(param.grad)
323
+ t.fill_(expected_grad)
324
+ torch.testing.assert_close(param.grad, t)