torchft-nightly 2026.1.3__cp310-cp310-manylinux_2_24_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. torchft/__init__.py +34 -0
  2. torchft/_test/diloco_trainer.py +287 -0
  3. torchft/_test/managed_work_test.py +320 -0
  4. torchft/_test_utils.py +111 -0
  5. torchft/_torchft.cpython-310-x86_64-linux-gnu.so +0 -0
  6. torchft/_torchft.pyi +116 -0
  7. torchft/checkpointing/__init__.py +20 -0
  8. torchft/checkpointing/_rwlock.py +136 -0
  9. torchft/checkpointing/_serialization.py +39 -0
  10. torchft/checkpointing/http_transport.py +299 -0
  11. torchft/checkpointing/http_transport_bench.py +61 -0
  12. torchft/checkpointing/http_transport_test.py +146 -0
  13. torchft/checkpointing/pg_transport.py +306 -0
  14. torchft/checkpointing/pg_transport_bench.py +99 -0
  15. torchft/checkpointing/pg_transport_test.py +101 -0
  16. torchft/checkpointing/rwlock_test.py +58 -0
  17. torchft/checkpointing/transport.py +68 -0
  18. torchft/checkpointing/transport_test.py +161 -0
  19. torchft/collectives.py +415 -0
  20. torchft/collectives_test.py +212 -0
  21. torchft/coordination.py +39 -0
  22. torchft/coordination_test.py +29 -0
  23. torchft/data.py +77 -0
  24. torchft/data_test.py +39 -0
  25. torchft/ddp.py +105 -0
  26. torchft/ddp_test.py +68 -0
  27. torchft/diloco_regression_test.py +644 -0
  28. torchft/examples/slurm/README.md +34 -0
  29. torchft/examples/slurm/punisher.py +95 -0
  30. torchft/examples/slurm/runner.py +221 -0
  31. torchft/fsdp_test.py +102 -0
  32. torchft/futures.py +353 -0
  33. torchft/futures_test.py +140 -0
  34. torchft/http.py +13 -0
  35. torchft/lighthouse_test.py +163 -0
  36. torchft/local_sgd.py +796 -0
  37. torchft/local_sgd_integ_test.py +600 -0
  38. torchft/local_sgd_test.py +324 -0
  39. torchft/manager.py +1358 -0
  40. torchft/manager_integ_test.py +653 -0
  41. torchft/manager_test.py +911 -0
  42. torchft/multiprocessing.py +38 -0
  43. torchft/multiprocessing_dummy_context.py +135 -0
  44. torchft/multiprocessing_test.py +58 -0
  45. torchft/optim.py +63 -0
  46. torchft/optim_test.py +50 -0
  47. torchft/otel.py +134 -0
  48. torchft/parameter_server.py +195 -0
  49. torchft/parameter_server_test.py +47 -0
  50. torchft/process_group.py +2118 -0
  51. torchft/process_group_test.py +1028 -0
  52. torchft/quantization.py +686 -0
  53. torchft/quantization_test.py +131 -0
  54. torchft/torchx.py +89 -0
  55. torchft/utils.py +67 -0
  56. torchft/work.py +26 -0
  57. torchft_nightly-2026.1.3.dist-info/METADATA +308 -0
  58. torchft_nightly-2026.1.3.dist-info/RECORD +61 -0
  59. torchft_nightly-2026.1.3.dist-info/WHEEL +4 -0
  60. torchft_nightly-2026.1.3.dist-info/entry_points.txt +2 -0
  61. torchft_nightly-2026.1.3.dist-info/licenses/LICENSE +34 -0
@@ -0,0 +1,653 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import copy
8
+ import logging
9
+ import threading
10
+ import time
11
+ import traceback
12
+ from collections import defaultdict
13
+ from concurrent.futures import as_completed, ThreadPoolExecutor
14
+ from contextlib import contextmanager, ExitStack
15
+ from dataclasses import dataclass, field
16
+ from datetime import timedelta
17
+ from enum import auto, Enum
18
+ from typing import (
19
+ Any,
20
+ cast,
21
+ Dict,
22
+ Generator,
23
+ List,
24
+ Optional,
25
+ Protocol,
26
+ Set,
27
+ Tuple,
28
+ TypeVar,
29
+ )
30
+ from unittest import TestCase
31
+
32
+ import torch
33
+ import torch.distributed as dist
34
+ from parameterized import parameterized
35
+ from torch import nn, optim
36
+ from torch._dynamo.utils import timed
37
+
38
+ from torchft._torchft import LighthouseServer
39
+ from torchft.ddp import DistributedDataParallel
40
+ from torchft.local_sgd import DiLoCo, LocalSGD
41
+ from torchft.manager import Manager
42
+ from torchft.optim import OptimizerWrapper
43
+ from torchft.process_group import (
44
+ FakeProcessGroupWrapper,
45
+ ProcessGroupBabyNCCL,
46
+ ProcessGroupGloo,
47
+ )
48
+
49
+ logging.basicConfig(level=logging.INFO)
50
+ logger: logging.Logger = logging.getLogger(__name__)
51
+
52
+ INIT_LOCK: threading.Lock = threading.Lock()
53
+
54
+
55
+ class MyModel(nn.Module):
56
+ def __init__(self, in_dim: int = 3, out_dim: int = 4) -> None:
57
+ super().__init__()
58
+ self.in_dim = in_dim
59
+ self.out_dim = out_dim
60
+ self.layers = nn.ModuleList(
61
+ [
62
+ nn.Linear(in_dim, 8),
63
+ nn.ReLU(),
64
+ nn.Linear(8, out_dim),
65
+ nn.ReLU(),
66
+ ]
67
+ )
68
+
69
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
70
+ for layer in self.layers:
71
+ x = layer(x)
72
+ return x
73
+
74
+ def get_rand_inputs(
75
+ self, batch_size: int, device: torch.device = torch.device("cpu")
76
+ ) -> torch.Tensor:
77
+ return torch.rand(batch_size, self.in_dim, device=device)
78
+
79
+ def get_rand_labels(
80
+ self, batch_size: int, device: torch.device = torch.device("cpu")
81
+ ) -> torch.Tensor:
82
+ return torch.randint(3, (batch_size,), device=device)
83
+
84
+
85
+ class InjectedFailure(Exception):
86
+ pass
87
+
88
+
89
+ class EventInjectorEvent(Enum):
90
+ # Crashes a rank
91
+ Failure = auto()
92
+ # Used to wait for a rank to reach a certain step before continuing.
93
+ # Users need to make sure the size of the barrier is appropriately set.
94
+ Barrier = auto()
95
+ # Fails the allreduce call made by a rank
96
+ AllreduceFailure = auto()
97
+
98
+
99
+ class EventInjectorInfo:
100
+ def __init__(self, event: EventInjectorEvent, data: object) -> None:
101
+ self.event = event
102
+ self.data = data
103
+
104
+
105
+ class EventInjector:
106
+ def __init__(self) -> None:
107
+ self._pg: Optional[FakeProcessGroupWrapper] = None
108
+ self._lock = threading.Lock()
109
+ self._events: Dict[Tuple[int, int], EventInjectorInfo] = {}
110
+ self.count: dict[EventInjectorEvent, int] = defaultdict(int)
111
+
112
+ def set_pg(self, pg: FakeProcessGroupWrapper) -> None:
113
+ with self._lock:
114
+ self._pg = pg
115
+
116
+ def fail_at(self, rank: int, step: int) -> "EventInjector":
117
+ with self._lock:
118
+ assert (rank, step) not in self._events
119
+ self._events[(rank, step)] = EventInjectorInfo(
120
+ EventInjectorEvent.Failure, None
121
+ )
122
+ return self
123
+
124
+ def fail_allreduce_at(self, rank: int, step: int) -> "EventInjector":
125
+ with self._lock:
126
+ assert (rank, step) not in self._events
127
+ self._events[(rank, step)] = EventInjectorInfo(
128
+ EventInjectorEvent.AllreduceFailure, None
129
+ )
130
+ return self
131
+
132
+ def barrier_at(
133
+ self, rank: int, step: int, barrier: threading.Barrier
134
+ ) -> "EventInjector":
135
+ with self._lock:
136
+ assert (rank, step) not in self._events
137
+ self._events[(rank, step)] = EventInjectorInfo(
138
+ EventInjectorEvent.Barrier, barrier
139
+ )
140
+ return self
141
+
142
+ def check(self, rank: int, step: int) -> None:
143
+ with self._lock:
144
+ key = (rank, step)
145
+ if key in self._events:
146
+ event_info = self._events.pop(key)
147
+
148
+ self.count[event_info.event] += 1
149
+
150
+ if event_info.event == EventInjectorEvent.Failure:
151
+ print(f"injecting failure {rank=} {step=}")
152
+ raise InjectedFailure(f"injected failure {rank=} {step=}")
153
+
154
+ if event_info.event == EventInjectorEvent.AllreduceFailure:
155
+ print(f"injecting allreduce failure {rank=} {step=}")
156
+ assert self._pg is not None
157
+ self._pg.report_future_error(
158
+ RuntimeError("injected allreduce error")
159
+ )
160
+ return
161
+
162
+ if event_info.event == EventInjectorEvent.Barrier:
163
+ print(f"waiting for barrier {rank=} {step=}")
164
+ cast(threading.Barrier, event_info.data).wait()
165
+ return
166
+
167
+ raise RuntimeError(f"unknown event {event_info.event}")
168
+
169
+
170
+ # R for an arbitrary return type
171
+ R = TypeVar("R", covariant=True)
172
+
173
+
174
+ class TrainLoop(Protocol[R]):
175
+ def __call__(
176
+ self,
177
+ rank: int,
178
+ store_port: int,
179
+ device: torch.device,
180
+ runner: "Runner",
181
+ train_loop_args: dict[str, Any] = field(default_factory=dict),
182
+ ) -> R: ...
183
+
184
+
185
+ @dataclass
186
+ class Runner:
187
+ replica_id: int
188
+ num_replicas: int
189
+ lighthouse_address: str
190
+ event_injector: EventInjector
191
+ train_loop: TrainLoop[object]
192
+
193
+ use_cuda: bool = False
194
+ world_size: int = 1
195
+ attempts: int = 3
196
+ manager_args: Dict[str, object] = field(default_factory=dict)
197
+ train_loop_args: Dict[str, Any] = field(default_factory=dict)
198
+
199
+ def _replica_main(self) -> List[object]:
200
+ store = dist.TCPStore(
201
+ host_name="localhost",
202
+ port=0,
203
+ is_master=True,
204
+ wait_for_workers=False,
205
+ )
206
+
207
+ with ThreadPoolExecutor(
208
+ max_workers=self.world_size, thread_name_prefix=f"replica{self.replica_id}"
209
+ ) as executor:
210
+ futures = []
211
+ for rank in range(self.world_size):
212
+ if self.use_cuda:
213
+ num_cuda_devices = torch.cuda.device_count()
214
+ assert num_cuda_devices >= self.num_replicas
215
+ device_index = (
216
+ num_cuda_devices // self.num_replicas
217
+ ) * self.replica_id + rank
218
+ device = torch.device(f"cuda:{device_index}")
219
+ else:
220
+ device = torch.device("cpu")
221
+
222
+ futures.append(
223
+ executor.submit(
224
+ self.train_loop,
225
+ rank=rank,
226
+ store_port=store.port,
227
+ device=device,
228
+ runner=self,
229
+ train_loop_args=self.train_loop_args,
230
+ )
231
+ )
232
+
233
+ for fut in as_completed(futures):
234
+ try:
235
+ fut.result()
236
+ except Exception as e:
237
+ logger.exception(f"worker {self.replica_id=} threw exception: {e}")
238
+ raise
239
+
240
+ return [fut.result() for fut in futures]
241
+
242
+ def run_replica(self) -> List[object]:
243
+ for i in range(self.attempts):
244
+ try:
245
+ print(
246
+ f"starting replica group {self.replica_id=} {self.world_size=} attempt {i}"
247
+ )
248
+ return self._replica_main()
249
+ except InjectedFailure as e:
250
+ print("got injected failure", i, e)
251
+ if i == self.attempts - 1:
252
+ raise
253
+ continue
254
+
255
+ raise RuntimeError("ran out of attempts")
256
+
257
+
258
+ def ddp_train_loop(
259
+ rank: int,
260
+ store_port: int,
261
+ device: torch.device,
262
+ runner: Runner,
263
+ train_loop_args: dict[str, Any] = {},
264
+ ) -> Dict[str, Dict[str, object]]:
265
+ with ExitStack() as stack:
266
+
267
+ def load_state_dict(state_dict: Dict[str, Dict[str, object]]) -> None:
268
+ m.load_state_dict(state_dict["model"])
269
+ optimizer.load_state_dict(state_dict["optim"])
270
+
271
+ def state_dict() -> Dict[str, Dict[str, object]]:
272
+ return {
273
+ "model": m.state_dict(),
274
+ "optim": optimizer.state_dict(),
275
+ }
276
+
277
+ print(f"worker {runner.replica_id=} {rank=} {runner.world_size=} starting")
278
+
279
+ pg = ProcessGroupGloo()
280
+ manager = Manager(
281
+ pg=pg,
282
+ min_replica_size=2,
283
+ load_state_dict=load_state_dict,
284
+ state_dict=state_dict,
285
+ replica_id=str(runner.replica_id),
286
+ store_addr="localhost",
287
+ store_port=store_port,
288
+ rank=rank,
289
+ world_size=runner.world_size,
290
+ lighthouse_addr=runner.lighthouse_address,
291
+ port=19530 + runner.replica_id,
292
+ # pyre-fixme[6]: Incompatible parameter type
293
+ **runner.manager_args,
294
+ )
295
+ stack.callback(lambda: manager.shutdown(wait=False))
296
+
297
+ with INIT_LOCK:
298
+ # We need to lock during init for testing init_sync=False as all
299
+ # threads share the same RNG
300
+ torch.manual_seed(42)
301
+ m: nn.Module = MyModel()
302
+
303
+ m: nn.Module = DistributedDataParallel(manager, m)
304
+ optimizer: optim.Optimizer = OptimizerWrapper(
305
+ manager, optim.Adam(m.parameters())
306
+ )
307
+ criterion = nn.CrossEntropyLoss()
308
+
309
+ while True:
310
+ inputs = torch.rand(2, 3)
311
+ labels = torch.randint(4, (2,))
312
+
313
+ optimizer.zero_grad()
314
+ out = m(inputs)
315
+ loss = criterion(out, labels)
316
+
317
+ loss.backward()
318
+
319
+ optimizer.step()
320
+
321
+ if manager.current_step() >= 4:
322
+ break
323
+
324
+ runner.event_injector.check(rank, manager.current_step())
325
+
326
+ # return state_dict so we can check consistency
327
+ return state_dict()
328
+
329
+
330
+ class ManagerIntegTest(TestCase):
331
+ @contextmanager
332
+ def assertElapsedLessThan(
333
+ self, timeout: float, msg: str = ""
334
+ ) -> Generator[None, None, None]:
335
+ start = time.perf_counter()
336
+ yield
337
+ elapsed = time.perf_counter() - start
338
+ self.assertLess(elapsed, timeout, msg)
339
+
340
+ def test_ddp_healthy(self) -> None:
341
+ lighthouse = LighthouseServer(
342
+ bind="[::]:0",
343
+ min_replicas=2,
344
+ )
345
+ num_replicas = 2
346
+ futures = []
347
+
348
+ with ThreadPoolExecutor(max_workers=num_replicas) as executor:
349
+ for replica_id in range(num_replicas):
350
+ event_injector = EventInjector()
351
+ runner = Runner(
352
+ replica_id=replica_id,
353
+ num_replicas=num_replicas,
354
+ lighthouse_address=lighthouse.address(),
355
+ event_injector=event_injector,
356
+ train_loop=ddp_train_loop,
357
+ )
358
+ futures.append(executor.submit(runner.run_replica))
359
+
360
+ state_dicts = []
361
+
362
+ for fut in as_completed(futures):
363
+ state_dicts.append(fut.result())
364
+
365
+ lighthouse.shutdown()
366
+
367
+ for state_dict in state_dicts:
368
+ torch.testing.assert_close(state_dict, state_dicts[0])
369
+
370
+ @parameterized.expand(
371
+ [
372
+ (
373
+ "async_quorum",
374
+ True,
375
+ ),
376
+ (
377
+ "sync_quorum",
378
+ False,
379
+ ),
380
+ ]
381
+ )
382
+ def test_ddp_recovery(
383
+ self,
384
+ name: str,
385
+ use_async_quorum: bool,
386
+ ) -> None:
387
+ lighthouse = LighthouseServer(
388
+ bind="[::]:0",
389
+ min_replicas=2,
390
+ )
391
+ num_replicas = 2
392
+ futures = []
393
+
394
+ event_injectors = [
395
+ EventInjector(),
396
+ EventInjector().fail_at(0, 2),
397
+ ]
398
+
399
+ with ThreadPoolExecutor(max_workers=num_replicas) as executor:
400
+ for replica_id, event_injector in zip(range(num_replicas), event_injectors):
401
+ runner = Runner(
402
+ replica_id=replica_id,
403
+ num_replicas=num_replicas,
404
+ lighthouse_address=lighthouse.address(),
405
+ event_injector=event_injector,
406
+ manager_args={
407
+ "use_async_quorum": use_async_quorum,
408
+ },
409
+ train_loop=ddp_train_loop,
410
+ )
411
+ futures.append(executor.submit(runner.run_replica))
412
+
413
+ state_dicts = []
414
+
415
+ for fut in as_completed(futures):
416
+ try:
417
+ state_dicts.append(fut.result())
418
+ except Exception as e:
419
+ print(e)
420
+ raise
421
+
422
+ lighthouse.shutdown()
423
+
424
+ for state_dict in state_dicts:
425
+ torch.testing.assert_close(state_dict, state_dicts[0])
426
+
427
+ self.assertEqual(event_injectors[1].count[EventInjectorEvent.Failure], 1)
428
+
429
+ def test_ddp_skip_init_sync(
430
+ self,
431
+ ) -> None:
432
+ lighthouse = LighthouseServer(
433
+ bind="[::]:0",
434
+ min_replicas=2,
435
+ )
436
+ num_replicas = 2
437
+ futures = []
438
+
439
+ # no failures
440
+ event_injectors = [
441
+ EventInjector(),
442
+ EventInjector(),
443
+ ]
444
+
445
+ with ThreadPoolExecutor(max_workers=num_replicas) as executor:
446
+ for replica_id, event_injector in zip(range(num_replicas), event_injectors):
447
+ runner = Runner(
448
+ replica_id=replica_id,
449
+ num_replicas=num_replicas,
450
+ lighthouse_address=lighthouse.address(),
451
+ event_injector=event_injector,
452
+ manager_args={
453
+ "use_async_quorum": False,
454
+ "init_sync": False,
455
+ },
456
+ train_loop=ddp_train_loop,
457
+ )
458
+ futures.append(executor.submit(runner.run_replica))
459
+
460
+ state_dicts = []
461
+
462
+ for fut in as_completed(futures):
463
+ try:
464
+ state_dicts.append(fut.result())
465
+ except Exception as e:
466
+ print(e)
467
+ raise
468
+
469
+ lighthouse.shutdown()
470
+
471
+ for state_dict in state_dicts:
472
+ torch.testing.assert_close(state_dict, state_dicts[0])
473
+
474
+ def test_ddp_recovery_multi_rank(self) -> None:
475
+ lighthouse = LighthouseServer(
476
+ bind="[::]:0",
477
+ min_replicas=2,
478
+ )
479
+ num_replicas = 2
480
+ world_size = 2
481
+ futures = []
482
+
483
+ event_injectors = [
484
+ EventInjector(),
485
+ EventInjector().fail_at(0, 2).fail_at(1, 2),
486
+ ]
487
+
488
+ with ThreadPoolExecutor(max_workers=num_replicas) as executor:
489
+ for replica_id, event_injector in zip(range(num_replicas), event_injectors):
490
+ runner = Runner(
491
+ replica_id=replica_id,
492
+ num_replicas=num_replicas,
493
+ lighthouse_address=lighthouse.address(),
494
+ event_injector=event_injector,
495
+ world_size=world_size,
496
+ train_loop=ddp_train_loop,
497
+ )
498
+ futures.append(executor.submit(runner.run_replica))
499
+
500
+ state_dicts = []
501
+
502
+ for fut in as_completed(futures):
503
+ try:
504
+ state_dicts.append(fut.result())
505
+ except Exception as e:
506
+ print(e)
507
+ raise
508
+
509
+ lighthouse.shutdown()
510
+
511
+ for state_dict in state_dicts:
512
+ torch.testing.assert_close(state_dict, state_dicts[0])
513
+
514
+ def test_quorum_timeout(self) -> None:
515
+ with ExitStack() as stack:
516
+ lighthouse = LighthouseServer(
517
+ bind="[::]:0",
518
+ min_replicas=2,
519
+ )
520
+ stack.callback(lighthouse.shutdown)
521
+
522
+ store = dist.TCPStore(
523
+ host_name="localhost",
524
+ port=0,
525
+ is_master=True,
526
+ wait_for_workers=False,
527
+ )
528
+
529
+ pg = ProcessGroupGloo()
530
+ manager = Manager(
531
+ pg=pg,
532
+ min_replica_size=2,
533
+ load_state_dict=lambda x: None,
534
+ state_dict=lambda: None,
535
+ store_addr="localhost",
536
+ store_port=store.port,
537
+ rank=0,
538
+ world_size=2,
539
+ lighthouse_addr=lighthouse.address(),
540
+ port=19530,
541
+ use_async_quorum=False,
542
+ )
543
+ stack.callback(lambda: manager.shutdown(wait=False))
544
+
545
+ with self.assertElapsedLessThan(1.0):
546
+ with self.assertRaisesRegex(
547
+ TimeoutError,
548
+ "status: Cancelled, message.*Timeout expired",
549
+ ):
550
+ manager.start_quorum(timeout=timedelta(seconds=0.01))
551
+
552
+ with self.assertElapsedLessThan(1.0):
553
+ with self.assertRaisesRegex(
554
+ TimeoutError,
555
+ "status: Cancelled, message.*Timeout expired",
556
+ ):
557
+ manager.should_commit(timeout=timedelta(seconds=0.01))
558
+
559
+ @parameterized.expand(
560
+ [
561
+ (True,), # Test with CUDA
562
+ (False,), # Test without CUDA (CPU)
563
+ ]
564
+ )
565
+ def test_manager_allreduce(self, use_cuda: bool) -> None:
566
+ # Skip the test if use_cuda is True and there are not enough GPUs
567
+ if use_cuda and torch.cuda.device_count() < 2:
568
+ self.skipTest("Not enough GPUs for CUDA test")
569
+
570
+ # manager supports allreduce but we found an issue where the future callback is getting called
571
+ # before the allreduce is complete. This test is to ensure that the callback has stream synchronization
572
+ lighthouse = LighthouseServer(
573
+ bind="[::]:0",
574
+ min_replicas=2,
575
+ )
576
+ num_replicas = 2
577
+ futures = []
578
+
579
+ with ThreadPoolExecutor(max_workers=num_replicas) as executor:
580
+ for replica_id in range(num_replicas):
581
+ event_injector = EventInjector()
582
+ runner = Runner(
583
+ replica_id=replica_id,
584
+ num_replicas=num_replicas,
585
+ lighthouse_address=lighthouse.address(),
586
+ event_injector=event_injector,
587
+ train_loop=all_reduce_callback,
588
+ use_cuda=use_cuda,
589
+ )
590
+ futures.append(executor.submit(runner.run_replica))
591
+
592
+ results = []
593
+ for fut in as_completed(futures):
594
+ try:
595
+ results.append(fut.result()[0])
596
+ except Exception as e:
597
+ print(e, flush=True)
598
+ traceback.print_exc()
599
+ raise
600
+
601
+ lighthouse.shutdown()
602
+
603
+ print(results)
604
+ r0, r1 = results
605
+ torch.testing.assert_close(r0, r1, check_device=False)
606
+
607
+
608
+ def all_reduce_callback(
609
+ rank: int,
610
+ store_port: int,
611
+ device: torch.device,
612
+ runner: Runner,
613
+ train_loop_args: dict[str, Any] = {},
614
+ ) -> Optional[torch.Tensor]:
615
+ with ExitStack() as stack:
616
+ print(f"worker {runner.replica_id=} {rank=} {runner.world_size=} starting")
617
+
618
+ if device.type == "cuda":
619
+ pg = ProcessGroupBabyNCCL()
620
+ else:
621
+ pg = ProcessGroupGloo()
622
+ manager = Manager(
623
+ pg=pg,
624
+ min_replica_size=2,
625
+ use_async_quorum=False,
626
+ load_state_dict=lambda x: None,
627
+ state_dict=lambda: None,
628
+ replica_id=str(runner.replica_id),
629
+ store_addr="localhost",
630
+ store_port=store_port,
631
+ rank=rank,
632
+ world_size=runner.world_size,
633
+ lighthouse_addr=runner.lighthouse_address,
634
+ port=19530 + runner.replica_id,
635
+ timeout=timedelta(seconds=10),
636
+ quorum_timeout=timedelta(seconds=10),
637
+ # pyre-fixme[6]: Incompatible parameter type
638
+ **runner.manager_args,
639
+ )
640
+ stack.callback(lambda: manager.shutdown(wait=False))
641
+
642
+ manager.start_quorum()
643
+ t1 = torch.ones((1, 3), device=device)
644
+ work = manager.allreduce(t1)
645
+ work.wait()
646
+ return t1
647
+ return None
648
+
649
+
650
+ if __name__ == "__main__":
651
+ import unittest
652
+
653
+ unittest.main()