torchft-nightly 2026.1.3__cp310-cp310-manylinux_2_24_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. torchft/__init__.py +34 -0
  2. torchft/_test/diloco_trainer.py +287 -0
  3. torchft/_test/managed_work_test.py +320 -0
  4. torchft/_test_utils.py +111 -0
  5. torchft/_torchft.cpython-310-x86_64-linux-gnu.so +0 -0
  6. torchft/_torchft.pyi +116 -0
  7. torchft/checkpointing/__init__.py +20 -0
  8. torchft/checkpointing/_rwlock.py +136 -0
  9. torchft/checkpointing/_serialization.py +39 -0
  10. torchft/checkpointing/http_transport.py +299 -0
  11. torchft/checkpointing/http_transport_bench.py +61 -0
  12. torchft/checkpointing/http_transport_test.py +146 -0
  13. torchft/checkpointing/pg_transport.py +306 -0
  14. torchft/checkpointing/pg_transport_bench.py +99 -0
  15. torchft/checkpointing/pg_transport_test.py +101 -0
  16. torchft/checkpointing/rwlock_test.py +58 -0
  17. torchft/checkpointing/transport.py +68 -0
  18. torchft/checkpointing/transport_test.py +161 -0
  19. torchft/collectives.py +415 -0
  20. torchft/collectives_test.py +212 -0
  21. torchft/coordination.py +39 -0
  22. torchft/coordination_test.py +29 -0
  23. torchft/data.py +77 -0
  24. torchft/data_test.py +39 -0
  25. torchft/ddp.py +105 -0
  26. torchft/ddp_test.py +68 -0
  27. torchft/diloco_regression_test.py +644 -0
  28. torchft/examples/slurm/README.md +34 -0
  29. torchft/examples/slurm/punisher.py +95 -0
  30. torchft/examples/slurm/runner.py +221 -0
  31. torchft/fsdp_test.py +102 -0
  32. torchft/futures.py +353 -0
  33. torchft/futures_test.py +140 -0
  34. torchft/http.py +13 -0
  35. torchft/lighthouse_test.py +163 -0
  36. torchft/local_sgd.py +796 -0
  37. torchft/local_sgd_integ_test.py +600 -0
  38. torchft/local_sgd_test.py +324 -0
  39. torchft/manager.py +1358 -0
  40. torchft/manager_integ_test.py +653 -0
  41. torchft/manager_test.py +911 -0
  42. torchft/multiprocessing.py +38 -0
  43. torchft/multiprocessing_dummy_context.py +135 -0
  44. torchft/multiprocessing_test.py +58 -0
  45. torchft/optim.py +63 -0
  46. torchft/optim_test.py +50 -0
  47. torchft/otel.py +134 -0
  48. torchft/parameter_server.py +195 -0
  49. torchft/parameter_server_test.py +47 -0
  50. torchft/process_group.py +2118 -0
  51. torchft/process_group_test.py +1028 -0
  52. torchft/quantization.py +686 -0
  53. torchft/quantization_test.py +131 -0
  54. torchft/torchx.py +89 -0
  55. torchft/utils.py +67 -0
  56. torchft/work.py +26 -0
  57. torchft_nightly-2026.1.3.dist-info/METADATA +308 -0
  58. torchft_nightly-2026.1.3.dist-info/RECORD +61 -0
  59. torchft_nightly-2026.1.3.dist-info/WHEEL +4 -0
  60. torchft_nightly-2026.1.3.dist-info/entry_points.txt +2 -0
  61. torchft_nightly-2026.1.3.dist-info/licenses/LICENSE +34 -0
@@ -0,0 +1,911 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import concurrent
8
+ import threading
9
+ import time
10
+ from datetime import timedelta
11
+ from typing import Optional
12
+ from unittest import TestCase
13
+ from unittest.mock import create_autospec, MagicMock, patch
14
+
15
+ import torch
16
+ from torch.distributed import ReduceOp, TCPStore
17
+
18
+ from torchft._torchft import QuorumResult
19
+ from torchft.checkpointing._rwlock import RWLock
20
+ from torchft.checkpointing.transport import CheckpointTransport
21
+ from torchft.manager import Manager, MANAGER_ADDR_KEY, REPLICA_ID_KEY, WorldSizeMode
22
+ from torchft.process_group import ProcessGroup
23
+ from torchft.work import _DummyWork
24
+
25
+
26
+ def mock_should_commit(
27
+ rank: int, step: int, should_commit: bool, timeout: timedelta
28
+ ) -> bool:
29
+ return should_commit
30
+
31
+
32
+ class TestManager(TestCase):
33
+ store: TCPStore # pyre-fixme[13]: never initialized
34
+ load_state_dict: MagicMock # pyre-fixme[13]: never initialized
35
+ manager: Optional[Manager] # pyre-fixme[13]: never initialized
36
+
37
+ def tearDown(self) -> None:
38
+ # Manager cleanup might be handled by _create_manager
39
+ if hasattr(self, "manager") and self.manager is not None:
40
+ self.manager.shutdown(wait=False)
41
+
42
+ def _create_manager(
43
+ self,
44
+ use_async_quorum: bool = True,
45
+ min_replica_size: int = 2,
46
+ world_size_mode: WorldSizeMode = WorldSizeMode.DYNAMIC,
47
+ timeout: timedelta = timedelta(seconds=10),
48
+ init_sync: bool = True,
49
+ max_retries: Optional[int] = None,
50
+ ) -> Manager:
51
+ pg = create_autospec(ProcessGroup)
52
+ pg.errored.return_value = None
53
+
54
+ self.store = TCPStore(
55
+ host_name="localhost", port=0, is_master=True, wait_for_workers=False
56
+ )
57
+ self.store.set(MANAGER_ADDR_KEY, "dummy")
58
+ self.store.set(REPLICA_ID_KEY, "dummy_id")
59
+ with patch(
60
+ "os.environ",
61
+ {
62
+ "MASTER_ADDR": "localhost",
63
+ "MASTER_PORT": self.store.port,
64
+ "RANK": "1",
65
+ "WORLD_SIZE": "2",
66
+ },
67
+ ):
68
+ self.load_state_dict = MagicMock()
69
+ manager = Manager(
70
+ pg=pg,
71
+ min_replica_size=min_replica_size,
72
+ load_state_dict=self.load_state_dict,
73
+ state_dict=lambda: {},
74
+ use_async_quorum=use_async_quorum,
75
+ world_size_mode=world_size_mode,
76
+ timeout=timeout,
77
+ init_sync=init_sync,
78
+ max_retries=max_retries,
79
+ )
80
+ self.manager = manager
81
+ return manager
82
+
83
+ @patch("torchft.manager.ManagerClient", autospec=True)
84
+ def test_manager(self, client_mock: MagicMock) -> None:
85
+ manager = self._create_manager()
86
+ self.assertEqual(client_mock.call_count, 1)
87
+
88
+ @patch("torchft.manager.ManagerClient", autospec=True)
89
+ def test_state_dict(self, client_mock: MagicMock) -> None:
90
+ manager = self._create_manager()
91
+
92
+ state_dict = manager.state_dict()
93
+ self.assertEqual(
94
+ state_dict,
95
+ {
96
+ "step": 0,
97
+ "batches_committed": 0,
98
+ },
99
+ )
100
+
101
+ manager.load_state_dict(
102
+ {
103
+ "step": 1234,
104
+ "batches_committed": 2345,
105
+ }
106
+ )
107
+ self.assertEqual(manager.current_step(), 1234)
108
+ self.assertEqual(manager.batches_committed(), 2345)
109
+
110
+ @patch("torchft.manager.ManagerClient", autospec=True)
111
+ def test_user_state_dict(self, client_mock: MagicMock) -> None:
112
+ manager = self._create_manager()
113
+
114
+ self.assertEqual(
115
+ manager._manager_state_dict(),
116
+ {
117
+ "user": {
118
+ "default": {},
119
+ },
120
+ "torchft": {
121
+ "step": 0,
122
+ "batches_committed": 0,
123
+ },
124
+ },
125
+ )
126
+
127
+ manager.register_state_dict_fn(
128
+ "state",
129
+ self.load_state_dict,
130
+ lambda: {"new_state": 1},
131
+ )
132
+
133
+ self.assertEqual(
134
+ manager._manager_state_dict(),
135
+ {
136
+ "user": {
137
+ "default": {},
138
+ "state": {"new_state": 1},
139
+ },
140
+ "torchft": {
141
+ "step": 0,
142
+ "batches_committed": 0,
143
+ },
144
+ },
145
+ )
146
+
147
+ @patch("torchft.manager.ManagerClient", autospec=True)
148
+ def test_quorum_happy(self, client_mock: MagicMock) -> None:
149
+ manager = self._create_manager()
150
+ client_mock().should_commit = mock_should_commit
151
+
152
+ quorum = QuorumResult()
153
+ quorum.quorum_id = 123
154
+ quorum.replica_rank = 1
155
+ quorum.replica_world_size = 2
156
+ quorum.recover_src_manager_address = "manager address"
157
+ quorum.store_address = f"localhost:{self.store.port}"
158
+ quorum.max_step = 1
159
+ quorum.max_replica_rank = 1
160
+ quorum.max_world_size = 2
161
+ quorum.heal = False
162
+
163
+ client_mock()._quorum.return_value = quorum
164
+
165
+ self.assertEqual(manager._quorum_id, -1)
166
+ self.assertEqual(manager.current_step(), 0)
167
+ self.assertEqual(manager.batches_committed(), 0)
168
+
169
+ manager.start_quorum()
170
+ manager.allreduce(torch.tensor([1.0])).wait()
171
+ self.assertTrue(manager.should_commit())
172
+
173
+ self.assertEqual(manager._quorum_id, 123)
174
+ self.assertEqual(manager.current_step(), 1)
175
+ # pyre-ignore[16]: _pg is mocked
176
+ self.assertEqual(manager._pg.allreduce.call_count, 1)
177
+
178
+ manager.start_quorum()
179
+ self.assertEqual(manager.batches_committed(), 2)
180
+
181
+ @patch("torchft.manager.ManagerClient", autospec=True)
182
+ def test_quorum_heal_sync(self, client_mock: MagicMock) -> None:
183
+ manager = self._create_manager(use_async_quorum=False)
184
+ client_mock().should_commit = mock_should_commit
185
+
186
+ quorum = QuorumResult()
187
+ quorum.quorum_id = 123
188
+ quorum.replica_rank = 1
189
+ quorum.replica_world_size = 2
190
+ quorum.recover_src_manager_address = "manager address"
191
+ quorum.recover_src_replica_rank = 0
192
+ quorum.store_address = f"localhost:{self.store.port}"
193
+ quorum.max_step = 20
194
+ quorum.max_replica_rank = None
195
+ quorum.max_world_size = 2
196
+ quorum.heal = True
197
+
198
+ client_mock()._quorum.return_value = quorum
199
+
200
+ # forcible increment checkpoint server to compute correct address
201
+ manager._checkpoint_transport.send_checkpoint(
202
+ dst_ranks=[],
203
+ step=quorum.max_step,
204
+ state_dict=manager._manager_state_dict(),
205
+ timeout=timedelta(seconds=10),
206
+ )
207
+ client_mock()._checkpoint_metadata.return_value = (
208
+ manager._checkpoint_transport.metadata()
209
+ )
210
+
211
+ self.assertEqual(manager._quorum_id, -1)
212
+ self.assertEqual(manager.current_step(), 0)
213
+
214
+ self.assertEqual(manager.num_participants(), 0)
215
+ self.assertEqual(manager.participating_rank(), None)
216
+
217
+ manager.start_quorum()
218
+ manager.allreduce(torch.tensor([1.0])).wait()
219
+ self.assertFalse(manager._healing)
220
+ self.assertTrue(manager.is_participating())
221
+ self.assertEqual(manager.num_participants(), 2)
222
+ self.assertTrue(manager.should_commit())
223
+
224
+ self.assertEqual(manager._quorum_id, 123)
225
+ self.assertEqual(manager.current_step(), 21)
226
+ # pyre-ignore[16]: _pg is mocked
227
+ self.assertEqual(manager._pg.allreduce.call_count, 1)
228
+ # pyre-ignore[16]: _pg is mocked
229
+ self.assertEqual(manager._pg.allreduce.return_value.get_future.call_count, 1)
230
+
231
+ self.assertEqual(self.load_state_dict.call_count, 1)
232
+
233
+ @patch("torchft.manager.ManagerClient", autospec=True)
234
+ def test_quorum_heal_async_not_enough_participants(
235
+ self, client_mock: MagicMock
236
+ ) -> None:
237
+ manager = self._create_manager(use_async_quorum=True, min_replica_size=2)
238
+ client_mock().should_commit = mock_should_commit
239
+
240
+ quorum = QuorumResult()
241
+ quorum.quorum_id = 123
242
+ quorum.replica_rank = 1
243
+ quorum.replica_world_size = 2
244
+ quorum.recover_src_manager_address = "manager address"
245
+ quorum.recover_src_replica_rank = 0
246
+ quorum.store_address = f"localhost:{self.store.port}"
247
+ quorum.max_step = 20
248
+ quorum.max_replica_rank = None
249
+ quorum.max_world_size = 1
250
+ quorum.heal = True
251
+
252
+ client_mock()._quorum.return_value = quorum
253
+
254
+ # forcible increment checkpoint server to compute correct address
255
+ manager._checkpoint_transport.send_checkpoint(
256
+ dst_ranks=[],
257
+ step=quorum.max_step,
258
+ state_dict=manager._manager_state_dict(),
259
+ timeout=timedelta(seconds=10),
260
+ )
261
+ client_mock()._checkpoint_metadata.return_value = (
262
+ manager._checkpoint_transport.metadata()
263
+ )
264
+
265
+ self.assertEqual(manager._quorum_id, -1)
266
+ self.assertEqual(manager.current_step(), 0)
267
+
268
+ manager.start_quorum()
269
+ assert manager._quorum_future is not None
270
+ manager._quorum_future.result()
271
+ self.assertTrue(manager._healing)
272
+ self.assertFalse(manager.is_participating())
273
+ self.assertEqual(manager.num_participants(), 1)
274
+
275
+ grad = torch.tensor([1.0])
276
+ manager.allreduce(grad).wait()
277
+ torch.testing.assert_close(grad, torch.zeros_like(grad))
278
+ # don't commit since num_max < min_replica_size
279
+ self.assertFalse(manager.should_commit())
280
+ self.assertEqual(manager.current_step(), 20)
281
+
282
+ self.assertEqual(manager._quorum_id, 123)
283
+ self.assertEqual(manager.current_step(), 20)
284
+ # pyre-ignore[16]: _pg is mocked
285
+ self.assertEqual(manager._pg.allreduce.call_count, 1)
286
+ # pyre-ignore[16]: _pg is mocked
287
+ self.assertEqual(manager._pg.allreduce.return_value.get_future.call_count, 1)
288
+
289
+ self.assertEqual(self.load_state_dict.call_count, 1)
290
+
291
+ # failed to commit so no step
292
+ quorum.heal = False
293
+ manager.start_quorum()
294
+ self.assertEqual(manager.current_step(), 20)
295
+ self.assertEqual(manager.batches_committed(), 0)
296
+
297
+ @patch("torchft.manager.ManagerClient", autospec=True)
298
+ def test_quorum_heal_async_zero_grad(self, client_mock: MagicMock) -> None:
299
+ manager = self._create_manager(use_async_quorum=True, min_replica_size=1)
300
+ client_mock().should_commit = mock_should_commit
301
+
302
+ quorum = QuorumResult()
303
+ quorum.quorum_id = 123
304
+ quorum.replica_rank = 1
305
+ quorum.replica_world_size = 2
306
+ quorum.recover_src_manager_address = "manager address"
307
+ quorum.recover_src_replica_rank = 0
308
+ quorum.store_address = f"localhost:{self.store.port}"
309
+ quorum.max_step = 20
310
+ quorum.max_replica_rank = None
311
+ quorum.max_world_size = 1
312
+ quorum.heal = True
313
+
314
+ client_mock()._quorum.return_value = quorum
315
+
316
+ # forceable increment checkpoint server to compute correct address
317
+ manager._checkpoint_transport.send_checkpoint(
318
+ dst_ranks=[],
319
+ step=quorum.max_step,
320
+ state_dict=manager._manager_state_dict(),
321
+ timeout=timedelta(seconds=10),
322
+ )
323
+ client_mock()._checkpoint_metadata.return_value = (
324
+ manager._checkpoint_transport.metadata()
325
+ )
326
+
327
+ self.assertEqual(manager._quorum_id, -1)
328
+ self.assertEqual(manager.current_step(), 0)
329
+
330
+ manager.start_quorum()
331
+ assert manager._quorum_future is not None
332
+ manager._quorum_future.result()
333
+ self.assertTrue(manager._healing)
334
+
335
+ grad = torch.tensor([1.0])
336
+ manager.allreduce(grad).wait()
337
+ torch.testing.assert_close(grad, torch.zeros_like(grad))
338
+ # don't commit since num_max < min_replica_size
339
+ self.assertTrue(manager.should_commit())
340
+ self.assertEqual(manager.num_participants(), 1)
341
+ self.assertTrue(manager.current_step(), 21)
342
+
343
+ self.assertEqual(manager._quorum_id, 123)
344
+ # pyre-ignore[16]: _pg is mocked
345
+ self.assertEqual(manager._pg.allreduce.call_count, 1)
346
+ # pyre-ignore[16]: _pg is mocked
347
+ self.assertEqual(manager._pg.allreduce.return_value.get_future.call_count, 1)
348
+
349
+ self.assertEqual(self.load_state_dict.call_count, 1)
350
+
351
+ # healed
352
+ quorum.heal = False
353
+ manager.start_quorum()
354
+ self.assertEqual(manager.current_step(), 21)
355
+ self.assertEqual(manager.batches_committed(), 1)
356
+
357
+ @patch("torchft.manager.ManagerClient", autospec=True)
358
+ def test_allreduce_error(self, client_mock: MagicMock) -> None:
359
+ manager = self._create_manager()
360
+ client_mock().should_commit = mock_should_commit
361
+
362
+ quorum = QuorumResult()
363
+ quorum.quorum_id = 123
364
+ quorum.replica_rank = 1
365
+ quorum.replica_world_size = 2
366
+ quorum.recover_src_manager_address = "manager address"
367
+ quorum.store_address = f"localhost:{self.store.port}"
368
+ quorum.max_step = 1
369
+ quorum.max_replica_rank = 1
370
+ quorum.max_world_size = 2
371
+ quorum.heal = False
372
+
373
+ client_mock()._quorum.return_value = quorum
374
+
375
+ self.assertEqual(manager._quorum_id, -1)
376
+ self.assertEqual(manager.current_step(), 0)
377
+
378
+ manager.start_quorum()
379
+ manager.allreduce(torch.tensor([1.0])).wait()
380
+ # pyre-ignore[16]: _pg is mocked
381
+ self.assertEqual(manager._pg.allreduce.call_count, 1)
382
+
383
+ # inject failure when work queued
384
+ # pyre-ignore[16]: _pg is mocked
385
+ manager._pg.allreduce.side_effect = RuntimeError("injected failure")
386
+ manager.allreduce(torch.tensor([1.0])).wait()
387
+ self.assertTrue(manager._errored)
388
+ # this should be skipped due to error
389
+ manager.allreduce(torch.tensor([1.0])).wait()
390
+ self.assertEqual(manager._pg.allreduce.call_count, 2)
391
+ # pyre-ignore[16]: _pg is mocked
392
+ self.assertEqual(manager._pg.allreduce.return_value.get_future.call_count, 1)
393
+
394
+ self.assertFalse(manager.should_commit())
395
+ self.assertTrue(manager._errored)
396
+
397
+ # cleanup
398
+ manager._pg.allreduce.side_effect = None
399
+
400
+ # inject failure when worked waited
401
+ quorum.max_step = 2
402
+
403
+ manager.start_quorum()
404
+
405
+ self.assertFalse(manager._errored)
406
+
407
+ bad_fut = torch.futures.Future()
408
+ bad_fut.set_exception(RuntimeError("injected failure"))
409
+ manager._pg.allreduce.return_value.get_future.return_value = bad_fut
410
+ manager.allreduce(torch.tensor([1.0])).wait()
411
+ self.assertEqual(manager._pg.allreduce.return_value.get_future.call_count, 2)
412
+ self.assertTrue(manager._errored)
413
+ self.assertFalse(manager.should_commit())
414
+ self.assertTrue(manager._errored)
415
+
416
+ # cleanup
417
+ manager._pg.allreduce.reset_mock(return_value=True)
418
+
419
+ # recover on next step
420
+ quorum.max_step = 3
421
+
422
+ manager.start_quorum()
423
+ manager.allreduce(torch.tensor([1.0])).wait()
424
+ self.assertTrue(manager.should_commit())
425
+
426
+ @patch("torchft.manager.ManagerClient", autospec=True)
427
+ def test_pg_errored(self, client_mock: MagicMock) -> None:
428
+ manager = self._create_manager()
429
+ client_mock().should_commit = mock_should_commit
430
+
431
+ quorum = QuorumResult()
432
+ quorum.quorum_id = 123
433
+ quorum.replica_rank = 1
434
+ quorum.replica_world_size = 2
435
+ quorum.recover_src_manager_address = "manager address"
436
+ quorum.store_address = f"localhost:{self.store.port}"
437
+ quorum.max_step = 1
438
+ quorum.max_replica_rank = 1
439
+ quorum.max_world_size = 2
440
+ quorum.heal = False
441
+
442
+ client_mock()._quorum.return_value = quorum
443
+
444
+ self.assertEqual(manager._quorum_id, -1)
445
+ self.assertEqual(manager.current_step(), 0)
446
+
447
+ manager.start_quorum()
448
+
449
+ injected_failure = RuntimeError("injected failure")
450
+
451
+ # pyre-ignore[16]: _pg is mocked
452
+ manager._pg.errored.return_value = injected_failure
453
+
454
+ self.assertFalse(manager.should_commit())
455
+ assert manager._errored is not None
456
+ self.assertEqual(manager._errored.original_exception, injected_failure)
457
+ # pyre-ignore[16]: _pg is mocked
458
+ self.assertEqual(manager._pg.errored.call_count, 1)
459
+
460
+ @patch("torchft.manager.ManagerClient", autospec=True)
461
+ def test_quorum_fixed_world_size(self, client_mock: MagicMock) -> None:
462
+ # test active and spares
463
+ for rank in [1, 2]:
464
+ manager = self._create_manager(
465
+ min_replica_size=2,
466
+ world_size_mode=WorldSizeMode.FIXED_WITH_SPARES,
467
+ )
468
+ client_mock().should_commit = mock_should_commit
469
+
470
+ quorum = QuorumResult()
471
+ quorum.quorum_id = 123
472
+ quorum.replica_rank = rank
473
+ quorum.replica_world_size = 3
474
+ quorum.recover_src_manager_address = "manager address"
475
+ quorum.store_address = f"localhost:{self.store.port}"
476
+ quorum.max_step = 1
477
+ quorum.max_replica_rank = rank
478
+ quorum.max_world_size = 3
479
+ quorum.heal = False
480
+
481
+ client_mock()._quorum.return_value = quorum
482
+
483
+ self.assertEqual(manager._quorum_id, -1)
484
+ self.assertEqual(manager.current_step(), 0)
485
+ self.assertEqual(manager.batches_committed(), 0)
486
+
487
+ manager.start_quorum()
488
+ manager.allreduce(torch.tensor([1.0])).wait()
489
+
490
+ self.assertEqual(manager.is_participating(), rank != 2)
491
+ self.assertEqual(manager.num_participants(), 2)
492
+
493
+ self.assertTrue(manager.should_commit())
494
+ self.assertEqual(manager.batches_committed(), 2)
495
+ self.assertEqual(manager.current_step(), 1)
496
+
497
+ @patch("torchft.manager.ManagerClient", autospec=True)
498
+ def test_quorum_no_healing(self, client_mock: MagicMock) -> None:
499
+ manager = self._create_manager(
500
+ min_replica_size=2,
501
+ )
502
+ client_mock().should_commit = mock_should_commit
503
+
504
+ quorum = QuorumResult()
505
+ quorum.quorum_id = 123
506
+ quorum.replica_rank = 0
507
+ quorum.replica_world_size = 3
508
+ quorum.recover_src_manager_address = "manager address"
509
+ quorum.recover_src_replica_rank = 1
510
+ quorum.store_address = f"localhost:{self.store.port}"
511
+ quorum.max_step = 1
512
+ quorum.max_replica_rank = None
513
+ quorum.max_world_size = 2
514
+ quorum.heal = True
515
+ client_mock()._quorum.return_value = quorum
516
+
517
+ self.assertEqual(manager._quorum_id, -1)
518
+ self.assertEqual(manager.current_step(), 0)
519
+ self.assertEqual(manager.batches_committed(), 0)
520
+
521
+ manager.start_quorum(allow_heal=False)
522
+ manager.allreduce(torch.tensor([1.0])).wait()
523
+
524
+ self.assertFalse(manager.is_participating())
525
+ self.assertEqual(manager.num_participants(), 2)
526
+
527
+ self.assertTrue(manager.should_commit())
528
+ self.assertEqual(manager.batches_committed(), 2)
529
+ self.assertEqual(manager.current_step(), 1)
530
+
531
+ @patch("torchft.manager.ManagerClient", autospec=True)
532
+ def test_manager_report_error(self, client_mock: MagicMock) -> None:
533
+ manager = self._create_manager()
534
+
535
+ self.assertIsNone(manager.errored())
536
+ e = RuntimeError("some error")
537
+ manager.report_error(e)
538
+ error = manager.errored()
539
+ assert error is not None
540
+ self.assertIs(error.original_exception, e)
541
+
542
+ @patch("torchft.manager.ManagerClient", autospec=True)
543
+ def test_manager_wrap_future(self, client_mock: MagicMock) -> None:
544
+ manager = self._create_manager()
545
+
546
+ self.assertIsNone(manager.errored())
547
+
548
+ fut = torch.futures.Future()
549
+ wrapped_fut = manager.wrap_future(fut, 2)
550
+ self.assertIsNone(manager.errored())
551
+
552
+ e = RuntimeError("injected failure")
553
+ fut.set_exception(e)
554
+ error = manager.errored()
555
+ assert error is not None
556
+ self.assertIs(error.original_exception, e)
557
+ self.assertEqual(wrapped_fut.value(), 2)
558
+
559
+ @patch("torchft.manager.ManagerClient", autospec=True)
560
+ def test_manager_wrap_future_timeout(self, client_mock: MagicMock) -> None:
561
+ manager = self._create_manager(timeout=timedelta(seconds=0.01))
562
+
563
+ self.assertFalse(manager.errored())
564
+
565
+ fut = torch.futures.Future()
566
+ wrapped_fut = manager.wrap_future(fut, 2)
567
+ wrapped_fut.wait()
568
+ error = manager.errored()
569
+ assert error is not None
570
+ with self.assertRaisesRegex(
571
+ TimeoutError, "future did not complete within.*0.01"
572
+ ):
573
+ raise error.original_exception
574
+
575
+ @patch("torchft.manager.ManagerClient", autospec=True)
576
+ def test_manager_numerics(self, client_mock: MagicMock) -> None:
577
+ manager = self._create_manager()
578
+
579
+ manager._quorum_future = quorum_future = MagicMock(
580
+ spec=concurrent.futures.Future
581
+ )
582
+ manager._participating_replica_rank = 1
583
+ manager._participating_replica_world_size = 5
584
+ self.assertEqual(manager.num_participants(), 5)
585
+ self.assertEqual(quorum_future.result.call_count, 1)
586
+ self.assertEqual(manager.participating_rank(), 1)
587
+ self.assertEqual(quorum_future.result.call_count, 2)
588
+
589
+ # pyre-ignore[16]: _pg is mocked
590
+ manager._pg.allreduce.return_value = _DummyWork(None)
591
+
592
+ self.assertTrue(manager.is_participating())
593
+
594
+ for dtype in (torch.float16, torch.bfloat16, torch.float32, torch.long):
595
+ orig = torch.tensor([10], dtype=dtype)
596
+
597
+ if torch.is_floating_point(orig):
598
+ tensor = orig.clone()
599
+ manager.allreduce(tensor).wait()
600
+ torch.testing.assert_close(tensor, orig / 5)
601
+
602
+ tensor = orig.clone()
603
+ manager.allreduce(tensor, reduce_op=ReduceOp.AVG).wait()
604
+ torch.testing.assert_close(tensor, orig / 5)
605
+
606
+ for reduce_op in [
607
+ ReduceOp.SUM,
608
+ ReduceOp.MAX,
609
+ ReduceOp.MIN,
610
+ ReduceOp.PRODUCT,
611
+ ]:
612
+ tensor = orig.clone()
613
+ manager.allreduce(tensor, reduce_op=reduce_op).wait()
614
+ torch.testing.assert_close(tensor, orig)
615
+
616
+ # check healing numerics
617
+ manager._healing = True
618
+ self.assertFalse(manager.is_participating())
619
+ tensor = torch.tensor([1.0])
620
+ work = manager.allreduce(tensor)
621
+ work.wait()
622
+ torch.testing.assert_close(tensor, torch.tensor([0.0]))
623
+
624
+ @patch("torchft.manager.ManagerClient", autospec=True)
625
+ def test_quorum_happy_timeouts(self, client_mock: MagicMock) -> None:
626
+ manager = self._create_manager(use_async_quorum=False)
627
+
628
+ quorum = QuorumResult()
629
+ quorum.quorum_id = 123
630
+ quorum.replica_rank = 1
631
+ quorum.replica_world_size = 2
632
+ quorum.recover_src_manager_address = "manager address"
633
+ quorum.store_address = f"localhost:{self.store.port}"
634
+ quorum.max_step = 1
635
+ quorum.max_replica_rank = 1
636
+ quorum.max_world_size = 2
637
+ quorum.heal = False
638
+
639
+ client_mock()._quorum.return_value = quorum
640
+
641
+ manager.start_quorum(timeout=timedelta(seconds=12))
642
+ self.assertEqual(
643
+ client_mock()._quorum.call_args.kwargs["timeout"], timedelta(seconds=12)
644
+ )
645
+
646
+ self.assertTrue(manager.should_commit(timeout=timedelta(seconds=23)))
647
+ self.assertEqual(
648
+ client_mock().should_commit.call_args.kwargs["timeout"],
649
+ timedelta(seconds=23),
650
+ )
651
+
652
+ @patch("torchft.manager.ManagerClient", autospec=True)
653
+ def test_quorum_skip_init(self, client_mock: MagicMock) -> None:
654
+ manager = self._create_manager(
655
+ use_async_quorum=False,
656
+ init_sync=False,
657
+ )
658
+
659
+ self.assertFalse(manager._init_sync)
660
+
661
+ quorum = QuorumResult()
662
+ quorum.quorum_id = 123
663
+ quorum.replica_rank = 1
664
+ quorum.replica_world_size = 2
665
+ quorum.recover_src_manager_address = "manager address"
666
+ quorum.store_address = f"localhost:{self.store.port}"
667
+ quorum.max_step = 1
668
+ quorum.max_replica_rank = 1
669
+ quorum.max_world_size = 2
670
+ quorum.heal = False
671
+
672
+ client_mock()._quorum.return_value = quorum
673
+
674
+ manager.start_quorum()
675
+ self.assertEqual(client_mock()._quorum.call_args.kwargs["init_sync"], False)
676
+
677
+ manager._init_sync = True
678
+ manager.start_quorum()
679
+ self.assertEqual(client_mock()._quorum.call_args.kwargs["init_sync"], True)
680
+
681
+ @patch("torchft.manager.ManagerClient", autospec=True)
682
+ def test_quorum_checkpoint_errors(self, client_mock: MagicMock) -> None:
683
+ manager = self._create_manager(use_async_quorum=True)
684
+ client_mock().should_commit = MagicMock(return_value=False)
685
+
686
+ transport = MagicMock(spec=CheckpointTransport)
687
+ transport.send_checkpoint.side_effect = RuntimeError("send failure")
688
+ transport.recv_checkpoint.side_effect = RuntimeError("recv failure")
689
+ manager._checkpoint_transport = transport
690
+
691
+ quorum = QuorumResult()
692
+ quorum.quorum_id = 123
693
+ quorum.replica_rank = 1
694
+ quorum.replica_world_size = 2
695
+ quorum.recover_src_manager_address = "manager address"
696
+ quorum.recover_src_replica_rank = 0
697
+ quorum.store_address = f"localhost:{self.store.port}"
698
+ quorum.max_step = 20
699
+ quorum.max_replica_rank = None
700
+ quorum.max_world_size = 2
701
+ quorum.heal = True
702
+
703
+ client_mock()._quorum.return_value = quorum
704
+
705
+ manager.start_quorum()
706
+ manager.wait_quorum()
707
+ self.assertFalse(manager.should_commit())
708
+
709
+ error = manager.errored()
710
+ assert error is not None
711
+ with self.assertRaisesRegex(RuntimeError, "recv failure"):
712
+ raise error.original_exception
713
+
714
+ quorum.recover_dst_replica_ranks = [0]
715
+ manager.start_quorum()
716
+ manager.wait_quorum()
717
+ self.assertFalse(manager.should_commit())
718
+
719
+ error = manager.errored()
720
+ assert error is not None
721
+ with self.assertRaisesRegex(RuntimeError, "send failure"):
722
+ raise error.original_exception
723
+
724
+ @patch("torchft.manager.ManagerClient", autospec=True)
725
+ def test_quorum_configure_errors(self, client_mock: MagicMock) -> None:
726
+ manager = self._create_manager(use_async_quorum=True)
727
+ client_mock().should_commit = MagicMock(return_value=False)
728
+
729
+ # pyre-ignore[16]: mock
730
+ manager._pg.configure.side_effect = RuntimeError("configure failure")
731
+
732
+ quorum = QuorumResult()
733
+ quorum.quorum_id = 123
734
+ quorum.replica_rank = 1
735
+ quorum.replica_world_size = 2
736
+ quorum.recover_src_manager_address = "manager address"
737
+ quorum.recover_src_replica_rank = 0
738
+ quorum.store_address = f"localhost:{self.store.port}"
739
+ quorum.max_step = 20
740
+ quorum.max_replica_rank = None
741
+ quorum.max_world_size = 2
742
+
743
+ client_mock()._quorum.return_value = quorum
744
+
745
+ manager.start_quorum()
746
+ manager.wait_quorum()
747
+ self.assertFalse(manager.should_commit())
748
+
749
+ error = manager.errored()
750
+ assert error is not None
751
+ with self.assertRaisesRegex(RuntimeError, "configure failure"):
752
+ raise error.original_exception
753
+
754
+ @patch("torchft.manager.ManagerClient", autospec=True)
755
+ def test_max_retries(self, client_mock: MagicMock) -> None:
756
+ # Create a manager with max_retries=2
757
+ manager = self._create_manager(max_retries=2)
758
+
759
+ # Setup quorum for testing
760
+ quorum = QuorumResult()
761
+ quorum.quorum_id = 123
762
+ quorum.replica_rank = 1
763
+ quorum.replica_world_size = 2
764
+ quorum.recover_src_manager_address = "manager address"
765
+ quorum.store_address = f"localhost:{self.store.port}"
766
+ quorum.max_step = 1
767
+ quorum.max_replica_rank = 1
768
+ quorum.max_world_size = 2
769
+ quorum.heal = False
770
+ client_mock()._quorum.return_value = quorum
771
+
772
+ # Make should_commit always return False to simulate failures
773
+ client_mock().should_commit = MagicMock(return_value=False)
774
+
775
+ # Start quorum
776
+ manager.start_quorum()
777
+
778
+ # First failure
779
+ self.assertFalse(manager.should_commit())
780
+ self.assertEqual(manager._commit_failures, 1)
781
+
782
+ # Second failure
783
+ self.assertFalse(manager.should_commit())
784
+ self.assertEqual(manager._commit_failures, 2)
785
+
786
+ # Third failure - should raise exception
787
+ with self.assertRaises(RuntimeError) as context:
788
+ manager.should_commit()
789
+
790
+ self.assertIn("exceeding max_retries=2", str(context.exception))
791
+ self.assertEqual(manager._commit_failures, 3)
792
+
793
+ # Now test that success resets the counter
794
+ manager._commit_failures = 2 # Reset to just before failure threshold
795
+ client_mock().should_commit = MagicMock(return_value=True) # Now succeed
796
+
797
+ # This should succeed and reset the counter
798
+ self.assertTrue(manager.should_commit())
799
+ self.assertEqual(manager._commit_failures, 0)
800
+
801
+ @patch("torchft.manager.ManagerClient", autospec=True)
802
+ def test_state_dict_lock_allow_disallow(self, client_mock: MagicMock) -> None:
803
+ """Test that allow_state_dict_read and disallow_state_dict_read methods work correctly."""
804
+ manager = self._create_manager()
805
+
806
+ # Initially, state dict read should be allowed
807
+ self.assertTrue(manager._is_state_dict_read_allowed)
808
+
809
+ # Test disallow_state_dict_read
810
+ manager.disallow_state_dict_read()
811
+ self.assertFalse(manager._is_state_dict_read_allowed)
812
+ self.assertTrue(manager._state_dict_lock.w_locked())
813
+
814
+ # Calling disallow_state_dict_read again should be a no-op
815
+ manager.disallow_state_dict_read()
816
+ self.assertFalse(manager._is_state_dict_read_allowed)
817
+ self.assertTrue(manager._state_dict_lock.w_locked())
818
+
819
+ # Test allow_state_dict_read
820
+ manager.allow_state_dict_read()
821
+ self.assertTrue(manager._is_state_dict_read_allowed)
822
+ self.assertFalse(manager._state_dict_lock.w_locked())
823
+
824
+ # Calling allow_state_dict_read again should be a no-op
825
+ manager.allow_state_dict_read()
826
+ self.assertTrue(manager._is_state_dict_read_allowed)
827
+ self.assertFalse(manager._state_dict_lock.w_locked())
828
+
829
+ @patch("torchft.manager.ManagerClient", autospec=True)
830
+ def test_state_dict_lock_concurrent_access(self, client_mock: MagicMock) -> None:
831
+ """Test that _state_dict_lock properly protects concurrent access to the state dictionary."""
832
+ manager: Manager = self._create_manager()
833
+
834
+ # Create flags for thread synchronization
835
+ access_attempted: threading.Event = threading.Event()
836
+ can_proceed: threading.Event = threading.Event()
837
+ access_result: dict[str, bool] = {"succeeded": False}
838
+
839
+ def try_access_state_dict() -> None:
840
+ # Wait until the main thread signals it's ready
841
+ nonlocal access_attempted, can_proceed, access_result, manager
842
+ access_attempted.set()
843
+ can_proceed.wait(timeout=1.0)
844
+
845
+ # Try to access the state dict
846
+ if manager._is_state_dict_read_allowed:
847
+ access_result["succeeded"] = True
848
+
849
+ # Start a thread that will try to access the state dict
850
+ thread = threading.Thread(target=try_access_state_dict)
851
+ thread.daemon = True
852
+ thread.start()
853
+
854
+ # Disallow state dict read
855
+ manager.disallow_state_dict_read()
856
+ self.assertFalse(manager._is_state_dict_read_allowed)
857
+
858
+ # Wait for the thread to be ready
859
+ access_attempted.wait(timeout=1.0)
860
+
861
+ # Signal the thread to proceed while state dict read is disallowed
862
+ can_proceed.set()
863
+ thread.join(timeout=1.0)
864
+
865
+ # The thread should not have been able to access the state dict
866
+ self.assertFalse(access_result["succeeded"])
867
+
868
+ # Reset for the second part of the test
869
+ access_attempted.clear()
870
+ can_proceed.clear()
871
+
872
+ # Start another thread
873
+ thread = threading.Thread(target=try_access_state_dict)
874
+ thread.daemon = True
875
+ thread.start()
876
+
877
+ # Allow state dict read
878
+ manager.allow_state_dict_read()
879
+ self.assertTrue(manager._is_state_dict_read_allowed)
880
+
881
+ # Wait for the thread to be ready
882
+ access_attempted.wait(timeout=1.0)
883
+
884
+ # Signal the thread to proceed while state dict read is allowed
885
+ can_proceed.set()
886
+ thread.join(timeout=1.0)
887
+
888
+ # The thread should now have been able to access the state dict
889
+ self.assertTrue(access_result["succeeded"])
890
+
891
+ @patch("torchft.manager.ManagerClient", autospec=True)
892
+ def test_manager_state_dict_with_lock(self, client_mock: MagicMock) -> None:
893
+ """Test that _manager_state_dict properly uses the read lock."""
894
+ manager = self._create_manager()
895
+
896
+ # Replace the real RWLock with a mock to track lock acquisition
897
+ original_lock = manager._state_dict_lock
898
+ mock_lock = create_autospec(RWLock)
899
+ mock_context = MagicMock()
900
+ mock_lock.r_lock.return_value.__enter__ = lambda _: mock_context
901
+ mock_lock.r_lock.return_value.__exit__ = lambda *args: None
902
+ manager._state_dict_lock = mock_lock
903
+
904
+ # Call _manager_state_dict
905
+ result = manager._manager_state_dict()
906
+
907
+ # Verify that r_lock was called
908
+ mock_lock.r_lock.assert_called_once()
909
+
910
+ # Restore the original lock
911
+ manager._state_dict_lock = original_lock