torchft-nightly 2026.1.3__cp310-cp310-manylinux_2_24_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. torchft/__init__.py +34 -0
  2. torchft/_test/diloco_trainer.py +287 -0
  3. torchft/_test/managed_work_test.py +320 -0
  4. torchft/_test_utils.py +111 -0
  5. torchft/_torchft.cpython-310-x86_64-linux-gnu.so +0 -0
  6. torchft/_torchft.pyi +116 -0
  7. torchft/checkpointing/__init__.py +20 -0
  8. torchft/checkpointing/_rwlock.py +136 -0
  9. torchft/checkpointing/_serialization.py +39 -0
  10. torchft/checkpointing/http_transport.py +299 -0
  11. torchft/checkpointing/http_transport_bench.py +61 -0
  12. torchft/checkpointing/http_transport_test.py +146 -0
  13. torchft/checkpointing/pg_transport.py +306 -0
  14. torchft/checkpointing/pg_transport_bench.py +99 -0
  15. torchft/checkpointing/pg_transport_test.py +101 -0
  16. torchft/checkpointing/rwlock_test.py +58 -0
  17. torchft/checkpointing/transport.py +68 -0
  18. torchft/checkpointing/transport_test.py +161 -0
  19. torchft/collectives.py +415 -0
  20. torchft/collectives_test.py +212 -0
  21. torchft/coordination.py +39 -0
  22. torchft/coordination_test.py +29 -0
  23. torchft/data.py +77 -0
  24. torchft/data_test.py +39 -0
  25. torchft/ddp.py +105 -0
  26. torchft/ddp_test.py +68 -0
  27. torchft/diloco_regression_test.py +644 -0
  28. torchft/examples/slurm/README.md +34 -0
  29. torchft/examples/slurm/punisher.py +95 -0
  30. torchft/examples/slurm/runner.py +221 -0
  31. torchft/fsdp_test.py +102 -0
  32. torchft/futures.py +353 -0
  33. torchft/futures_test.py +140 -0
  34. torchft/http.py +13 -0
  35. torchft/lighthouse_test.py +163 -0
  36. torchft/local_sgd.py +796 -0
  37. torchft/local_sgd_integ_test.py +600 -0
  38. torchft/local_sgd_test.py +324 -0
  39. torchft/manager.py +1358 -0
  40. torchft/manager_integ_test.py +653 -0
  41. torchft/manager_test.py +911 -0
  42. torchft/multiprocessing.py +38 -0
  43. torchft/multiprocessing_dummy_context.py +135 -0
  44. torchft/multiprocessing_test.py +58 -0
  45. torchft/optim.py +63 -0
  46. torchft/optim_test.py +50 -0
  47. torchft/otel.py +134 -0
  48. torchft/parameter_server.py +195 -0
  49. torchft/parameter_server_test.py +47 -0
  50. torchft/process_group.py +2118 -0
  51. torchft/process_group_test.py +1028 -0
  52. torchft/quantization.py +686 -0
  53. torchft/quantization_test.py +131 -0
  54. torchft/torchx.py +89 -0
  55. torchft/utils.py +67 -0
  56. torchft/work.py +26 -0
  57. torchft_nightly-2026.1.3.dist-info/METADATA +308 -0
  58. torchft_nightly-2026.1.3.dist-info/RECORD +61 -0
  59. torchft_nightly-2026.1.3.dist-info/WHEEL +4 -0
  60. torchft_nightly-2026.1.3.dist-info/entry_points.txt +2 -0
  61. torchft_nightly-2026.1.3.dist-info/licenses/LICENSE +34 -0
@@ -0,0 +1,600 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import copy
8
+ import logging
9
+ import os
10
+ import re
11
+ import sys
12
+ import threading
13
+ import traceback
14
+ from concurrent.futures import as_completed, ThreadPoolExecutor
15
+ from contextlib import ExitStack
16
+ from dataclasses import field
17
+ from datetime import timedelta
18
+ from typing import Any, cast, Dict
19
+ from unittest import skipIf, TestCase
20
+
21
+ import torch
22
+ from parameterized import parameterized
23
+ from torch import nn, optim
24
+ from torch.distributed.pipelining import pipeline, SplitPoint
25
+ from torch.distributed.tensor import DTensor, Replicate
26
+
27
+ from torchft._test.diloco_trainer import DiLoCoTrainer, MultiMyModel
28
+ from torchft._torchft import LighthouseServer
29
+ from torchft.local_sgd import DiLoCo, LocalSGD
30
+ from torchft.manager import Manager
31
+ from torchft.manager_integ_test import (
32
+ EventInjector,
33
+ EventInjectorEvent,
34
+ MyModel,
35
+ Runner,
36
+ )
37
+ from torchft.process_group import (
38
+ FakeProcessGroupWrapper,
39
+ ProcessGroupBabyNCCL,
40
+ ProcessGroupGloo,
41
+ )
42
+
43
+ logger: logging.Logger = logging.getLogger(__name__)
44
+ logging.basicConfig(level=logging.INFO)
45
+
46
+
47
+ def local_sgd_train_loop(
48
+ rank: int,
49
+ store_port: int,
50
+ device: torch.device,
51
+ runner: Runner,
52
+ train_loop_args: dict[str, Any] = {},
53
+ ) -> Dict[str, Dict[str, object]]:
54
+ with ExitStack() as stack:
55
+
56
+ def load_state_dict(state_dict: Dict[str, Dict[str, object]]) -> None:
57
+ m.load_state_dict(state_dict["model"])
58
+ optimizer.load_state_dict(state_dict["optim"])
59
+
60
+ def state_dict() -> Dict[str, Dict[str, object]]:
61
+ return {
62
+ "model": m.state_dict(),
63
+ "optim": optimizer.state_dict(),
64
+ }
65
+
66
+ print(f"worker {runner.replica_id=} {rank=} {runner.world_size=} starting")
67
+
68
+ if device.type == "cuda":
69
+ pg = ProcessGroupBabyNCCL()
70
+ else:
71
+ pg = ProcessGroupGloo()
72
+ manager = Manager(
73
+ pg=pg,
74
+ min_replica_size=2,
75
+ load_state_dict=load_state_dict,
76
+ state_dict=state_dict,
77
+ replica_id=str(runner.replica_id),
78
+ store_addr="localhost",
79
+ store_port=store_port,
80
+ rank=rank,
81
+ world_size=runner.world_size,
82
+ lighthouse_addr=runner.lighthouse_address,
83
+ port=19530 + runner.replica_id,
84
+ timeout=timedelta(seconds=10),
85
+ # pyre-fixme[6]: Incompatible parameter type
86
+ **runner.manager_args,
87
+ )
88
+ stack.callback(lambda: manager.shutdown(wait=False))
89
+
90
+ m: nn.Module = MyModel().to(device)
91
+
92
+ optimizer: optim.Optimizer = optim.Adam(m.parameters())
93
+ criterion = nn.CrossEntropyLoss()
94
+
95
+ with LocalSGD(manager, m, optimizer, sync_every=2) as local_sgd:
96
+ while True:
97
+ inputs = torch.rand(2, 3).to(device)
98
+ labels = torch.randint(4, (2,)).to(device)
99
+
100
+ optimizer.zero_grad()
101
+ out = m(inputs)
102
+ loss = criterion(out, labels)
103
+ loss.backward()
104
+
105
+ optimizer.step()
106
+
107
+ if manager.current_step() >= 4:
108
+ break
109
+
110
+ runner.event_injector.check(rank, manager.current_step())
111
+
112
+ # return state_dict so we can check consistency
113
+ return state_dict()
114
+ return {}
115
+
116
+
117
+ def diloco_train_loop(
118
+ rank: int,
119
+ store_port: int,
120
+ device: torch.device,
121
+ runner: Runner,
122
+ train_loop_args: dict[str, Any] = {},
123
+ ) -> Dict[str, Dict[str, object]]:
124
+ model_state_dict = train_loop_args.get("model_state_dict", {})
125
+ n_fragments = train_loop_args.get("n_fragments", 1)
126
+ diloco_args = train_loop_args.get("diloco_args", {})
127
+
128
+ with ExitStack() as stack:
129
+ trainer = DiLoCoTrainer(
130
+ rank, store_port, device, runner, model_state_dict, n_fragments, diloco_args
131
+ )
132
+ stack.callback(trainer.manager.shutdown)
133
+ return trainer.train_loop()
134
+ return {}
135
+
136
+
137
+ def assert_equal_global_state(
138
+ n_fragments: int,
139
+ rep0: dict[str, dict[str, dict[str, dict[str, object]]]],
140
+ rep1: dict[str, dict[str, dict[str, dict[str, object]]]],
141
+ ) -> None:
142
+ """
143
+ Asserts that the global state of the two replicas are equal
144
+ """
145
+ for step in rep0.keys():
146
+ for i in range(n_fragments):
147
+ torch.testing.assert_close(
148
+ rep1[step]["user"][f"StreamingDiLoCoFragment_{i}"][
149
+ "original_parameters"
150
+ ],
151
+ rep0[step]["user"][f"StreamingDiLoCoFragment_{i}"][
152
+ "original_parameters"
153
+ ],
154
+ check_device=False,
155
+ msg=f"{step=} {i=}",
156
+ )
157
+ # Check all outer optimizers
158
+ torch.testing.assert_close(
159
+ cast(
160
+ dict[str, dict[str, torch.Tensor]],
161
+ rep1[step]["user"][f"StreamingDiLoCoFragment_{i}"][
162
+ "outer_optimizer"
163
+ ],
164
+ ),
165
+ cast(
166
+ dict[str, dict[str, torch.Tensor]],
167
+ rep0[step]["user"][f"StreamingDiLoCoFragment_{i}"][
168
+ "outer_optimizer"
169
+ ],
170
+ ),
171
+ check_device=False,
172
+ )
173
+
174
+
175
+ class LocalSGDIntegTest(TestCase):
176
+ # TODO: race condition due to using NCCL in threads causes manager allreduce to sometimes not be correct
177
+ # Because of that the test is disabled for cuda
178
+ @parameterized.expand(
179
+ [
180
+ # (True,),
181
+ (False,),
182
+ ]
183
+ )
184
+ def test_local_sgd_recovery(self, use_cuda: bool) -> None:
185
+ # Skip the test if use_cuda is True and there are not enough GPUs
186
+ if use_cuda and torch.cuda.device_count() < 2:
187
+ self.skipTest("Not enough GPUs for CUDA test")
188
+ if sys.platform == "darwin":
189
+ self.skipTest("not reliable on mac")
190
+
191
+ lighthouse = LighthouseServer(
192
+ bind="[::]:0",
193
+ min_replicas=2,
194
+ )
195
+ num_replicas = 2
196
+ futures = []
197
+
198
+ event_injectors = [
199
+ EventInjector(),
200
+ EventInjector().fail_at(0, 2),
201
+ ]
202
+
203
+ with ThreadPoolExecutor(max_workers=num_replicas) as executor:
204
+ for replica_id, event_injector in zip(range(num_replicas), event_injectors):
205
+ runner = Runner(
206
+ replica_id=replica_id,
207
+ num_replicas=num_replicas,
208
+ lighthouse_address=lighthouse.address(),
209
+ event_injector=event_injector,
210
+ train_loop=local_sgd_train_loop,
211
+ use_cuda=use_cuda,
212
+ manager_args={
213
+ "use_async_quorum": False,
214
+ },
215
+ )
216
+ futures.append(executor.submit(runner.run_replica))
217
+
218
+ state_dicts = []
219
+
220
+ for fut in as_completed(futures):
221
+ try:
222
+ state_dicts.append(fut.result())
223
+ except Exception as e:
224
+ print(e)
225
+ raise
226
+
227
+ lighthouse.shutdown()
228
+
229
+ for state_dict in state_dicts:
230
+ # LocalSGD only guarantees that the model is consistent across
231
+ # replicas but uses separate optimizer states.
232
+ torch.testing.assert_close(
233
+ state_dict[0]["model"], state_dicts[0][0]["model"], check_device=False
234
+ )
235
+
236
+ self.assertEqual(event_injectors[1].count[EventInjectorEvent.Failure], 1)
237
+
238
+ @parameterized.expand(
239
+ [
240
+ # (True,),
241
+ (False,),
242
+ ]
243
+ )
244
+ def test_diloco_healthy(self, use_cuda: bool) -> None:
245
+ # Skip the test if use_cuda is True and there are not enough GPUs
246
+ if use_cuda and torch.cuda.device_count() < 2:
247
+ self.skipTest("Not enough GPUs for CUDA test")
248
+ if sys.platform == "darwin":
249
+ self.skipTest("not reliable on mac")
250
+
251
+ lighthouse = LighthouseServer(bind="[::]:0", min_replicas=2)
252
+ num_replicas = 2
253
+ futures = []
254
+
255
+ torch.manual_seed(42)
256
+ # Initialize the model so we can pass in the state_dict
257
+ m: nn.Module = MultiMyModel(2, 3, 1)
258
+
259
+ with ThreadPoolExecutor(max_workers=num_replicas) as executor:
260
+ for replica_id in range(num_replicas):
261
+ event_injector = EventInjector()
262
+ runner = Runner(
263
+ replica_id=replica_id,
264
+ num_replicas=num_replicas,
265
+ lighthouse_address=lighthouse.address(),
266
+ event_injector=event_injector,
267
+ train_loop=diloco_train_loop,
268
+ use_cuda=use_cuda,
269
+ train_loop_args={
270
+ "model_state_dict": m.state_dict(),
271
+ },
272
+ )
273
+ futures.append(executor.submit(runner.run_replica))
274
+
275
+ state_dicts = []
276
+ for fut in as_completed(futures):
277
+ try:
278
+ state_dicts.append(fut.result()[0])
279
+ except Exception as e:
280
+ print(e, flush=True)
281
+ traceback.print_exc()
282
+ raise
283
+
284
+ lighthouse.shutdown()
285
+
286
+ rep0, rep1 = state_dicts
287
+ assert_equal_global_state(1, rep1, rep0)
288
+
289
+ # pyre-fixme[56]: Pyre was not able to infer the type of argument
290
+ @skipIf(sys.platform == "darwin", "not reliable on mac")
291
+ @parameterized.expand(
292
+ [
293
+ # (True,),
294
+ (False,),
295
+ ]
296
+ )
297
+ def test_diloco_recovery(self, use_cuda: bool) -> None:
298
+ # Skip the test if use_cuda is True and there are not enough GPUs
299
+ if use_cuda and torch.cuda.device_count() < 2:
300
+ self.skipTest("Not enough GPUs for CUDA test")
301
+ if sys.platform == "darwin":
302
+ self.skipTest("not reliable on mac")
303
+
304
+ lighthouse = LighthouseServer(
305
+ bind="[::]:0",
306
+ min_replicas=2,
307
+ )
308
+ num_replicas = 2
309
+ futures = []
310
+
311
+ event_injectors = [
312
+ EventInjector(),
313
+ EventInjector().fail_at(0, 2),
314
+ ]
315
+
316
+ torch.manual_seed(42)
317
+ # Initialize the model so we can pass in the state_dict
318
+ m: nn.Module = MultiMyModel(2, 3, 1)
319
+
320
+ with ThreadPoolExecutor(max_workers=num_replicas) as executor:
321
+ for replica_id, event_injector in zip(range(num_replicas), event_injectors):
322
+ runner = Runner(
323
+ replica_id=replica_id,
324
+ num_replicas=num_replicas,
325
+ lighthouse_address=lighthouse.address(),
326
+ event_injector=event_injector,
327
+ train_loop=diloco_train_loop,
328
+ train_loop_args={
329
+ "model_state_dict": m.state_dict(),
330
+ },
331
+ )
332
+ futures.append(executor.submit(runner.run_replica))
333
+
334
+ state_dicts = []
335
+
336
+ for fut in as_completed(futures):
337
+ continue
338
+
339
+ for fut in futures:
340
+ try:
341
+ state_dicts.append(fut.result()[0])
342
+ except Exception as e:
343
+ print(e)
344
+ raise
345
+
346
+ lighthouse.shutdown()
347
+
348
+ rep0, rep1 = state_dicts
349
+
350
+ # Inner optimizer and local model parameters will be different e.g.
351
+ # with 2 replicas r1 and r2, we sync every 2 steps
352
+ #
353
+ # - Manager Step 1
354
+ # - Step 1: r1 and r2 step
355
+ # - Step 2: r1 and r2 step, sync the model, quorum succeeds
356
+ # - Manager Step 2
357
+ # - Step 1: r1 steps but r2 fails
358
+ # - Step 2:
359
+ # - r1 steps, sync fails because r2 is down
360
+ # - r1 recovers r2 from the model state at this step
361
+ # that is different from the model for r1 at the beginning
362
+ # of step Manager Step 2
363
+ #
364
+ # Outer optimizer and global model should be the same
365
+ assert_equal_global_state(1, rep1, rep0)
366
+
367
+ self.assertEqual(event_injectors[1].count[EventInjectorEvent.Failure], 1)
368
+
369
+ # pyre-fixme[56]: Pyre was not able to infer the type of argument
370
+ @skipIf(sys.platform == "darwin", "not reliable on mac")
371
+ @parameterized.expand(
372
+ [
373
+ # (True,),
374
+ (False,),
375
+ ]
376
+ )
377
+ def test_streaming_diloco_recovery(self, use_cuda: bool) -> None:
378
+ # Skip the test if use_cuda is True and there are not enough GPUs
379
+ if use_cuda and torch.cuda.device_count() < 2:
380
+ self.skipTest("Not enough GPUs for CUDA test")
381
+ if sys.platform == "darwin":
382
+ self.skipTest("not reliable on mac")
383
+
384
+ lighthouse = LighthouseServer(
385
+ bind="[::]:0",
386
+ min_replicas=2,
387
+ )
388
+ num_replicas = 2
389
+ futures = []
390
+
391
+ event_injectors = [
392
+ EventInjector(),
393
+ EventInjector().fail_at(0, 2),
394
+ ]
395
+
396
+ torch.manual_seed(42)
397
+ # Initialize the model so we can pass in the state_dict
398
+ m: nn.Module = MultiMyModel(2, 3, 2)
399
+
400
+ with ThreadPoolExecutor(max_workers=num_replicas) as executor:
401
+ for replica_id, event_injector in zip(range(num_replicas), event_injectors):
402
+ runner = Runner(
403
+ replica_id=replica_id,
404
+ num_replicas=num_replicas,
405
+ lighthouse_address=lighthouse.address(),
406
+ event_injector=event_injector,
407
+ train_loop=diloco_train_loop,
408
+ train_loop_args={
409
+ "model_state_dict": m.state_dict(),
410
+ "n_fragments": 2,
411
+ "diloco_args": {
412
+ "fragment_sync_delay": 1,
413
+ "sync_every": 4,
414
+ },
415
+ },
416
+ )
417
+ futures.append(executor.submit(runner.run_replica))
418
+
419
+ state_dicts = []
420
+
421
+ for fut in as_completed(futures):
422
+ continue
423
+
424
+ for fut in futures:
425
+ try:
426
+ state_dicts.append(fut.result()[0])
427
+ except Exception as e:
428
+ print(e)
429
+ raise
430
+
431
+ lighthouse.shutdown()
432
+
433
+ rep0, rep1 = state_dicts
434
+
435
+ assert_equal_global_state(2, rep1, rep0)
436
+
437
+ self.assertEqual(event_injectors[1].count[EventInjectorEvent.Failure], 1)
438
+
439
+ CONFIG: list[tuple[bool, int, int, float]] = [
440
+ (use_cuda, n_fragments, fragment_sync_delay, alpha)
441
+ for use_cuda in [False]
442
+ for n_fragments in [1, 2]
443
+ for fragment_sync_delay in [0, 1]
444
+ for alpha in [0.0, 0.5, 1.0]
445
+ ]
446
+
447
+ # pyre-fixme[56]: Pyre was not able to infer the type of argument
448
+ @skipIf(sys.platform == "darwin", "not reliable on mac")
449
+ @parameterized.expand(CONFIG)
450
+ def test_streaming_diloco_upscale(
451
+ self, use_cuda: bool, n_fragments: int, fragment_sync_delay: int, alpha: float
452
+ ) -> None:
453
+ # Skip the test if use_cuda is True and there are not enough GPUs
454
+ if use_cuda and torch.cuda.device_count() < 2:
455
+ self.skipTest("Not enough GPUs for CUDA test")
456
+ if sys.platform == "darwin":
457
+ self.skipTest("not reliable on mac")
458
+
459
+ lighthouse = LighthouseServer(
460
+ bind="[::]:0",
461
+ min_replicas=2,
462
+ )
463
+ num_replicas = 3
464
+ futures = []
465
+ executors = []
466
+
467
+ barrier = threading.Barrier(num_replicas)
468
+
469
+ event_injectors = [
470
+ # Make this replica join after other replicas have made 2 steps
471
+ EventInjector().barrier_at(0, 0, barrier),
472
+ EventInjector().barrier_at(0, 2, barrier),
473
+ EventInjector().barrier_at(0, 2, barrier),
474
+ ]
475
+
476
+ torch.manual_seed(42)
477
+ # Initialize the model so we can pass in the state_dict
478
+ m: nn.Module = MultiMyModel(2, 3, n_fragments)
479
+
480
+ for replica_id, event_injector in zip(range(num_replicas), event_injectors):
481
+ executor = ThreadPoolExecutor(max_workers=1)
482
+ executors.append(executor)
483
+ runner = Runner(
484
+ replica_id=replica_id,
485
+ num_replicas=num_replicas,
486
+ lighthouse_address=lighthouse.address(),
487
+ event_injector=event_injector,
488
+ train_loop=diloco_train_loop,
489
+ train_loop_args={
490
+ "model_state_dict": m.state_dict(),
491
+ "n_fragments": n_fragments,
492
+ "diloco_args": {
493
+ "fragment_sync_delay": fragment_sync_delay,
494
+ "sync_every": 4,
495
+ "fragment_update_alpha": alpha,
496
+ },
497
+ },
498
+ )
499
+ futures.append(executor.submit(runner.run_replica))
500
+
501
+ state_dicts = []
502
+
503
+ for fut in as_completed(futures):
504
+ continue
505
+
506
+ for fut in futures:
507
+ try:
508
+ state_dicts.append(fut.result()[0])
509
+ except Exception as e:
510
+ print(e)
511
+ raise
512
+
513
+ lighthouse.shutdown()
514
+
515
+ rep0, rep1, rep2 = state_dicts
516
+
517
+ assert_equal_global_state(n_fragments, rep0, rep1)
518
+ assert_equal_global_state(n_fragments, rep0, rep2)
519
+
520
+ for event_injector in event_injectors:
521
+ self.assertEqual(event_injectors[1].count[EventInjectorEvent.Barrier], 1)
522
+
523
+ # pyre-fixme[56]: Pyre was not able to infer the type of argument
524
+ @skipIf(sys.platform == "darwin", "not reliable on mac")
525
+ @parameterized.expand(CONFIG)
526
+ def test_streaming_diloco_commit_failure(
527
+ self, use_cuda: bool, n_fragments: int, fragment_sync_delay: int, alpha: float
528
+ ) -> None:
529
+ # Skip the test if use_cuda is True and there are not enough GPUs
530
+ if use_cuda and torch.cuda.device_count() < 2:
531
+ self.skipTest("Not enough GPUs for CUDA test")
532
+ if sys.platform == "darwin":
533
+ self.skipTest("not reliable on mac")
534
+
535
+ lighthouse = LighthouseServer(
536
+ bind="[::]:0",
537
+ min_replicas=2,
538
+ )
539
+ num_replicas = 2
540
+ futures = []
541
+ executors = []
542
+
543
+ event_injectors = [
544
+ EventInjector().fail_allreduce_at(0, 1),
545
+ EventInjector().fail_allreduce_at(0, 1),
546
+ ]
547
+
548
+ torch.manual_seed(42)
549
+ # Initialize the model so we can pass in the state_dict
550
+ m: nn.Module = MultiMyModel(2, 3, n_fragments)
551
+
552
+ for replica_id, event_injector in zip(range(num_replicas), event_injectors):
553
+ executor = ThreadPoolExecutor(max_workers=1)
554
+ executors.append(executor)
555
+ runner = Runner(
556
+ replica_id=replica_id,
557
+ num_replicas=num_replicas,
558
+ lighthouse_address=lighthouse.address(),
559
+ event_injector=event_injector,
560
+ train_loop=diloco_train_loop,
561
+ train_loop_args={
562
+ "model_state_dict": m.state_dict(),
563
+ "n_fragments": n_fragments,
564
+ "diloco_args": {
565
+ "fragment_sync_delay": fragment_sync_delay,
566
+ "sync_every": 4,
567
+ "fragment_update_alpha": alpha,
568
+ },
569
+ },
570
+ )
571
+ futures.append(executor.submit(runner.run_replica))
572
+
573
+ state_dicts = []
574
+
575
+ for fut in as_completed(futures):
576
+ continue
577
+
578
+ for fut in futures:
579
+ try:
580
+ state_dicts.append(fut.result()[0])
581
+ except Exception as e:
582
+ print(e)
583
+ raise
584
+
585
+ lighthouse.shutdown()
586
+
587
+ rep0, rep1 = state_dicts
588
+
589
+ assert_equal_global_state(n_fragments, rep0, rep1)
590
+
591
+ for event_injector in event_injectors:
592
+ self.assertEqual(
593
+ event_injector.count[EventInjectorEvent.AllreduceFailure], 1
594
+ )
595
+
596
+
597
+ if __name__ == "__main__":
598
+ import unittest
599
+
600
+ unittest.main()