torchft-nightly 2026.1.3__cp310-cp310-manylinux_2_24_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchft/__init__.py +34 -0
- torchft/_test/diloco_trainer.py +287 -0
- torchft/_test/managed_work_test.py +320 -0
- torchft/_test_utils.py +111 -0
- torchft/_torchft.cpython-310-x86_64-linux-gnu.so +0 -0
- torchft/_torchft.pyi +116 -0
- torchft/checkpointing/__init__.py +20 -0
- torchft/checkpointing/_rwlock.py +136 -0
- torchft/checkpointing/_serialization.py +39 -0
- torchft/checkpointing/http_transport.py +299 -0
- torchft/checkpointing/http_transport_bench.py +61 -0
- torchft/checkpointing/http_transport_test.py +146 -0
- torchft/checkpointing/pg_transport.py +306 -0
- torchft/checkpointing/pg_transport_bench.py +99 -0
- torchft/checkpointing/pg_transport_test.py +101 -0
- torchft/checkpointing/rwlock_test.py +58 -0
- torchft/checkpointing/transport.py +68 -0
- torchft/checkpointing/transport_test.py +161 -0
- torchft/collectives.py +415 -0
- torchft/collectives_test.py +212 -0
- torchft/coordination.py +39 -0
- torchft/coordination_test.py +29 -0
- torchft/data.py +77 -0
- torchft/data_test.py +39 -0
- torchft/ddp.py +105 -0
- torchft/ddp_test.py +68 -0
- torchft/diloco_regression_test.py +644 -0
- torchft/examples/slurm/README.md +34 -0
- torchft/examples/slurm/punisher.py +95 -0
- torchft/examples/slurm/runner.py +221 -0
- torchft/fsdp_test.py +102 -0
- torchft/futures.py +353 -0
- torchft/futures_test.py +140 -0
- torchft/http.py +13 -0
- torchft/lighthouse_test.py +163 -0
- torchft/local_sgd.py +796 -0
- torchft/local_sgd_integ_test.py +600 -0
- torchft/local_sgd_test.py +324 -0
- torchft/manager.py +1358 -0
- torchft/manager_integ_test.py +653 -0
- torchft/manager_test.py +911 -0
- torchft/multiprocessing.py +38 -0
- torchft/multiprocessing_dummy_context.py +135 -0
- torchft/multiprocessing_test.py +58 -0
- torchft/optim.py +63 -0
- torchft/optim_test.py +50 -0
- torchft/otel.py +134 -0
- torchft/parameter_server.py +195 -0
- torchft/parameter_server_test.py +47 -0
- torchft/process_group.py +2118 -0
- torchft/process_group_test.py +1028 -0
- torchft/quantization.py +686 -0
- torchft/quantization_test.py +131 -0
- torchft/torchx.py +89 -0
- torchft/utils.py +67 -0
- torchft/work.py +26 -0
- torchft_nightly-2026.1.3.dist-info/METADATA +308 -0
- torchft_nightly-2026.1.3.dist-info/RECORD +61 -0
- torchft_nightly-2026.1.3.dist-info/WHEEL +4 -0
- torchft_nightly-2026.1.3.dist-info/entry_points.txt +2 -0
- torchft_nightly-2026.1.3.dist-info/licenses/LICENSE +34 -0
torchft/futures.py
ADDED
|
@@ -0,0 +1,353 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
import asyncio
|
|
8
|
+
import os
|
|
9
|
+
import queue
|
|
10
|
+
import sys
|
|
11
|
+
import threading
|
|
12
|
+
import time
|
|
13
|
+
from contextlib import contextmanager, nullcontext
|
|
14
|
+
from datetime import timedelta
|
|
15
|
+
from typing import Callable, Generator, Optional, TypeVar
|
|
16
|
+
from unittest.mock import Mock
|
|
17
|
+
|
|
18
|
+
import torch
|
|
19
|
+
from torch.futures import Future
|
|
20
|
+
|
|
21
|
+
from torchft.utils import get_stream_context
|
|
22
|
+
|
|
23
|
+
T = TypeVar("T")
|
|
24
|
+
|
|
25
|
+
WATCHDOG_TIMEOUT_SEC = "TORCHFT_WATCHDOG_TIMEOUT_SEC"
|
|
26
|
+
|
|
27
|
+
|
|
28
|
+
class _TimerHandle:
|
|
29
|
+
def __init__(self) -> None:
|
|
30
|
+
self._lock = threading.Lock()
|
|
31
|
+
self._timer_handle: Optional[asyncio.TimerHandle] = None
|
|
32
|
+
self._cancelled = False
|
|
33
|
+
|
|
34
|
+
def set_timer_handle(self, timer_handle: asyncio.TimerHandle) -> None:
|
|
35
|
+
with self._lock:
|
|
36
|
+
if self._cancelled:
|
|
37
|
+
timer_handle.cancel()
|
|
38
|
+
self._timer_handle = None
|
|
39
|
+
else:
|
|
40
|
+
self._timer_handle = timer_handle
|
|
41
|
+
|
|
42
|
+
def cancel(self) -> None:
|
|
43
|
+
with self._lock:
|
|
44
|
+
assert not self._cancelled, "timer can only be cancelled once"
|
|
45
|
+
self._cancelled = True
|
|
46
|
+
if self._timer_handle is not None:
|
|
47
|
+
self._timer_handle.cancel()
|
|
48
|
+
self._timer_handle = None
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
class _TimeoutManager:
|
|
52
|
+
"""
|
|
53
|
+
This class manages timeouts for code blocks, futures and CUDA events. It
|
|
54
|
+
uses a background thread with an event loop to schedule the timeouts and
|
|
55
|
+
call the callback function when the timeout is reached.
|
|
56
|
+
|
|
57
|
+
Generally there is a single instance of this class that is used for all
|
|
58
|
+
timeouts. The callbacks should not block otherwise other timeouts may not
|
|
59
|
+
be processed.
|
|
60
|
+
"""
|
|
61
|
+
|
|
62
|
+
def __init__(self) -> None:
|
|
63
|
+
self._lock = threading.Lock()
|
|
64
|
+
self._event_loop: Optional[asyncio.AbstractEventLoop] = None
|
|
65
|
+
self._event_loop_thread: Optional[threading.Thread] = None
|
|
66
|
+
self._next_timer_id = 0
|
|
67
|
+
|
|
68
|
+
# Ensures `_event_loop_thread` is not stuck
|
|
69
|
+
self._watchdog_thread: Optional[threading.Thread] = None
|
|
70
|
+
|
|
71
|
+
# Give this much time the the `_event_loop_thread` to confirm that
|
|
72
|
+
# it is not stuck
|
|
73
|
+
self._watchdog_interval = timedelta(
|
|
74
|
+
seconds=int(os.environ.get(WATCHDOG_TIMEOUT_SEC, "30"))
|
|
75
|
+
)
|
|
76
|
+
|
|
77
|
+
# This queue is used to delete events on the main thread as cudaEventDestroy
|
|
78
|
+
# can block if the CUDA queue is full.
|
|
79
|
+
self._del_queue: queue.SimpleQueue[object] = queue.SimpleQueue()
|
|
80
|
+
|
|
81
|
+
def _maybe_start_event_loop(self) -> asyncio.AbstractEventLoop:
|
|
82
|
+
"""
|
|
83
|
+
Start the event loop if it has not already been started.
|
|
84
|
+
"""
|
|
85
|
+
with self._lock:
|
|
86
|
+
if self._event_loop is None:
|
|
87
|
+
self._event_loop = asyncio.new_event_loop()
|
|
88
|
+
self._event_loop_thread = threading.Thread(
|
|
89
|
+
target=self._event_loop.run_forever,
|
|
90
|
+
daemon=True,
|
|
91
|
+
name="TimeoutManager",
|
|
92
|
+
)
|
|
93
|
+
self._event_loop_thread.start()
|
|
94
|
+
|
|
95
|
+
self._watchdog_thread = threading.Thread(
|
|
96
|
+
target=self._watchdog_loop, daemon=True
|
|
97
|
+
)
|
|
98
|
+
self._watchdog_thread.start()
|
|
99
|
+
|
|
100
|
+
# pyre-fixme[7]: optional
|
|
101
|
+
return self._event_loop
|
|
102
|
+
|
|
103
|
+
def _watchdog_loop(self) -> None:
|
|
104
|
+
while True:
|
|
105
|
+
is_healthy = False
|
|
106
|
+
|
|
107
|
+
def updated_health() -> None:
|
|
108
|
+
nonlocal is_healthy
|
|
109
|
+
is_healthy = True
|
|
110
|
+
|
|
111
|
+
with self._lock:
|
|
112
|
+
if self._event_loop is None:
|
|
113
|
+
return
|
|
114
|
+
|
|
115
|
+
# The method passed to the event loop should finish fast.
|
|
116
|
+
# It just updates a bool, which is also thread safe.
|
|
117
|
+
self._event_loop.call_soon_threadsafe(updated_health)
|
|
118
|
+
|
|
119
|
+
time.sleep(self._watchdog_interval.total_seconds())
|
|
120
|
+
|
|
121
|
+
if not is_healthy:
|
|
122
|
+
print("TimeoutManager is stuck. Exiting.")
|
|
123
|
+
sys.exit(1)
|
|
124
|
+
# Needed becuase `sys.exit` is mocked in unit tests.
|
|
125
|
+
# If we don't return here, we don't break out of the loop.
|
|
126
|
+
return
|
|
127
|
+
|
|
128
|
+
def shutdown(self) -> None:
|
|
129
|
+
"""
|
|
130
|
+
Shutdown the event loop and cancel all pending timeouts.
|
|
131
|
+
"""
|
|
132
|
+
watchdog_thread = None
|
|
133
|
+
with self._lock:
|
|
134
|
+
if self._event_loop is not None:
|
|
135
|
+
self._event_loop.call_soon_threadsafe(self._event_loop.stop)
|
|
136
|
+
assert self._event_loop_thread is not None
|
|
137
|
+
self._event_loop_thread.join()
|
|
138
|
+
self._event_loop = None
|
|
139
|
+
self._event_loop_thread = None
|
|
140
|
+
|
|
141
|
+
# We can't join the watchdog thread here because it grabs `lock_`
|
|
142
|
+
watchdog_thread = self._watchdog_thread
|
|
143
|
+
|
|
144
|
+
if watchdog_thread is not None:
|
|
145
|
+
# If `_maybe_start_event_loop` is called again, the it is possible the `join`
|
|
146
|
+
# below will never finish.
|
|
147
|
+
# This class assumes `_maybe_start_event_loop` will not be called after `shutdown`.
|
|
148
|
+
# If this functionality is required in the future, we could change the class to
|
|
149
|
+
# support this. Or create multiple instances of this class.
|
|
150
|
+
watchdog_thread.join()
|
|
151
|
+
|
|
152
|
+
def register(self, fut: Future[T], timeout: timedelta) -> Future[T]:
|
|
153
|
+
"""
|
|
154
|
+
Registers a future that will be cancelled after the specified timeout.
|
|
155
|
+
"""
|
|
156
|
+
# bypass timeout for mock futures
|
|
157
|
+
if isinstance(fut, Mock):
|
|
158
|
+
return fut
|
|
159
|
+
|
|
160
|
+
self._clear_del_queue()
|
|
161
|
+
|
|
162
|
+
loop = self._maybe_start_event_loop()
|
|
163
|
+
|
|
164
|
+
timed_fut: Future[T] = Future()
|
|
165
|
+
handle: _TimerHandle = _TimerHandle()
|
|
166
|
+
loop.call_soon_threadsafe(
|
|
167
|
+
self._register_callback,
|
|
168
|
+
loop,
|
|
169
|
+
lambda: timed_fut.set_exception(
|
|
170
|
+
# pyre-fixme[6]: e is not T
|
|
171
|
+
TimeoutError(f"future did not complete within {timeout}")
|
|
172
|
+
),
|
|
173
|
+
timeout,
|
|
174
|
+
handle,
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
stream: Optional[torch.Stream] = (
|
|
178
|
+
torch.accelerator.current_stream()
|
|
179
|
+
if torch.accelerator.is_available()
|
|
180
|
+
else None
|
|
181
|
+
)
|
|
182
|
+
|
|
183
|
+
def callback(fut: Future[T]) -> None:
|
|
184
|
+
with get_stream_context(stream):
|
|
185
|
+
handle.cancel()
|
|
186
|
+
try:
|
|
187
|
+
timed_fut.set_result(fut.wait())
|
|
188
|
+
except Exception as e:
|
|
189
|
+
try:
|
|
190
|
+
# this can throw if the future is already done
|
|
191
|
+
# pyre-fixme[6]: e is not T
|
|
192
|
+
timed_fut.set_exception(e)
|
|
193
|
+
except Exception:
|
|
194
|
+
pass
|
|
195
|
+
|
|
196
|
+
fut.add_done_callback(callback)
|
|
197
|
+
return timed_fut
|
|
198
|
+
|
|
199
|
+
def stream_timeout(self, callback: Callable[[], None], timeout: timedelta) -> None:
|
|
200
|
+
self._clear_del_queue()
|
|
201
|
+
|
|
202
|
+
loop = self._maybe_start_event_loop()
|
|
203
|
+
|
|
204
|
+
event: torch.Event = torch.Event()
|
|
205
|
+
event.record()
|
|
206
|
+
|
|
207
|
+
def handler() -> None:
|
|
208
|
+
if not event.query():
|
|
209
|
+
callback()
|
|
210
|
+
|
|
211
|
+
# cudaEventDestroy can block so we never want to delete in the event
|
|
212
|
+
# loop. Put it on the del queue so we can delete it in the main
|
|
213
|
+
# thread.
|
|
214
|
+
self._del_queue.put(event)
|
|
215
|
+
|
|
216
|
+
loop.call_soon_threadsafe(
|
|
217
|
+
self._register_callback, loop, handler, timeout, _TimerHandle()
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
@classmethod
|
|
221
|
+
def _register_callback(
|
|
222
|
+
cls,
|
|
223
|
+
loop: asyncio.AbstractEventLoop,
|
|
224
|
+
callback: Callable[[], None],
|
|
225
|
+
timeout: timedelta,
|
|
226
|
+
handle: _TimerHandle,
|
|
227
|
+
) -> None:
|
|
228
|
+
timer_handle = loop.call_later(
|
|
229
|
+
timeout.total_seconds(),
|
|
230
|
+
callback,
|
|
231
|
+
)
|
|
232
|
+
handle.set_timer_handle(timer_handle)
|
|
233
|
+
|
|
234
|
+
@contextmanager
|
|
235
|
+
def context_timeout(
|
|
236
|
+
self, callback: Callable[[], None], timeout: timedelta
|
|
237
|
+
) -> Generator[None, None, None]:
|
|
238
|
+
self._clear_del_queue()
|
|
239
|
+
|
|
240
|
+
loop = self._maybe_start_event_loop()
|
|
241
|
+
handle = _TimerHandle()
|
|
242
|
+
|
|
243
|
+
loop.call_soon_threadsafe(
|
|
244
|
+
self._register_callback, loop, callback, timeout, handle
|
|
245
|
+
)
|
|
246
|
+
|
|
247
|
+
yield
|
|
248
|
+
|
|
249
|
+
handle.cancel()
|
|
250
|
+
|
|
251
|
+
def _clear_del_queue(self) -> int:
|
|
252
|
+
"""
|
|
253
|
+
Clear the queue of futures to be deleted.
|
|
254
|
+
|
|
255
|
+
Returns the number of items deleted.
|
|
256
|
+
"""
|
|
257
|
+
count = 0
|
|
258
|
+
while True:
|
|
259
|
+
try:
|
|
260
|
+
# get and immediately discard item
|
|
261
|
+
item = self._del_queue.get_nowait()
|
|
262
|
+
refcount = sys.getrefcount(item)
|
|
263
|
+
assert (
|
|
264
|
+
# 1 from item, 1 from getrefcount
|
|
265
|
+
refcount == 2
|
|
266
|
+
), f"items in del_queue reference should not have other references, found {refcount=}"
|
|
267
|
+
del item
|
|
268
|
+
|
|
269
|
+
count += 1
|
|
270
|
+
except queue.Empty:
|
|
271
|
+
break
|
|
272
|
+
|
|
273
|
+
return count
|
|
274
|
+
|
|
275
|
+
|
|
276
|
+
_TIMEOUT_MANAGER = _TimeoutManager()
|
|
277
|
+
|
|
278
|
+
|
|
279
|
+
def future_timeout(fut: Future[T], timeout: timedelta) -> Future[T]:
|
|
280
|
+
"""
|
|
281
|
+
Return a Future that completes with the result of the given Future within
|
|
282
|
+
the given timeout or with a TimeoutError.
|
|
283
|
+
|
|
284
|
+
Args:
|
|
285
|
+
fut: The Future to wait for
|
|
286
|
+
timeout: The timeout to wait for the Future to complete
|
|
287
|
+
|
|
288
|
+
Returns:
|
|
289
|
+
The future with a timeout
|
|
290
|
+
"""
|
|
291
|
+
return _TIMEOUT_MANAGER.register(fut, timeout)
|
|
292
|
+
|
|
293
|
+
|
|
294
|
+
def future_wait(fut: Future[T], timeout: timedelta) -> T:
|
|
295
|
+
"""
|
|
296
|
+
Wait for a Future to complete up to a timeout.
|
|
297
|
+
|
|
298
|
+
Args:
|
|
299
|
+
fut: The Future to wait for
|
|
300
|
+
timeout: The timeout to wait for the Future to complete
|
|
301
|
+
|
|
302
|
+
Returns:
|
|
303
|
+
The result of the Future if it completed within the timeout.
|
|
304
|
+
|
|
305
|
+
Raises:
|
|
306
|
+
TimeoutError if the Future did not complete within the timeout.
|
|
307
|
+
Any other exception that occurred in the Future.
|
|
308
|
+
"""
|
|
309
|
+
|
|
310
|
+
event: threading.Event = threading.Event()
|
|
311
|
+
|
|
312
|
+
def callback(fut: Future[T]) -> T:
|
|
313
|
+
event.set()
|
|
314
|
+
return fut.wait()
|
|
315
|
+
|
|
316
|
+
fut = fut.then(callback)
|
|
317
|
+
|
|
318
|
+
if not event.wait(timeout=timeout.total_seconds()):
|
|
319
|
+
raise TimeoutError(f"future did not complete within {timeout}")
|
|
320
|
+
|
|
321
|
+
return fut.wait()
|
|
322
|
+
|
|
323
|
+
|
|
324
|
+
def stream_timeout(callback: Callable[[], None], timeout: timedelta) -> None:
|
|
325
|
+
"""
|
|
326
|
+
Registers a callback that will be called after the specified timeout if
|
|
327
|
+
the current stream doesn't complete in time.
|
|
328
|
+
|
|
329
|
+
This uses a cuda Event to track the completion of the current stream. If
|
|
330
|
+
the stream is not complete after the timeout, the callback is called.
|
|
331
|
+
|
|
332
|
+
Args:
|
|
333
|
+
callback: The callback to call if the stream doesn't complete in time.
|
|
334
|
+
timeout: The timeout to wait for the stream to complete.
|
|
335
|
+
"""
|
|
336
|
+
_TIMEOUT_MANAGER.stream_timeout(callback, timeout)
|
|
337
|
+
|
|
338
|
+
|
|
339
|
+
@contextmanager
|
|
340
|
+
def context_timeout(
|
|
341
|
+
callback: Callable[[], None], timeout: timedelta
|
|
342
|
+
) -> Generator[None, None, None]:
|
|
343
|
+
"""
|
|
344
|
+
Registers a callback that will be called after the specified timeout if
|
|
345
|
+
the current contextmanager doesn't exit in time.
|
|
346
|
+
|
|
347
|
+
Args:
|
|
348
|
+
callback: The callback to call if we time out.
|
|
349
|
+
timeout: How long to wait for the contextmanager to exit.
|
|
350
|
+
"""
|
|
351
|
+
|
|
352
|
+
with _TIMEOUT_MANAGER.context_timeout(callback, timeout):
|
|
353
|
+
yield
|
torchft/futures_test.py
ADDED
|
@@ -0,0 +1,140 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
import threading
|
|
8
|
+
from datetime import timedelta
|
|
9
|
+
from unittest import skipUnless, TestCase
|
|
10
|
+
from unittest.mock import Mock, patch
|
|
11
|
+
|
|
12
|
+
import torch
|
|
13
|
+
from torch.futures import Future
|
|
14
|
+
|
|
15
|
+
from torchft.futures import (
|
|
16
|
+
_TIMEOUT_MANAGER,
|
|
17
|
+
context_timeout,
|
|
18
|
+
future_timeout,
|
|
19
|
+
future_wait,
|
|
20
|
+
stream_timeout,
|
|
21
|
+
)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class FuturesTest(TestCase):
|
|
25
|
+
def setUp(self) -> None:
|
|
26
|
+
self._original_watchdog_interval = _TIMEOUT_MANAGER._watchdog_interval
|
|
27
|
+
_TIMEOUT_MANAGER._watchdog_interval = timedelta(seconds=1)
|
|
28
|
+
|
|
29
|
+
def tearDown(self) -> None:
|
|
30
|
+
_TIMEOUT_MANAGER._watchdog_interval = self._original_watchdog_interval
|
|
31
|
+
|
|
32
|
+
def test_future_wait(self) -> None:
|
|
33
|
+
fut = Future()
|
|
34
|
+
with self.assertRaisesRegex(TimeoutError, "future did not complete within"):
|
|
35
|
+
future_wait(fut, timeout=timedelta(seconds=0.01))
|
|
36
|
+
|
|
37
|
+
fut = Future()
|
|
38
|
+
fut.set_result(1)
|
|
39
|
+
self.assertEqual(future_wait(fut, timeout=timedelta(seconds=1.0)), 1)
|
|
40
|
+
|
|
41
|
+
fut = Future()
|
|
42
|
+
fut.set_exception(RuntimeError("test"))
|
|
43
|
+
with self.assertRaisesRegex(RuntimeError, "test"):
|
|
44
|
+
future_wait(fut, timeout=timedelta(seconds=1.0))
|
|
45
|
+
|
|
46
|
+
def test_future_timeout(self) -> None:
|
|
47
|
+
fut = Future()
|
|
48
|
+
timed_fut = future_timeout(fut, timeout=timedelta(seconds=0.01))
|
|
49
|
+
with self.assertRaisesRegex(TimeoutError, "future did not complete within"):
|
|
50
|
+
timed_fut.wait()
|
|
51
|
+
|
|
52
|
+
def test_future_timeout_result(self) -> None:
|
|
53
|
+
fut = Future()
|
|
54
|
+
timed_fut = future_timeout(fut, timeout=timedelta(seconds=10))
|
|
55
|
+
fut.set_result(1)
|
|
56
|
+
self.assertEqual(timed_fut.wait(), 1)
|
|
57
|
+
|
|
58
|
+
def test_future_timeout_exception(self) -> None:
|
|
59
|
+
fut = Future()
|
|
60
|
+
timed_fut = future_timeout(fut, timeout=timedelta(seconds=10))
|
|
61
|
+
fut.set_exception(RuntimeError("test"))
|
|
62
|
+
with self.assertRaisesRegex(RuntimeError, "test"):
|
|
63
|
+
timed_fut.wait()
|
|
64
|
+
|
|
65
|
+
def test_context_timeout(self) -> None:
|
|
66
|
+
barrier: threading.Barrier = threading.Barrier(2)
|
|
67
|
+
|
|
68
|
+
def callback() -> None:
|
|
69
|
+
barrier.wait()
|
|
70
|
+
|
|
71
|
+
with context_timeout(callback, timedelta(seconds=0.01)):
|
|
72
|
+
# block until timeout fires
|
|
73
|
+
barrier.wait()
|
|
74
|
+
|
|
75
|
+
def fail() -> None:
|
|
76
|
+
self.fail("timeout should be cancelled")
|
|
77
|
+
|
|
78
|
+
with context_timeout(fail, timedelta(seconds=10)):
|
|
79
|
+
pass
|
|
80
|
+
|
|
81
|
+
# pyre-fixme[56]: Pyre was not able to infer the type of decorator
|
|
82
|
+
@skipUnless(torch.cuda.is_available(), "CUDA is required for this test")
|
|
83
|
+
def test_stream_timeout(self) -> None:
|
|
84
|
+
torch.cuda.synchronize()
|
|
85
|
+
|
|
86
|
+
def callback() -> None:
|
|
87
|
+
self.fail()
|
|
88
|
+
|
|
89
|
+
stream_timeout(callback, timeout=timedelta(seconds=0.01))
|
|
90
|
+
|
|
91
|
+
# make sure event completes
|
|
92
|
+
torch.cuda.synchronize()
|
|
93
|
+
|
|
94
|
+
# make sure that event is deleted on the deletion queue
|
|
95
|
+
item = _TIMEOUT_MANAGER._del_queue.get(timeout=10.0)
|
|
96
|
+
_TIMEOUT_MANAGER._del_queue.put(item)
|
|
97
|
+
del item
|
|
98
|
+
|
|
99
|
+
self.assertEqual(_TIMEOUT_MANAGER._clear_del_queue(), 1)
|
|
100
|
+
|
|
101
|
+
# Test that when a timeout handle gets stuck, `sys.exit(1)` is called
|
|
102
|
+
@patch("sys.exit")
|
|
103
|
+
def test_exit_on_stuck_callback(self, mock_exit: Mock) -> None:
|
|
104
|
+
exit_event: threading.Event = threading.Event()
|
|
105
|
+
|
|
106
|
+
def custom_exit(_) -> None:
|
|
107
|
+
# 3. When event loop is stuck, exit(1) is called
|
|
108
|
+
nonlocal exit_event
|
|
109
|
+
exit_event.set()
|
|
110
|
+
|
|
111
|
+
mock_exit.side_effect = custom_exit
|
|
112
|
+
|
|
113
|
+
callback_event: threading.Event = threading.Event()
|
|
114
|
+
|
|
115
|
+
def callback() -> None:
|
|
116
|
+
# 2. Make sure callback blocks event loop
|
|
117
|
+
nonlocal callback_event
|
|
118
|
+
callback_event.wait()
|
|
119
|
+
|
|
120
|
+
context_event: threading.Event = threading.Event()
|
|
121
|
+
|
|
122
|
+
def thread_fn() -> None:
|
|
123
|
+
with context_timeout(callback, timedelta(seconds=0.01)):
|
|
124
|
+
# 1. Make sure context doesn't finish in time
|
|
125
|
+
nonlocal context_event
|
|
126
|
+
context_event.wait()
|
|
127
|
+
|
|
128
|
+
thread = threading.Thread(target=thread_fn)
|
|
129
|
+
thread.start()
|
|
130
|
+
|
|
131
|
+
# 4. exit event will wake this up
|
|
132
|
+
exit_event.wait()
|
|
133
|
+
mock_exit.assert_called_once_with(1)
|
|
134
|
+
|
|
135
|
+
# 5. event loop is still stuck, so let's unblock it
|
|
136
|
+
callback_event.set()
|
|
137
|
+
|
|
138
|
+
# 6. unblock the context and make sure it exits
|
|
139
|
+
context_event.set()
|
|
140
|
+
thread.join()
|
torchft/http.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
import socket
|
|
8
|
+
from http.server import ThreadingHTTPServer
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class _IPv6HTTPServer(ThreadingHTTPServer):
|
|
12
|
+
address_family: socket.AddressFamily = socket.AF_INET6
|
|
13
|
+
request_queue_size: int = 1024
|
|
@@ -0,0 +1,163 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
import time
|
|
8
|
+
from datetime import timedelta
|
|
9
|
+
from unittest import TestCase
|
|
10
|
+
|
|
11
|
+
import torch.distributed as dist
|
|
12
|
+
|
|
13
|
+
from torchft import Manager, ProcessGroupGloo
|
|
14
|
+
from torchft._torchft import LighthouseClient, LighthouseServer, Quorum, QuorumMember
|
|
15
|
+
|
|
16
|
+
|
|
17
|
+
class TestLighthouse(TestCase):
|
|
18
|
+
def test_join_timeout_behavior(self) -> None:
|
|
19
|
+
"""Test that join_timeout_ms affects joining behavior"""
|
|
20
|
+
# To test, we create a lighthouse with 100ms and 400ms join timeouts
|
|
21
|
+
# and measure the time taken to validate the quorum.
|
|
22
|
+
lighthouse = LighthouseServer(
|
|
23
|
+
bind="[::]:0",
|
|
24
|
+
min_replicas=1,
|
|
25
|
+
join_timeout_ms=100,
|
|
26
|
+
)
|
|
27
|
+
|
|
28
|
+
# Create a manager that tries to join
|
|
29
|
+
try:
|
|
30
|
+
store = dist.TCPStore(
|
|
31
|
+
host_name="localhost",
|
|
32
|
+
port=0,
|
|
33
|
+
is_master=True,
|
|
34
|
+
wait_for_workers=False,
|
|
35
|
+
)
|
|
36
|
+
pg = ProcessGroupGloo()
|
|
37
|
+
manager = Manager(
|
|
38
|
+
pg=pg,
|
|
39
|
+
min_replica_size=1,
|
|
40
|
+
load_state_dict=lambda x: None,
|
|
41
|
+
state_dict=lambda: None,
|
|
42
|
+
replica_id=f"lighthouse_test",
|
|
43
|
+
store_addr="localhost",
|
|
44
|
+
store_port=store.port,
|
|
45
|
+
rank=0,
|
|
46
|
+
world_size=1,
|
|
47
|
+
use_async_quorum=False,
|
|
48
|
+
lighthouse_addr=lighthouse.address(),
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
start_time = time.time()
|
|
52
|
+
manager.start_quorum()
|
|
53
|
+
time_taken = time.time() - start_time
|
|
54
|
+
assert time_taken < 0.4, f"Time taken to join: {time_taken} > 0.4s"
|
|
55
|
+
|
|
56
|
+
finally:
|
|
57
|
+
# Cleanup
|
|
58
|
+
lighthouse.shutdown()
|
|
59
|
+
if "manager" in locals():
|
|
60
|
+
manager.shutdown()
|
|
61
|
+
|
|
62
|
+
lighthouse = LighthouseServer(
|
|
63
|
+
bind="[::]:0",
|
|
64
|
+
min_replicas=1,
|
|
65
|
+
join_timeout_ms=400,
|
|
66
|
+
)
|
|
67
|
+
|
|
68
|
+
def test_heartbeat_timeout_ms_sanity(self) -> None:
|
|
69
|
+
lighthouse = LighthouseServer(
|
|
70
|
+
bind="[::]:0",
|
|
71
|
+
min_replicas=1,
|
|
72
|
+
heartbeat_timeout_ms=100,
|
|
73
|
+
)
|
|
74
|
+
lighthouse.shutdown()
|
|
75
|
+
|
|
76
|
+
def test_lighthouse_client_behavior(self) -> None:
|
|
77
|
+
"""Test that using LighthouseClient with a generic quorum behavior"""
|
|
78
|
+
# To test, we create a lighthouse with 100ms and 400ms join timeouts
|
|
79
|
+
# and measure the time taken to validate the quorum.
|
|
80
|
+
lighthouse = LighthouseServer(
|
|
81
|
+
bind="[::]:0",
|
|
82
|
+
min_replicas=1,
|
|
83
|
+
join_timeout_ms=100,
|
|
84
|
+
)
|
|
85
|
+
|
|
86
|
+
# Create a manager that tries to join
|
|
87
|
+
try:
|
|
88
|
+
client = LighthouseClient(
|
|
89
|
+
addr=lighthouse.address(),
|
|
90
|
+
connect_timeout=timedelta(seconds=1),
|
|
91
|
+
)
|
|
92
|
+
store = dist.TCPStore(
|
|
93
|
+
host_name="localhost",
|
|
94
|
+
port=0,
|
|
95
|
+
is_master=True,
|
|
96
|
+
wait_for_workers=False,
|
|
97
|
+
)
|
|
98
|
+
result = client.quorum(
|
|
99
|
+
replica_id="lighthouse_test",
|
|
100
|
+
address="localhost",
|
|
101
|
+
store_address=f"localhost:{store.port}",
|
|
102
|
+
step=1,
|
|
103
|
+
world_size=1,
|
|
104
|
+
shrink_only=False,
|
|
105
|
+
timeout=timedelta(seconds=1),
|
|
106
|
+
data={"my_data": 1234},
|
|
107
|
+
)
|
|
108
|
+
assert result is not None
|
|
109
|
+
assert isinstance(result, Quorum)
|
|
110
|
+
assert len(result.participants) == 1
|
|
111
|
+
for member in result.participants:
|
|
112
|
+
assert isinstance(member, QuorumMember)
|
|
113
|
+
assert member.replica_id == "lighthouse_test"
|
|
114
|
+
assert member.data is not None
|
|
115
|
+
assert "my_data" in member.data
|
|
116
|
+
assert member.data["my_data"] == 1234
|
|
117
|
+
|
|
118
|
+
# Test the optional args
|
|
119
|
+
result = client.quorum(
|
|
120
|
+
replica_id="lighthouse_test",
|
|
121
|
+
timeout=timedelta(seconds=1),
|
|
122
|
+
)
|
|
123
|
+
assert result is not None
|
|
124
|
+
for member in result.participants:
|
|
125
|
+
assert member.replica_id == "lighthouse_test"
|
|
126
|
+
|
|
127
|
+
finally:
|
|
128
|
+
# Cleanup
|
|
129
|
+
lighthouse.shutdown()
|
|
130
|
+
|
|
131
|
+
def test_heartbeat_round_trip(self) -> None:
|
|
132
|
+
lighthouse = LighthouseServer(
|
|
133
|
+
bind="[::]:0",
|
|
134
|
+
min_replicas=1,
|
|
135
|
+
heartbeat_timeout_ms=200,
|
|
136
|
+
)
|
|
137
|
+
try:
|
|
138
|
+
client = LighthouseClient(
|
|
139
|
+
addr=lighthouse.address(),
|
|
140
|
+
connect_timeout=timedelta(seconds=1),
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
client.heartbeat("rep0")
|
|
144
|
+
|
|
145
|
+
# (Should still be alive, as sleep time is less than timeout)
|
|
146
|
+
time.sleep(0.15)
|
|
147
|
+
q = client.quorum(
|
|
148
|
+
replica_id="rep0",
|
|
149
|
+
timeout=timedelta(milliseconds=500),
|
|
150
|
+
)
|
|
151
|
+
assert any(m.replica_id == "rep0" for m in q.participants)
|
|
152
|
+
|
|
153
|
+
# (Wait long enough for timeout to trigger)
|
|
154
|
+
time.sleep(0.25)
|
|
155
|
+
# "Probe" with different replica so we don't revive rep0
|
|
156
|
+
probe = client.quorum(
|
|
157
|
+
replica_id="probe",
|
|
158
|
+
timeout=timedelta(milliseconds=500),
|
|
159
|
+
)
|
|
160
|
+
assert all(m.replica_id != "rep0" for m in probe.participants)
|
|
161
|
+
|
|
162
|
+
finally:
|
|
163
|
+
lighthouse.shutdown()
|