torchft-nightly 2026.1.3__cp310-cp310-manylinux_2_24_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. torchft/__init__.py +34 -0
  2. torchft/_test/diloco_trainer.py +287 -0
  3. torchft/_test/managed_work_test.py +320 -0
  4. torchft/_test_utils.py +111 -0
  5. torchft/_torchft.cpython-310-x86_64-linux-gnu.so +0 -0
  6. torchft/_torchft.pyi +116 -0
  7. torchft/checkpointing/__init__.py +20 -0
  8. torchft/checkpointing/_rwlock.py +136 -0
  9. torchft/checkpointing/_serialization.py +39 -0
  10. torchft/checkpointing/http_transport.py +299 -0
  11. torchft/checkpointing/http_transport_bench.py +61 -0
  12. torchft/checkpointing/http_transport_test.py +146 -0
  13. torchft/checkpointing/pg_transport.py +306 -0
  14. torchft/checkpointing/pg_transport_bench.py +99 -0
  15. torchft/checkpointing/pg_transport_test.py +101 -0
  16. torchft/checkpointing/rwlock_test.py +58 -0
  17. torchft/checkpointing/transport.py +68 -0
  18. torchft/checkpointing/transport_test.py +161 -0
  19. torchft/collectives.py +415 -0
  20. torchft/collectives_test.py +212 -0
  21. torchft/coordination.py +39 -0
  22. torchft/coordination_test.py +29 -0
  23. torchft/data.py +77 -0
  24. torchft/data_test.py +39 -0
  25. torchft/ddp.py +105 -0
  26. torchft/ddp_test.py +68 -0
  27. torchft/diloco_regression_test.py +644 -0
  28. torchft/examples/slurm/README.md +34 -0
  29. torchft/examples/slurm/punisher.py +95 -0
  30. torchft/examples/slurm/runner.py +221 -0
  31. torchft/fsdp_test.py +102 -0
  32. torchft/futures.py +353 -0
  33. torchft/futures_test.py +140 -0
  34. torchft/http.py +13 -0
  35. torchft/lighthouse_test.py +163 -0
  36. torchft/local_sgd.py +796 -0
  37. torchft/local_sgd_integ_test.py +600 -0
  38. torchft/local_sgd_test.py +324 -0
  39. torchft/manager.py +1358 -0
  40. torchft/manager_integ_test.py +653 -0
  41. torchft/manager_test.py +911 -0
  42. torchft/multiprocessing.py +38 -0
  43. torchft/multiprocessing_dummy_context.py +135 -0
  44. torchft/multiprocessing_test.py +58 -0
  45. torchft/optim.py +63 -0
  46. torchft/optim_test.py +50 -0
  47. torchft/otel.py +134 -0
  48. torchft/parameter_server.py +195 -0
  49. torchft/parameter_server_test.py +47 -0
  50. torchft/process_group.py +2118 -0
  51. torchft/process_group_test.py +1028 -0
  52. torchft/quantization.py +686 -0
  53. torchft/quantization_test.py +131 -0
  54. torchft/torchx.py +89 -0
  55. torchft/utils.py +67 -0
  56. torchft/work.py +26 -0
  57. torchft_nightly-2026.1.3.dist-info/METADATA +308 -0
  58. torchft_nightly-2026.1.3.dist-info/RECORD +61 -0
  59. torchft_nightly-2026.1.3.dist-info/WHEEL +4 -0
  60. torchft_nightly-2026.1.3.dist-info/entry_points.txt +2 -0
  61. torchft_nightly-2026.1.3.dist-info/licenses/LICENSE +34 -0
torchft/futures.py ADDED
@@ -0,0 +1,353 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import asyncio
8
+ import os
9
+ import queue
10
+ import sys
11
+ import threading
12
+ import time
13
+ from contextlib import contextmanager, nullcontext
14
+ from datetime import timedelta
15
+ from typing import Callable, Generator, Optional, TypeVar
16
+ from unittest.mock import Mock
17
+
18
+ import torch
19
+ from torch.futures import Future
20
+
21
+ from torchft.utils import get_stream_context
22
+
23
+ T = TypeVar("T")
24
+
25
+ WATCHDOG_TIMEOUT_SEC = "TORCHFT_WATCHDOG_TIMEOUT_SEC"
26
+
27
+
28
+ class _TimerHandle:
29
+ def __init__(self) -> None:
30
+ self._lock = threading.Lock()
31
+ self._timer_handle: Optional[asyncio.TimerHandle] = None
32
+ self._cancelled = False
33
+
34
+ def set_timer_handle(self, timer_handle: asyncio.TimerHandle) -> None:
35
+ with self._lock:
36
+ if self._cancelled:
37
+ timer_handle.cancel()
38
+ self._timer_handle = None
39
+ else:
40
+ self._timer_handle = timer_handle
41
+
42
+ def cancel(self) -> None:
43
+ with self._lock:
44
+ assert not self._cancelled, "timer can only be cancelled once"
45
+ self._cancelled = True
46
+ if self._timer_handle is not None:
47
+ self._timer_handle.cancel()
48
+ self._timer_handle = None
49
+
50
+
51
+ class _TimeoutManager:
52
+ """
53
+ This class manages timeouts for code blocks, futures and CUDA events. It
54
+ uses a background thread with an event loop to schedule the timeouts and
55
+ call the callback function when the timeout is reached.
56
+
57
+ Generally there is a single instance of this class that is used for all
58
+ timeouts. The callbacks should not block otherwise other timeouts may not
59
+ be processed.
60
+ """
61
+
62
+ def __init__(self) -> None:
63
+ self._lock = threading.Lock()
64
+ self._event_loop: Optional[asyncio.AbstractEventLoop] = None
65
+ self._event_loop_thread: Optional[threading.Thread] = None
66
+ self._next_timer_id = 0
67
+
68
+ # Ensures `_event_loop_thread` is not stuck
69
+ self._watchdog_thread: Optional[threading.Thread] = None
70
+
71
+ # Give this much time the the `_event_loop_thread` to confirm that
72
+ # it is not stuck
73
+ self._watchdog_interval = timedelta(
74
+ seconds=int(os.environ.get(WATCHDOG_TIMEOUT_SEC, "30"))
75
+ )
76
+
77
+ # This queue is used to delete events on the main thread as cudaEventDestroy
78
+ # can block if the CUDA queue is full.
79
+ self._del_queue: queue.SimpleQueue[object] = queue.SimpleQueue()
80
+
81
+ def _maybe_start_event_loop(self) -> asyncio.AbstractEventLoop:
82
+ """
83
+ Start the event loop if it has not already been started.
84
+ """
85
+ with self._lock:
86
+ if self._event_loop is None:
87
+ self._event_loop = asyncio.new_event_loop()
88
+ self._event_loop_thread = threading.Thread(
89
+ target=self._event_loop.run_forever,
90
+ daemon=True,
91
+ name="TimeoutManager",
92
+ )
93
+ self._event_loop_thread.start()
94
+
95
+ self._watchdog_thread = threading.Thread(
96
+ target=self._watchdog_loop, daemon=True
97
+ )
98
+ self._watchdog_thread.start()
99
+
100
+ # pyre-fixme[7]: optional
101
+ return self._event_loop
102
+
103
+ def _watchdog_loop(self) -> None:
104
+ while True:
105
+ is_healthy = False
106
+
107
+ def updated_health() -> None:
108
+ nonlocal is_healthy
109
+ is_healthy = True
110
+
111
+ with self._lock:
112
+ if self._event_loop is None:
113
+ return
114
+
115
+ # The method passed to the event loop should finish fast.
116
+ # It just updates a bool, which is also thread safe.
117
+ self._event_loop.call_soon_threadsafe(updated_health)
118
+
119
+ time.sleep(self._watchdog_interval.total_seconds())
120
+
121
+ if not is_healthy:
122
+ print("TimeoutManager is stuck. Exiting.")
123
+ sys.exit(1)
124
+ # Needed becuase `sys.exit` is mocked in unit tests.
125
+ # If we don't return here, we don't break out of the loop.
126
+ return
127
+
128
+ def shutdown(self) -> None:
129
+ """
130
+ Shutdown the event loop and cancel all pending timeouts.
131
+ """
132
+ watchdog_thread = None
133
+ with self._lock:
134
+ if self._event_loop is not None:
135
+ self._event_loop.call_soon_threadsafe(self._event_loop.stop)
136
+ assert self._event_loop_thread is not None
137
+ self._event_loop_thread.join()
138
+ self._event_loop = None
139
+ self._event_loop_thread = None
140
+
141
+ # We can't join the watchdog thread here because it grabs `lock_`
142
+ watchdog_thread = self._watchdog_thread
143
+
144
+ if watchdog_thread is not None:
145
+ # If `_maybe_start_event_loop` is called again, the it is possible the `join`
146
+ # below will never finish.
147
+ # This class assumes `_maybe_start_event_loop` will not be called after `shutdown`.
148
+ # If this functionality is required in the future, we could change the class to
149
+ # support this. Or create multiple instances of this class.
150
+ watchdog_thread.join()
151
+
152
+ def register(self, fut: Future[T], timeout: timedelta) -> Future[T]:
153
+ """
154
+ Registers a future that will be cancelled after the specified timeout.
155
+ """
156
+ # bypass timeout for mock futures
157
+ if isinstance(fut, Mock):
158
+ return fut
159
+
160
+ self._clear_del_queue()
161
+
162
+ loop = self._maybe_start_event_loop()
163
+
164
+ timed_fut: Future[T] = Future()
165
+ handle: _TimerHandle = _TimerHandle()
166
+ loop.call_soon_threadsafe(
167
+ self._register_callback,
168
+ loop,
169
+ lambda: timed_fut.set_exception(
170
+ # pyre-fixme[6]: e is not T
171
+ TimeoutError(f"future did not complete within {timeout}")
172
+ ),
173
+ timeout,
174
+ handle,
175
+ )
176
+
177
+ stream: Optional[torch.Stream] = (
178
+ torch.accelerator.current_stream()
179
+ if torch.accelerator.is_available()
180
+ else None
181
+ )
182
+
183
+ def callback(fut: Future[T]) -> None:
184
+ with get_stream_context(stream):
185
+ handle.cancel()
186
+ try:
187
+ timed_fut.set_result(fut.wait())
188
+ except Exception as e:
189
+ try:
190
+ # this can throw if the future is already done
191
+ # pyre-fixme[6]: e is not T
192
+ timed_fut.set_exception(e)
193
+ except Exception:
194
+ pass
195
+
196
+ fut.add_done_callback(callback)
197
+ return timed_fut
198
+
199
+ def stream_timeout(self, callback: Callable[[], None], timeout: timedelta) -> None:
200
+ self._clear_del_queue()
201
+
202
+ loop = self._maybe_start_event_loop()
203
+
204
+ event: torch.Event = torch.Event()
205
+ event.record()
206
+
207
+ def handler() -> None:
208
+ if not event.query():
209
+ callback()
210
+
211
+ # cudaEventDestroy can block so we never want to delete in the event
212
+ # loop. Put it on the del queue so we can delete it in the main
213
+ # thread.
214
+ self._del_queue.put(event)
215
+
216
+ loop.call_soon_threadsafe(
217
+ self._register_callback, loop, handler, timeout, _TimerHandle()
218
+ )
219
+
220
+ @classmethod
221
+ def _register_callback(
222
+ cls,
223
+ loop: asyncio.AbstractEventLoop,
224
+ callback: Callable[[], None],
225
+ timeout: timedelta,
226
+ handle: _TimerHandle,
227
+ ) -> None:
228
+ timer_handle = loop.call_later(
229
+ timeout.total_seconds(),
230
+ callback,
231
+ )
232
+ handle.set_timer_handle(timer_handle)
233
+
234
+ @contextmanager
235
+ def context_timeout(
236
+ self, callback: Callable[[], None], timeout: timedelta
237
+ ) -> Generator[None, None, None]:
238
+ self._clear_del_queue()
239
+
240
+ loop = self._maybe_start_event_loop()
241
+ handle = _TimerHandle()
242
+
243
+ loop.call_soon_threadsafe(
244
+ self._register_callback, loop, callback, timeout, handle
245
+ )
246
+
247
+ yield
248
+
249
+ handle.cancel()
250
+
251
+ def _clear_del_queue(self) -> int:
252
+ """
253
+ Clear the queue of futures to be deleted.
254
+
255
+ Returns the number of items deleted.
256
+ """
257
+ count = 0
258
+ while True:
259
+ try:
260
+ # get and immediately discard item
261
+ item = self._del_queue.get_nowait()
262
+ refcount = sys.getrefcount(item)
263
+ assert (
264
+ # 1 from item, 1 from getrefcount
265
+ refcount == 2
266
+ ), f"items in del_queue reference should not have other references, found {refcount=}"
267
+ del item
268
+
269
+ count += 1
270
+ except queue.Empty:
271
+ break
272
+
273
+ return count
274
+
275
+
276
+ _TIMEOUT_MANAGER = _TimeoutManager()
277
+
278
+
279
+ def future_timeout(fut: Future[T], timeout: timedelta) -> Future[T]:
280
+ """
281
+ Return a Future that completes with the result of the given Future within
282
+ the given timeout or with a TimeoutError.
283
+
284
+ Args:
285
+ fut: The Future to wait for
286
+ timeout: The timeout to wait for the Future to complete
287
+
288
+ Returns:
289
+ The future with a timeout
290
+ """
291
+ return _TIMEOUT_MANAGER.register(fut, timeout)
292
+
293
+
294
+ def future_wait(fut: Future[T], timeout: timedelta) -> T:
295
+ """
296
+ Wait for a Future to complete up to a timeout.
297
+
298
+ Args:
299
+ fut: The Future to wait for
300
+ timeout: The timeout to wait for the Future to complete
301
+
302
+ Returns:
303
+ The result of the Future if it completed within the timeout.
304
+
305
+ Raises:
306
+ TimeoutError if the Future did not complete within the timeout.
307
+ Any other exception that occurred in the Future.
308
+ """
309
+
310
+ event: threading.Event = threading.Event()
311
+
312
+ def callback(fut: Future[T]) -> T:
313
+ event.set()
314
+ return fut.wait()
315
+
316
+ fut = fut.then(callback)
317
+
318
+ if not event.wait(timeout=timeout.total_seconds()):
319
+ raise TimeoutError(f"future did not complete within {timeout}")
320
+
321
+ return fut.wait()
322
+
323
+
324
+ def stream_timeout(callback: Callable[[], None], timeout: timedelta) -> None:
325
+ """
326
+ Registers a callback that will be called after the specified timeout if
327
+ the current stream doesn't complete in time.
328
+
329
+ This uses a cuda Event to track the completion of the current stream. If
330
+ the stream is not complete after the timeout, the callback is called.
331
+
332
+ Args:
333
+ callback: The callback to call if the stream doesn't complete in time.
334
+ timeout: The timeout to wait for the stream to complete.
335
+ """
336
+ _TIMEOUT_MANAGER.stream_timeout(callback, timeout)
337
+
338
+
339
+ @contextmanager
340
+ def context_timeout(
341
+ callback: Callable[[], None], timeout: timedelta
342
+ ) -> Generator[None, None, None]:
343
+ """
344
+ Registers a callback that will be called after the specified timeout if
345
+ the current contextmanager doesn't exit in time.
346
+
347
+ Args:
348
+ callback: The callback to call if we time out.
349
+ timeout: How long to wait for the contextmanager to exit.
350
+ """
351
+
352
+ with _TIMEOUT_MANAGER.context_timeout(callback, timeout):
353
+ yield
@@ -0,0 +1,140 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import threading
8
+ from datetime import timedelta
9
+ from unittest import skipUnless, TestCase
10
+ from unittest.mock import Mock, patch
11
+
12
+ import torch
13
+ from torch.futures import Future
14
+
15
+ from torchft.futures import (
16
+ _TIMEOUT_MANAGER,
17
+ context_timeout,
18
+ future_timeout,
19
+ future_wait,
20
+ stream_timeout,
21
+ )
22
+
23
+
24
+ class FuturesTest(TestCase):
25
+ def setUp(self) -> None:
26
+ self._original_watchdog_interval = _TIMEOUT_MANAGER._watchdog_interval
27
+ _TIMEOUT_MANAGER._watchdog_interval = timedelta(seconds=1)
28
+
29
+ def tearDown(self) -> None:
30
+ _TIMEOUT_MANAGER._watchdog_interval = self._original_watchdog_interval
31
+
32
+ def test_future_wait(self) -> None:
33
+ fut = Future()
34
+ with self.assertRaisesRegex(TimeoutError, "future did not complete within"):
35
+ future_wait(fut, timeout=timedelta(seconds=0.01))
36
+
37
+ fut = Future()
38
+ fut.set_result(1)
39
+ self.assertEqual(future_wait(fut, timeout=timedelta(seconds=1.0)), 1)
40
+
41
+ fut = Future()
42
+ fut.set_exception(RuntimeError("test"))
43
+ with self.assertRaisesRegex(RuntimeError, "test"):
44
+ future_wait(fut, timeout=timedelta(seconds=1.0))
45
+
46
+ def test_future_timeout(self) -> None:
47
+ fut = Future()
48
+ timed_fut = future_timeout(fut, timeout=timedelta(seconds=0.01))
49
+ with self.assertRaisesRegex(TimeoutError, "future did not complete within"):
50
+ timed_fut.wait()
51
+
52
+ def test_future_timeout_result(self) -> None:
53
+ fut = Future()
54
+ timed_fut = future_timeout(fut, timeout=timedelta(seconds=10))
55
+ fut.set_result(1)
56
+ self.assertEqual(timed_fut.wait(), 1)
57
+
58
+ def test_future_timeout_exception(self) -> None:
59
+ fut = Future()
60
+ timed_fut = future_timeout(fut, timeout=timedelta(seconds=10))
61
+ fut.set_exception(RuntimeError("test"))
62
+ with self.assertRaisesRegex(RuntimeError, "test"):
63
+ timed_fut.wait()
64
+
65
+ def test_context_timeout(self) -> None:
66
+ barrier: threading.Barrier = threading.Barrier(2)
67
+
68
+ def callback() -> None:
69
+ barrier.wait()
70
+
71
+ with context_timeout(callback, timedelta(seconds=0.01)):
72
+ # block until timeout fires
73
+ barrier.wait()
74
+
75
+ def fail() -> None:
76
+ self.fail("timeout should be cancelled")
77
+
78
+ with context_timeout(fail, timedelta(seconds=10)):
79
+ pass
80
+
81
+ # pyre-fixme[56]: Pyre was not able to infer the type of decorator
82
+ @skipUnless(torch.cuda.is_available(), "CUDA is required for this test")
83
+ def test_stream_timeout(self) -> None:
84
+ torch.cuda.synchronize()
85
+
86
+ def callback() -> None:
87
+ self.fail()
88
+
89
+ stream_timeout(callback, timeout=timedelta(seconds=0.01))
90
+
91
+ # make sure event completes
92
+ torch.cuda.synchronize()
93
+
94
+ # make sure that event is deleted on the deletion queue
95
+ item = _TIMEOUT_MANAGER._del_queue.get(timeout=10.0)
96
+ _TIMEOUT_MANAGER._del_queue.put(item)
97
+ del item
98
+
99
+ self.assertEqual(_TIMEOUT_MANAGER._clear_del_queue(), 1)
100
+
101
+ # Test that when a timeout handle gets stuck, `sys.exit(1)` is called
102
+ @patch("sys.exit")
103
+ def test_exit_on_stuck_callback(self, mock_exit: Mock) -> None:
104
+ exit_event: threading.Event = threading.Event()
105
+
106
+ def custom_exit(_) -> None:
107
+ # 3. When event loop is stuck, exit(1) is called
108
+ nonlocal exit_event
109
+ exit_event.set()
110
+
111
+ mock_exit.side_effect = custom_exit
112
+
113
+ callback_event: threading.Event = threading.Event()
114
+
115
+ def callback() -> None:
116
+ # 2. Make sure callback blocks event loop
117
+ nonlocal callback_event
118
+ callback_event.wait()
119
+
120
+ context_event: threading.Event = threading.Event()
121
+
122
+ def thread_fn() -> None:
123
+ with context_timeout(callback, timedelta(seconds=0.01)):
124
+ # 1. Make sure context doesn't finish in time
125
+ nonlocal context_event
126
+ context_event.wait()
127
+
128
+ thread = threading.Thread(target=thread_fn)
129
+ thread.start()
130
+
131
+ # 4. exit event will wake this up
132
+ exit_event.wait()
133
+ mock_exit.assert_called_once_with(1)
134
+
135
+ # 5. event loop is still stuck, so let's unblock it
136
+ callback_event.set()
137
+
138
+ # 6. unblock the context and make sure it exits
139
+ context_event.set()
140
+ thread.join()
torchft/http.py ADDED
@@ -0,0 +1,13 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import socket
8
+ from http.server import ThreadingHTTPServer
9
+
10
+
11
+ class _IPv6HTTPServer(ThreadingHTTPServer):
12
+ address_family: socket.AddressFamily = socket.AF_INET6
13
+ request_queue_size: int = 1024
@@ -0,0 +1,163 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import time
8
+ from datetime import timedelta
9
+ from unittest import TestCase
10
+
11
+ import torch.distributed as dist
12
+
13
+ from torchft import Manager, ProcessGroupGloo
14
+ from torchft._torchft import LighthouseClient, LighthouseServer, Quorum, QuorumMember
15
+
16
+
17
+ class TestLighthouse(TestCase):
18
+ def test_join_timeout_behavior(self) -> None:
19
+ """Test that join_timeout_ms affects joining behavior"""
20
+ # To test, we create a lighthouse with 100ms and 400ms join timeouts
21
+ # and measure the time taken to validate the quorum.
22
+ lighthouse = LighthouseServer(
23
+ bind="[::]:0",
24
+ min_replicas=1,
25
+ join_timeout_ms=100,
26
+ )
27
+
28
+ # Create a manager that tries to join
29
+ try:
30
+ store = dist.TCPStore(
31
+ host_name="localhost",
32
+ port=0,
33
+ is_master=True,
34
+ wait_for_workers=False,
35
+ )
36
+ pg = ProcessGroupGloo()
37
+ manager = Manager(
38
+ pg=pg,
39
+ min_replica_size=1,
40
+ load_state_dict=lambda x: None,
41
+ state_dict=lambda: None,
42
+ replica_id=f"lighthouse_test",
43
+ store_addr="localhost",
44
+ store_port=store.port,
45
+ rank=0,
46
+ world_size=1,
47
+ use_async_quorum=False,
48
+ lighthouse_addr=lighthouse.address(),
49
+ )
50
+
51
+ start_time = time.time()
52
+ manager.start_quorum()
53
+ time_taken = time.time() - start_time
54
+ assert time_taken < 0.4, f"Time taken to join: {time_taken} > 0.4s"
55
+
56
+ finally:
57
+ # Cleanup
58
+ lighthouse.shutdown()
59
+ if "manager" in locals():
60
+ manager.shutdown()
61
+
62
+ lighthouse = LighthouseServer(
63
+ bind="[::]:0",
64
+ min_replicas=1,
65
+ join_timeout_ms=400,
66
+ )
67
+
68
+ def test_heartbeat_timeout_ms_sanity(self) -> None:
69
+ lighthouse = LighthouseServer(
70
+ bind="[::]:0",
71
+ min_replicas=1,
72
+ heartbeat_timeout_ms=100,
73
+ )
74
+ lighthouse.shutdown()
75
+
76
+ def test_lighthouse_client_behavior(self) -> None:
77
+ """Test that using LighthouseClient with a generic quorum behavior"""
78
+ # To test, we create a lighthouse with 100ms and 400ms join timeouts
79
+ # and measure the time taken to validate the quorum.
80
+ lighthouse = LighthouseServer(
81
+ bind="[::]:0",
82
+ min_replicas=1,
83
+ join_timeout_ms=100,
84
+ )
85
+
86
+ # Create a manager that tries to join
87
+ try:
88
+ client = LighthouseClient(
89
+ addr=lighthouse.address(),
90
+ connect_timeout=timedelta(seconds=1),
91
+ )
92
+ store = dist.TCPStore(
93
+ host_name="localhost",
94
+ port=0,
95
+ is_master=True,
96
+ wait_for_workers=False,
97
+ )
98
+ result = client.quorum(
99
+ replica_id="lighthouse_test",
100
+ address="localhost",
101
+ store_address=f"localhost:{store.port}",
102
+ step=1,
103
+ world_size=1,
104
+ shrink_only=False,
105
+ timeout=timedelta(seconds=1),
106
+ data={"my_data": 1234},
107
+ )
108
+ assert result is not None
109
+ assert isinstance(result, Quorum)
110
+ assert len(result.participants) == 1
111
+ for member in result.participants:
112
+ assert isinstance(member, QuorumMember)
113
+ assert member.replica_id == "lighthouse_test"
114
+ assert member.data is not None
115
+ assert "my_data" in member.data
116
+ assert member.data["my_data"] == 1234
117
+
118
+ # Test the optional args
119
+ result = client.quorum(
120
+ replica_id="lighthouse_test",
121
+ timeout=timedelta(seconds=1),
122
+ )
123
+ assert result is not None
124
+ for member in result.participants:
125
+ assert member.replica_id == "lighthouse_test"
126
+
127
+ finally:
128
+ # Cleanup
129
+ lighthouse.shutdown()
130
+
131
+ def test_heartbeat_round_trip(self) -> None:
132
+ lighthouse = LighthouseServer(
133
+ bind="[::]:0",
134
+ min_replicas=1,
135
+ heartbeat_timeout_ms=200,
136
+ )
137
+ try:
138
+ client = LighthouseClient(
139
+ addr=lighthouse.address(),
140
+ connect_timeout=timedelta(seconds=1),
141
+ )
142
+
143
+ client.heartbeat("rep0")
144
+
145
+ # (Should still be alive, as sleep time is less than timeout)
146
+ time.sleep(0.15)
147
+ q = client.quorum(
148
+ replica_id="rep0",
149
+ timeout=timedelta(milliseconds=500),
150
+ )
151
+ assert any(m.replica_id == "rep0" for m in q.participants)
152
+
153
+ # (Wait long enough for timeout to trigger)
154
+ time.sleep(0.25)
155
+ # "Probe" with different replica so we don't revive rep0
156
+ probe = client.quorum(
157
+ replica_id="probe",
158
+ timeout=timedelta(milliseconds=500),
159
+ )
160
+ assert all(m.replica_id != "rep0" for m in probe.participants)
161
+
162
+ finally:
163
+ lighthouse.shutdown()