torchft-nightly 2026.1.3__cp310-cp310-manylinux_2_24_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchft/__init__.py +34 -0
- torchft/_test/diloco_trainer.py +287 -0
- torchft/_test/managed_work_test.py +320 -0
- torchft/_test_utils.py +111 -0
- torchft/_torchft.cpython-310-x86_64-linux-gnu.so +0 -0
- torchft/_torchft.pyi +116 -0
- torchft/checkpointing/__init__.py +20 -0
- torchft/checkpointing/_rwlock.py +136 -0
- torchft/checkpointing/_serialization.py +39 -0
- torchft/checkpointing/http_transport.py +299 -0
- torchft/checkpointing/http_transport_bench.py +61 -0
- torchft/checkpointing/http_transport_test.py +146 -0
- torchft/checkpointing/pg_transport.py +306 -0
- torchft/checkpointing/pg_transport_bench.py +99 -0
- torchft/checkpointing/pg_transport_test.py +101 -0
- torchft/checkpointing/rwlock_test.py +58 -0
- torchft/checkpointing/transport.py +68 -0
- torchft/checkpointing/transport_test.py +161 -0
- torchft/collectives.py +415 -0
- torchft/collectives_test.py +212 -0
- torchft/coordination.py +39 -0
- torchft/coordination_test.py +29 -0
- torchft/data.py +77 -0
- torchft/data_test.py +39 -0
- torchft/ddp.py +105 -0
- torchft/ddp_test.py +68 -0
- torchft/diloco_regression_test.py +644 -0
- torchft/examples/slurm/README.md +34 -0
- torchft/examples/slurm/punisher.py +95 -0
- torchft/examples/slurm/runner.py +221 -0
- torchft/fsdp_test.py +102 -0
- torchft/futures.py +353 -0
- torchft/futures_test.py +140 -0
- torchft/http.py +13 -0
- torchft/lighthouse_test.py +163 -0
- torchft/local_sgd.py +796 -0
- torchft/local_sgd_integ_test.py +600 -0
- torchft/local_sgd_test.py +324 -0
- torchft/manager.py +1358 -0
- torchft/manager_integ_test.py +653 -0
- torchft/manager_test.py +911 -0
- torchft/multiprocessing.py +38 -0
- torchft/multiprocessing_dummy_context.py +135 -0
- torchft/multiprocessing_test.py +58 -0
- torchft/optim.py +63 -0
- torchft/optim_test.py +50 -0
- torchft/otel.py +134 -0
- torchft/parameter_server.py +195 -0
- torchft/parameter_server_test.py +47 -0
- torchft/process_group.py +2118 -0
- torchft/process_group_test.py +1028 -0
- torchft/quantization.py +686 -0
- torchft/quantization_test.py +131 -0
- torchft/torchx.py +89 -0
- torchft/utils.py +67 -0
- torchft/work.py +26 -0
- torchft_nightly-2026.1.3.dist-info/METADATA +308 -0
- torchft_nightly-2026.1.3.dist-info/RECORD +61 -0
- torchft_nightly-2026.1.3.dist-info/WHEEL +4 -0
- torchft_nightly-2026.1.3.dist-info/entry_points.txt +2 -0
- torchft_nightly-2026.1.3.dist-info/licenses/LICENSE +34 -0
torchft/manager_test.py
ADDED
|
@@ -0,0 +1,911 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
import concurrent
|
|
8
|
+
import threading
|
|
9
|
+
import time
|
|
10
|
+
from datetime import timedelta
|
|
11
|
+
from typing import Optional
|
|
12
|
+
from unittest import TestCase
|
|
13
|
+
from unittest.mock import create_autospec, MagicMock, patch
|
|
14
|
+
|
|
15
|
+
import torch
|
|
16
|
+
from torch.distributed import ReduceOp, TCPStore
|
|
17
|
+
|
|
18
|
+
from torchft._torchft import QuorumResult
|
|
19
|
+
from torchft.checkpointing._rwlock import RWLock
|
|
20
|
+
from torchft.checkpointing.transport import CheckpointTransport
|
|
21
|
+
from torchft.manager import Manager, MANAGER_ADDR_KEY, REPLICA_ID_KEY, WorldSizeMode
|
|
22
|
+
from torchft.process_group import ProcessGroup
|
|
23
|
+
from torchft.work import _DummyWork
|
|
24
|
+
|
|
25
|
+
|
|
26
|
+
def mock_should_commit(
|
|
27
|
+
rank: int, step: int, should_commit: bool, timeout: timedelta
|
|
28
|
+
) -> bool:
|
|
29
|
+
return should_commit
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class TestManager(TestCase):
|
|
33
|
+
store: TCPStore # pyre-fixme[13]: never initialized
|
|
34
|
+
load_state_dict: MagicMock # pyre-fixme[13]: never initialized
|
|
35
|
+
manager: Optional[Manager] # pyre-fixme[13]: never initialized
|
|
36
|
+
|
|
37
|
+
def tearDown(self) -> None:
|
|
38
|
+
# Manager cleanup might be handled by _create_manager
|
|
39
|
+
if hasattr(self, "manager") and self.manager is not None:
|
|
40
|
+
self.manager.shutdown(wait=False)
|
|
41
|
+
|
|
42
|
+
def _create_manager(
|
|
43
|
+
self,
|
|
44
|
+
use_async_quorum: bool = True,
|
|
45
|
+
min_replica_size: int = 2,
|
|
46
|
+
world_size_mode: WorldSizeMode = WorldSizeMode.DYNAMIC,
|
|
47
|
+
timeout: timedelta = timedelta(seconds=10),
|
|
48
|
+
init_sync: bool = True,
|
|
49
|
+
max_retries: Optional[int] = None,
|
|
50
|
+
) -> Manager:
|
|
51
|
+
pg = create_autospec(ProcessGroup)
|
|
52
|
+
pg.errored.return_value = None
|
|
53
|
+
|
|
54
|
+
self.store = TCPStore(
|
|
55
|
+
host_name="localhost", port=0, is_master=True, wait_for_workers=False
|
|
56
|
+
)
|
|
57
|
+
self.store.set(MANAGER_ADDR_KEY, "dummy")
|
|
58
|
+
self.store.set(REPLICA_ID_KEY, "dummy_id")
|
|
59
|
+
with patch(
|
|
60
|
+
"os.environ",
|
|
61
|
+
{
|
|
62
|
+
"MASTER_ADDR": "localhost",
|
|
63
|
+
"MASTER_PORT": self.store.port,
|
|
64
|
+
"RANK": "1",
|
|
65
|
+
"WORLD_SIZE": "2",
|
|
66
|
+
},
|
|
67
|
+
):
|
|
68
|
+
self.load_state_dict = MagicMock()
|
|
69
|
+
manager = Manager(
|
|
70
|
+
pg=pg,
|
|
71
|
+
min_replica_size=min_replica_size,
|
|
72
|
+
load_state_dict=self.load_state_dict,
|
|
73
|
+
state_dict=lambda: {},
|
|
74
|
+
use_async_quorum=use_async_quorum,
|
|
75
|
+
world_size_mode=world_size_mode,
|
|
76
|
+
timeout=timeout,
|
|
77
|
+
init_sync=init_sync,
|
|
78
|
+
max_retries=max_retries,
|
|
79
|
+
)
|
|
80
|
+
self.manager = manager
|
|
81
|
+
return manager
|
|
82
|
+
|
|
83
|
+
@patch("torchft.manager.ManagerClient", autospec=True)
|
|
84
|
+
def test_manager(self, client_mock: MagicMock) -> None:
|
|
85
|
+
manager = self._create_manager()
|
|
86
|
+
self.assertEqual(client_mock.call_count, 1)
|
|
87
|
+
|
|
88
|
+
@patch("torchft.manager.ManagerClient", autospec=True)
|
|
89
|
+
def test_state_dict(self, client_mock: MagicMock) -> None:
|
|
90
|
+
manager = self._create_manager()
|
|
91
|
+
|
|
92
|
+
state_dict = manager.state_dict()
|
|
93
|
+
self.assertEqual(
|
|
94
|
+
state_dict,
|
|
95
|
+
{
|
|
96
|
+
"step": 0,
|
|
97
|
+
"batches_committed": 0,
|
|
98
|
+
},
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
manager.load_state_dict(
|
|
102
|
+
{
|
|
103
|
+
"step": 1234,
|
|
104
|
+
"batches_committed": 2345,
|
|
105
|
+
}
|
|
106
|
+
)
|
|
107
|
+
self.assertEqual(manager.current_step(), 1234)
|
|
108
|
+
self.assertEqual(manager.batches_committed(), 2345)
|
|
109
|
+
|
|
110
|
+
@patch("torchft.manager.ManagerClient", autospec=True)
|
|
111
|
+
def test_user_state_dict(self, client_mock: MagicMock) -> None:
|
|
112
|
+
manager = self._create_manager()
|
|
113
|
+
|
|
114
|
+
self.assertEqual(
|
|
115
|
+
manager._manager_state_dict(),
|
|
116
|
+
{
|
|
117
|
+
"user": {
|
|
118
|
+
"default": {},
|
|
119
|
+
},
|
|
120
|
+
"torchft": {
|
|
121
|
+
"step": 0,
|
|
122
|
+
"batches_committed": 0,
|
|
123
|
+
},
|
|
124
|
+
},
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
manager.register_state_dict_fn(
|
|
128
|
+
"state",
|
|
129
|
+
self.load_state_dict,
|
|
130
|
+
lambda: {"new_state": 1},
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
self.assertEqual(
|
|
134
|
+
manager._manager_state_dict(),
|
|
135
|
+
{
|
|
136
|
+
"user": {
|
|
137
|
+
"default": {},
|
|
138
|
+
"state": {"new_state": 1},
|
|
139
|
+
},
|
|
140
|
+
"torchft": {
|
|
141
|
+
"step": 0,
|
|
142
|
+
"batches_committed": 0,
|
|
143
|
+
},
|
|
144
|
+
},
|
|
145
|
+
)
|
|
146
|
+
|
|
147
|
+
@patch("torchft.manager.ManagerClient", autospec=True)
|
|
148
|
+
def test_quorum_happy(self, client_mock: MagicMock) -> None:
|
|
149
|
+
manager = self._create_manager()
|
|
150
|
+
client_mock().should_commit = mock_should_commit
|
|
151
|
+
|
|
152
|
+
quorum = QuorumResult()
|
|
153
|
+
quorum.quorum_id = 123
|
|
154
|
+
quorum.replica_rank = 1
|
|
155
|
+
quorum.replica_world_size = 2
|
|
156
|
+
quorum.recover_src_manager_address = "manager address"
|
|
157
|
+
quorum.store_address = f"localhost:{self.store.port}"
|
|
158
|
+
quorum.max_step = 1
|
|
159
|
+
quorum.max_replica_rank = 1
|
|
160
|
+
quorum.max_world_size = 2
|
|
161
|
+
quorum.heal = False
|
|
162
|
+
|
|
163
|
+
client_mock()._quorum.return_value = quorum
|
|
164
|
+
|
|
165
|
+
self.assertEqual(manager._quorum_id, -1)
|
|
166
|
+
self.assertEqual(manager.current_step(), 0)
|
|
167
|
+
self.assertEqual(manager.batches_committed(), 0)
|
|
168
|
+
|
|
169
|
+
manager.start_quorum()
|
|
170
|
+
manager.allreduce(torch.tensor([1.0])).wait()
|
|
171
|
+
self.assertTrue(manager.should_commit())
|
|
172
|
+
|
|
173
|
+
self.assertEqual(manager._quorum_id, 123)
|
|
174
|
+
self.assertEqual(manager.current_step(), 1)
|
|
175
|
+
# pyre-ignore[16]: _pg is mocked
|
|
176
|
+
self.assertEqual(manager._pg.allreduce.call_count, 1)
|
|
177
|
+
|
|
178
|
+
manager.start_quorum()
|
|
179
|
+
self.assertEqual(manager.batches_committed(), 2)
|
|
180
|
+
|
|
181
|
+
@patch("torchft.manager.ManagerClient", autospec=True)
|
|
182
|
+
def test_quorum_heal_sync(self, client_mock: MagicMock) -> None:
|
|
183
|
+
manager = self._create_manager(use_async_quorum=False)
|
|
184
|
+
client_mock().should_commit = mock_should_commit
|
|
185
|
+
|
|
186
|
+
quorum = QuorumResult()
|
|
187
|
+
quorum.quorum_id = 123
|
|
188
|
+
quorum.replica_rank = 1
|
|
189
|
+
quorum.replica_world_size = 2
|
|
190
|
+
quorum.recover_src_manager_address = "manager address"
|
|
191
|
+
quorum.recover_src_replica_rank = 0
|
|
192
|
+
quorum.store_address = f"localhost:{self.store.port}"
|
|
193
|
+
quorum.max_step = 20
|
|
194
|
+
quorum.max_replica_rank = None
|
|
195
|
+
quorum.max_world_size = 2
|
|
196
|
+
quorum.heal = True
|
|
197
|
+
|
|
198
|
+
client_mock()._quorum.return_value = quorum
|
|
199
|
+
|
|
200
|
+
# forcible increment checkpoint server to compute correct address
|
|
201
|
+
manager._checkpoint_transport.send_checkpoint(
|
|
202
|
+
dst_ranks=[],
|
|
203
|
+
step=quorum.max_step,
|
|
204
|
+
state_dict=manager._manager_state_dict(),
|
|
205
|
+
timeout=timedelta(seconds=10),
|
|
206
|
+
)
|
|
207
|
+
client_mock()._checkpoint_metadata.return_value = (
|
|
208
|
+
manager._checkpoint_transport.metadata()
|
|
209
|
+
)
|
|
210
|
+
|
|
211
|
+
self.assertEqual(manager._quorum_id, -1)
|
|
212
|
+
self.assertEqual(manager.current_step(), 0)
|
|
213
|
+
|
|
214
|
+
self.assertEqual(manager.num_participants(), 0)
|
|
215
|
+
self.assertEqual(manager.participating_rank(), None)
|
|
216
|
+
|
|
217
|
+
manager.start_quorum()
|
|
218
|
+
manager.allreduce(torch.tensor([1.0])).wait()
|
|
219
|
+
self.assertFalse(manager._healing)
|
|
220
|
+
self.assertTrue(manager.is_participating())
|
|
221
|
+
self.assertEqual(manager.num_participants(), 2)
|
|
222
|
+
self.assertTrue(manager.should_commit())
|
|
223
|
+
|
|
224
|
+
self.assertEqual(manager._quorum_id, 123)
|
|
225
|
+
self.assertEqual(manager.current_step(), 21)
|
|
226
|
+
# pyre-ignore[16]: _pg is mocked
|
|
227
|
+
self.assertEqual(manager._pg.allreduce.call_count, 1)
|
|
228
|
+
# pyre-ignore[16]: _pg is mocked
|
|
229
|
+
self.assertEqual(manager._pg.allreduce.return_value.get_future.call_count, 1)
|
|
230
|
+
|
|
231
|
+
self.assertEqual(self.load_state_dict.call_count, 1)
|
|
232
|
+
|
|
233
|
+
@patch("torchft.manager.ManagerClient", autospec=True)
|
|
234
|
+
def test_quorum_heal_async_not_enough_participants(
|
|
235
|
+
self, client_mock: MagicMock
|
|
236
|
+
) -> None:
|
|
237
|
+
manager = self._create_manager(use_async_quorum=True, min_replica_size=2)
|
|
238
|
+
client_mock().should_commit = mock_should_commit
|
|
239
|
+
|
|
240
|
+
quorum = QuorumResult()
|
|
241
|
+
quorum.quorum_id = 123
|
|
242
|
+
quorum.replica_rank = 1
|
|
243
|
+
quorum.replica_world_size = 2
|
|
244
|
+
quorum.recover_src_manager_address = "manager address"
|
|
245
|
+
quorum.recover_src_replica_rank = 0
|
|
246
|
+
quorum.store_address = f"localhost:{self.store.port}"
|
|
247
|
+
quorum.max_step = 20
|
|
248
|
+
quorum.max_replica_rank = None
|
|
249
|
+
quorum.max_world_size = 1
|
|
250
|
+
quorum.heal = True
|
|
251
|
+
|
|
252
|
+
client_mock()._quorum.return_value = quorum
|
|
253
|
+
|
|
254
|
+
# forcible increment checkpoint server to compute correct address
|
|
255
|
+
manager._checkpoint_transport.send_checkpoint(
|
|
256
|
+
dst_ranks=[],
|
|
257
|
+
step=quorum.max_step,
|
|
258
|
+
state_dict=manager._manager_state_dict(),
|
|
259
|
+
timeout=timedelta(seconds=10),
|
|
260
|
+
)
|
|
261
|
+
client_mock()._checkpoint_metadata.return_value = (
|
|
262
|
+
manager._checkpoint_transport.metadata()
|
|
263
|
+
)
|
|
264
|
+
|
|
265
|
+
self.assertEqual(manager._quorum_id, -1)
|
|
266
|
+
self.assertEqual(manager.current_step(), 0)
|
|
267
|
+
|
|
268
|
+
manager.start_quorum()
|
|
269
|
+
assert manager._quorum_future is not None
|
|
270
|
+
manager._quorum_future.result()
|
|
271
|
+
self.assertTrue(manager._healing)
|
|
272
|
+
self.assertFalse(manager.is_participating())
|
|
273
|
+
self.assertEqual(manager.num_participants(), 1)
|
|
274
|
+
|
|
275
|
+
grad = torch.tensor([1.0])
|
|
276
|
+
manager.allreduce(grad).wait()
|
|
277
|
+
torch.testing.assert_close(grad, torch.zeros_like(grad))
|
|
278
|
+
# don't commit since num_max < min_replica_size
|
|
279
|
+
self.assertFalse(manager.should_commit())
|
|
280
|
+
self.assertEqual(manager.current_step(), 20)
|
|
281
|
+
|
|
282
|
+
self.assertEqual(manager._quorum_id, 123)
|
|
283
|
+
self.assertEqual(manager.current_step(), 20)
|
|
284
|
+
# pyre-ignore[16]: _pg is mocked
|
|
285
|
+
self.assertEqual(manager._pg.allreduce.call_count, 1)
|
|
286
|
+
# pyre-ignore[16]: _pg is mocked
|
|
287
|
+
self.assertEqual(manager._pg.allreduce.return_value.get_future.call_count, 1)
|
|
288
|
+
|
|
289
|
+
self.assertEqual(self.load_state_dict.call_count, 1)
|
|
290
|
+
|
|
291
|
+
# failed to commit so no step
|
|
292
|
+
quorum.heal = False
|
|
293
|
+
manager.start_quorum()
|
|
294
|
+
self.assertEqual(manager.current_step(), 20)
|
|
295
|
+
self.assertEqual(manager.batches_committed(), 0)
|
|
296
|
+
|
|
297
|
+
@patch("torchft.manager.ManagerClient", autospec=True)
|
|
298
|
+
def test_quorum_heal_async_zero_grad(self, client_mock: MagicMock) -> None:
|
|
299
|
+
manager = self._create_manager(use_async_quorum=True, min_replica_size=1)
|
|
300
|
+
client_mock().should_commit = mock_should_commit
|
|
301
|
+
|
|
302
|
+
quorum = QuorumResult()
|
|
303
|
+
quorum.quorum_id = 123
|
|
304
|
+
quorum.replica_rank = 1
|
|
305
|
+
quorum.replica_world_size = 2
|
|
306
|
+
quorum.recover_src_manager_address = "manager address"
|
|
307
|
+
quorum.recover_src_replica_rank = 0
|
|
308
|
+
quorum.store_address = f"localhost:{self.store.port}"
|
|
309
|
+
quorum.max_step = 20
|
|
310
|
+
quorum.max_replica_rank = None
|
|
311
|
+
quorum.max_world_size = 1
|
|
312
|
+
quorum.heal = True
|
|
313
|
+
|
|
314
|
+
client_mock()._quorum.return_value = quorum
|
|
315
|
+
|
|
316
|
+
# forceable increment checkpoint server to compute correct address
|
|
317
|
+
manager._checkpoint_transport.send_checkpoint(
|
|
318
|
+
dst_ranks=[],
|
|
319
|
+
step=quorum.max_step,
|
|
320
|
+
state_dict=manager._manager_state_dict(),
|
|
321
|
+
timeout=timedelta(seconds=10),
|
|
322
|
+
)
|
|
323
|
+
client_mock()._checkpoint_metadata.return_value = (
|
|
324
|
+
manager._checkpoint_transport.metadata()
|
|
325
|
+
)
|
|
326
|
+
|
|
327
|
+
self.assertEqual(manager._quorum_id, -1)
|
|
328
|
+
self.assertEqual(manager.current_step(), 0)
|
|
329
|
+
|
|
330
|
+
manager.start_quorum()
|
|
331
|
+
assert manager._quorum_future is not None
|
|
332
|
+
manager._quorum_future.result()
|
|
333
|
+
self.assertTrue(manager._healing)
|
|
334
|
+
|
|
335
|
+
grad = torch.tensor([1.0])
|
|
336
|
+
manager.allreduce(grad).wait()
|
|
337
|
+
torch.testing.assert_close(grad, torch.zeros_like(grad))
|
|
338
|
+
# don't commit since num_max < min_replica_size
|
|
339
|
+
self.assertTrue(manager.should_commit())
|
|
340
|
+
self.assertEqual(manager.num_participants(), 1)
|
|
341
|
+
self.assertTrue(manager.current_step(), 21)
|
|
342
|
+
|
|
343
|
+
self.assertEqual(manager._quorum_id, 123)
|
|
344
|
+
# pyre-ignore[16]: _pg is mocked
|
|
345
|
+
self.assertEqual(manager._pg.allreduce.call_count, 1)
|
|
346
|
+
# pyre-ignore[16]: _pg is mocked
|
|
347
|
+
self.assertEqual(manager._pg.allreduce.return_value.get_future.call_count, 1)
|
|
348
|
+
|
|
349
|
+
self.assertEqual(self.load_state_dict.call_count, 1)
|
|
350
|
+
|
|
351
|
+
# healed
|
|
352
|
+
quorum.heal = False
|
|
353
|
+
manager.start_quorum()
|
|
354
|
+
self.assertEqual(manager.current_step(), 21)
|
|
355
|
+
self.assertEqual(manager.batches_committed(), 1)
|
|
356
|
+
|
|
357
|
+
@patch("torchft.manager.ManagerClient", autospec=True)
|
|
358
|
+
def test_allreduce_error(self, client_mock: MagicMock) -> None:
|
|
359
|
+
manager = self._create_manager()
|
|
360
|
+
client_mock().should_commit = mock_should_commit
|
|
361
|
+
|
|
362
|
+
quorum = QuorumResult()
|
|
363
|
+
quorum.quorum_id = 123
|
|
364
|
+
quorum.replica_rank = 1
|
|
365
|
+
quorum.replica_world_size = 2
|
|
366
|
+
quorum.recover_src_manager_address = "manager address"
|
|
367
|
+
quorum.store_address = f"localhost:{self.store.port}"
|
|
368
|
+
quorum.max_step = 1
|
|
369
|
+
quorum.max_replica_rank = 1
|
|
370
|
+
quorum.max_world_size = 2
|
|
371
|
+
quorum.heal = False
|
|
372
|
+
|
|
373
|
+
client_mock()._quorum.return_value = quorum
|
|
374
|
+
|
|
375
|
+
self.assertEqual(manager._quorum_id, -1)
|
|
376
|
+
self.assertEqual(manager.current_step(), 0)
|
|
377
|
+
|
|
378
|
+
manager.start_quorum()
|
|
379
|
+
manager.allreduce(torch.tensor([1.0])).wait()
|
|
380
|
+
# pyre-ignore[16]: _pg is mocked
|
|
381
|
+
self.assertEqual(manager._pg.allreduce.call_count, 1)
|
|
382
|
+
|
|
383
|
+
# inject failure when work queued
|
|
384
|
+
# pyre-ignore[16]: _pg is mocked
|
|
385
|
+
manager._pg.allreduce.side_effect = RuntimeError("injected failure")
|
|
386
|
+
manager.allreduce(torch.tensor([1.0])).wait()
|
|
387
|
+
self.assertTrue(manager._errored)
|
|
388
|
+
# this should be skipped due to error
|
|
389
|
+
manager.allreduce(torch.tensor([1.0])).wait()
|
|
390
|
+
self.assertEqual(manager._pg.allreduce.call_count, 2)
|
|
391
|
+
# pyre-ignore[16]: _pg is mocked
|
|
392
|
+
self.assertEqual(manager._pg.allreduce.return_value.get_future.call_count, 1)
|
|
393
|
+
|
|
394
|
+
self.assertFalse(manager.should_commit())
|
|
395
|
+
self.assertTrue(manager._errored)
|
|
396
|
+
|
|
397
|
+
# cleanup
|
|
398
|
+
manager._pg.allreduce.side_effect = None
|
|
399
|
+
|
|
400
|
+
# inject failure when worked waited
|
|
401
|
+
quorum.max_step = 2
|
|
402
|
+
|
|
403
|
+
manager.start_quorum()
|
|
404
|
+
|
|
405
|
+
self.assertFalse(manager._errored)
|
|
406
|
+
|
|
407
|
+
bad_fut = torch.futures.Future()
|
|
408
|
+
bad_fut.set_exception(RuntimeError("injected failure"))
|
|
409
|
+
manager._pg.allreduce.return_value.get_future.return_value = bad_fut
|
|
410
|
+
manager.allreduce(torch.tensor([1.0])).wait()
|
|
411
|
+
self.assertEqual(manager._pg.allreduce.return_value.get_future.call_count, 2)
|
|
412
|
+
self.assertTrue(manager._errored)
|
|
413
|
+
self.assertFalse(manager.should_commit())
|
|
414
|
+
self.assertTrue(manager._errored)
|
|
415
|
+
|
|
416
|
+
# cleanup
|
|
417
|
+
manager._pg.allreduce.reset_mock(return_value=True)
|
|
418
|
+
|
|
419
|
+
# recover on next step
|
|
420
|
+
quorum.max_step = 3
|
|
421
|
+
|
|
422
|
+
manager.start_quorum()
|
|
423
|
+
manager.allreduce(torch.tensor([1.0])).wait()
|
|
424
|
+
self.assertTrue(manager.should_commit())
|
|
425
|
+
|
|
426
|
+
@patch("torchft.manager.ManagerClient", autospec=True)
|
|
427
|
+
def test_pg_errored(self, client_mock: MagicMock) -> None:
|
|
428
|
+
manager = self._create_manager()
|
|
429
|
+
client_mock().should_commit = mock_should_commit
|
|
430
|
+
|
|
431
|
+
quorum = QuorumResult()
|
|
432
|
+
quorum.quorum_id = 123
|
|
433
|
+
quorum.replica_rank = 1
|
|
434
|
+
quorum.replica_world_size = 2
|
|
435
|
+
quorum.recover_src_manager_address = "manager address"
|
|
436
|
+
quorum.store_address = f"localhost:{self.store.port}"
|
|
437
|
+
quorum.max_step = 1
|
|
438
|
+
quorum.max_replica_rank = 1
|
|
439
|
+
quorum.max_world_size = 2
|
|
440
|
+
quorum.heal = False
|
|
441
|
+
|
|
442
|
+
client_mock()._quorum.return_value = quorum
|
|
443
|
+
|
|
444
|
+
self.assertEqual(manager._quorum_id, -1)
|
|
445
|
+
self.assertEqual(manager.current_step(), 0)
|
|
446
|
+
|
|
447
|
+
manager.start_quorum()
|
|
448
|
+
|
|
449
|
+
injected_failure = RuntimeError("injected failure")
|
|
450
|
+
|
|
451
|
+
# pyre-ignore[16]: _pg is mocked
|
|
452
|
+
manager._pg.errored.return_value = injected_failure
|
|
453
|
+
|
|
454
|
+
self.assertFalse(manager.should_commit())
|
|
455
|
+
assert manager._errored is not None
|
|
456
|
+
self.assertEqual(manager._errored.original_exception, injected_failure)
|
|
457
|
+
# pyre-ignore[16]: _pg is mocked
|
|
458
|
+
self.assertEqual(manager._pg.errored.call_count, 1)
|
|
459
|
+
|
|
460
|
+
@patch("torchft.manager.ManagerClient", autospec=True)
|
|
461
|
+
def test_quorum_fixed_world_size(self, client_mock: MagicMock) -> None:
|
|
462
|
+
# test active and spares
|
|
463
|
+
for rank in [1, 2]:
|
|
464
|
+
manager = self._create_manager(
|
|
465
|
+
min_replica_size=2,
|
|
466
|
+
world_size_mode=WorldSizeMode.FIXED_WITH_SPARES,
|
|
467
|
+
)
|
|
468
|
+
client_mock().should_commit = mock_should_commit
|
|
469
|
+
|
|
470
|
+
quorum = QuorumResult()
|
|
471
|
+
quorum.quorum_id = 123
|
|
472
|
+
quorum.replica_rank = rank
|
|
473
|
+
quorum.replica_world_size = 3
|
|
474
|
+
quorum.recover_src_manager_address = "manager address"
|
|
475
|
+
quorum.store_address = f"localhost:{self.store.port}"
|
|
476
|
+
quorum.max_step = 1
|
|
477
|
+
quorum.max_replica_rank = rank
|
|
478
|
+
quorum.max_world_size = 3
|
|
479
|
+
quorum.heal = False
|
|
480
|
+
|
|
481
|
+
client_mock()._quorum.return_value = quorum
|
|
482
|
+
|
|
483
|
+
self.assertEqual(manager._quorum_id, -1)
|
|
484
|
+
self.assertEqual(manager.current_step(), 0)
|
|
485
|
+
self.assertEqual(manager.batches_committed(), 0)
|
|
486
|
+
|
|
487
|
+
manager.start_quorum()
|
|
488
|
+
manager.allreduce(torch.tensor([1.0])).wait()
|
|
489
|
+
|
|
490
|
+
self.assertEqual(manager.is_participating(), rank != 2)
|
|
491
|
+
self.assertEqual(manager.num_participants(), 2)
|
|
492
|
+
|
|
493
|
+
self.assertTrue(manager.should_commit())
|
|
494
|
+
self.assertEqual(manager.batches_committed(), 2)
|
|
495
|
+
self.assertEqual(manager.current_step(), 1)
|
|
496
|
+
|
|
497
|
+
@patch("torchft.manager.ManagerClient", autospec=True)
|
|
498
|
+
def test_quorum_no_healing(self, client_mock: MagicMock) -> None:
|
|
499
|
+
manager = self._create_manager(
|
|
500
|
+
min_replica_size=2,
|
|
501
|
+
)
|
|
502
|
+
client_mock().should_commit = mock_should_commit
|
|
503
|
+
|
|
504
|
+
quorum = QuorumResult()
|
|
505
|
+
quorum.quorum_id = 123
|
|
506
|
+
quorum.replica_rank = 0
|
|
507
|
+
quorum.replica_world_size = 3
|
|
508
|
+
quorum.recover_src_manager_address = "manager address"
|
|
509
|
+
quorum.recover_src_replica_rank = 1
|
|
510
|
+
quorum.store_address = f"localhost:{self.store.port}"
|
|
511
|
+
quorum.max_step = 1
|
|
512
|
+
quorum.max_replica_rank = None
|
|
513
|
+
quorum.max_world_size = 2
|
|
514
|
+
quorum.heal = True
|
|
515
|
+
client_mock()._quorum.return_value = quorum
|
|
516
|
+
|
|
517
|
+
self.assertEqual(manager._quorum_id, -1)
|
|
518
|
+
self.assertEqual(manager.current_step(), 0)
|
|
519
|
+
self.assertEqual(manager.batches_committed(), 0)
|
|
520
|
+
|
|
521
|
+
manager.start_quorum(allow_heal=False)
|
|
522
|
+
manager.allreduce(torch.tensor([1.0])).wait()
|
|
523
|
+
|
|
524
|
+
self.assertFalse(manager.is_participating())
|
|
525
|
+
self.assertEqual(manager.num_participants(), 2)
|
|
526
|
+
|
|
527
|
+
self.assertTrue(manager.should_commit())
|
|
528
|
+
self.assertEqual(manager.batches_committed(), 2)
|
|
529
|
+
self.assertEqual(manager.current_step(), 1)
|
|
530
|
+
|
|
531
|
+
@patch("torchft.manager.ManagerClient", autospec=True)
|
|
532
|
+
def test_manager_report_error(self, client_mock: MagicMock) -> None:
|
|
533
|
+
manager = self._create_manager()
|
|
534
|
+
|
|
535
|
+
self.assertIsNone(manager.errored())
|
|
536
|
+
e = RuntimeError("some error")
|
|
537
|
+
manager.report_error(e)
|
|
538
|
+
error = manager.errored()
|
|
539
|
+
assert error is not None
|
|
540
|
+
self.assertIs(error.original_exception, e)
|
|
541
|
+
|
|
542
|
+
@patch("torchft.manager.ManagerClient", autospec=True)
|
|
543
|
+
def test_manager_wrap_future(self, client_mock: MagicMock) -> None:
|
|
544
|
+
manager = self._create_manager()
|
|
545
|
+
|
|
546
|
+
self.assertIsNone(manager.errored())
|
|
547
|
+
|
|
548
|
+
fut = torch.futures.Future()
|
|
549
|
+
wrapped_fut = manager.wrap_future(fut, 2)
|
|
550
|
+
self.assertIsNone(manager.errored())
|
|
551
|
+
|
|
552
|
+
e = RuntimeError("injected failure")
|
|
553
|
+
fut.set_exception(e)
|
|
554
|
+
error = manager.errored()
|
|
555
|
+
assert error is not None
|
|
556
|
+
self.assertIs(error.original_exception, e)
|
|
557
|
+
self.assertEqual(wrapped_fut.value(), 2)
|
|
558
|
+
|
|
559
|
+
@patch("torchft.manager.ManagerClient", autospec=True)
|
|
560
|
+
def test_manager_wrap_future_timeout(self, client_mock: MagicMock) -> None:
|
|
561
|
+
manager = self._create_manager(timeout=timedelta(seconds=0.01))
|
|
562
|
+
|
|
563
|
+
self.assertFalse(manager.errored())
|
|
564
|
+
|
|
565
|
+
fut = torch.futures.Future()
|
|
566
|
+
wrapped_fut = manager.wrap_future(fut, 2)
|
|
567
|
+
wrapped_fut.wait()
|
|
568
|
+
error = manager.errored()
|
|
569
|
+
assert error is not None
|
|
570
|
+
with self.assertRaisesRegex(
|
|
571
|
+
TimeoutError, "future did not complete within.*0.01"
|
|
572
|
+
):
|
|
573
|
+
raise error.original_exception
|
|
574
|
+
|
|
575
|
+
@patch("torchft.manager.ManagerClient", autospec=True)
|
|
576
|
+
def test_manager_numerics(self, client_mock: MagicMock) -> None:
|
|
577
|
+
manager = self._create_manager()
|
|
578
|
+
|
|
579
|
+
manager._quorum_future = quorum_future = MagicMock(
|
|
580
|
+
spec=concurrent.futures.Future
|
|
581
|
+
)
|
|
582
|
+
manager._participating_replica_rank = 1
|
|
583
|
+
manager._participating_replica_world_size = 5
|
|
584
|
+
self.assertEqual(manager.num_participants(), 5)
|
|
585
|
+
self.assertEqual(quorum_future.result.call_count, 1)
|
|
586
|
+
self.assertEqual(manager.participating_rank(), 1)
|
|
587
|
+
self.assertEqual(quorum_future.result.call_count, 2)
|
|
588
|
+
|
|
589
|
+
# pyre-ignore[16]: _pg is mocked
|
|
590
|
+
manager._pg.allreduce.return_value = _DummyWork(None)
|
|
591
|
+
|
|
592
|
+
self.assertTrue(manager.is_participating())
|
|
593
|
+
|
|
594
|
+
for dtype in (torch.float16, torch.bfloat16, torch.float32, torch.long):
|
|
595
|
+
orig = torch.tensor([10], dtype=dtype)
|
|
596
|
+
|
|
597
|
+
if torch.is_floating_point(orig):
|
|
598
|
+
tensor = orig.clone()
|
|
599
|
+
manager.allreduce(tensor).wait()
|
|
600
|
+
torch.testing.assert_close(tensor, orig / 5)
|
|
601
|
+
|
|
602
|
+
tensor = orig.clone()
|
|
603
|
+
manager.allreduce(tensor, reduce_op=ReduceOp.AVG).wait()
|
|
604
|
+
torch.testing.assert_close(tensor, orig / 5)
|
|
605
|
+
|
|
606
|
+
for reduce_op in [
|
|
607
|
+
ReduceOp.SUM,
|
|
608
|
+
ReduceOp.MAX,
|
|
609
|
+
ReduceOp.MIN,
|
|
610
|
+
ReduceOp.PRODUCT,
|
|
611
|
+
]:
|
|
612
|
+
tensor = orig.clone()
|
|
613
|
+
manager.allreduce(tensor, reduce_op=reduce_op).wait()
|
|
614
|
+
torch.testing.assert_close(tensor, orig)
|
|
615
|
+
|
|
616
|
+
# check healing numerics
|
|
617
|
+
manager._healing = True
|
|
618
|
+
self.assertFalse(manager.is_participating())
|
|
619
|
+
tensor = torch.tensor([1.0])
|
|
620
|
+
work = manager.allreduce(tensor)
|
|
621
|
+
work.wait()
|
|
622
|
+
torch.testing.assert_close(tensor, torch.tensor([0.0]))
|
|
623
|
+
|
|
624
|
+
@patch("torchft.manager.ManagerClient", autospec=True)
|
|
625
|
+
def test_quorum_happy_timeouts(self, client_mock: MagicMock) -> None:
|
|
626
|
+
manager = self._create_manager(use_async_quorum=False)
|
|
627
|
+
|
|
628
|
+
quorum = QuorumResult()
|
|
629
|
+
quorum.quorum_id = 123
|
|
630
|
+
quorum.replica_rank = 1
|
|
631
|
+
quorum.replica_world_size = 2
|
|
632
|
+
quorum.recover_src_manager_address = "manager address"
|
|
633
|
+
quorum.store_address = f"localhost:{self.store.port}"
|
|
634
|
+
quorum.max_step = 1
|
|
635
|
+
quorum.max_replica_rank = 1
|
|
636
|
+
quorum.max_world_size = 2
|
|
637
|
+
quorum.heal = False
|
|
638
|
+
|
|
639
|
+
client_mock()._quorum.return_value = quorum
|
|
640
|
+
|
|
641
|
+
manager.start_quorum(timeout=timedelta(seconds=12))
|
|
642
|
+
self.assertEqual(
|
|
643
|
+
client_mock()._quorum.call_args.kwargs["timeout"], timedelta(seconds=12)
|
|
644
|
+
)
|
|
645
|
+
|
|
646
|
+
self.assertTrue(manager.should_commit(timeout=timedelta(seconds=23)))
|
|
647
|
+
self.assertEqual(
|
|
648
|
+
client_mock().should_commit.call_args.kwargs["timeout"],
|
|
649
|
+
timedelta(seconds=23),
|
|
650
|
+
)
|
|
651
|
+
|
|
652
|
+
@patch("torchft.manager.ManagerClient", autospec=True)
|
|
653
|
+
def test_quorum_skip_init(self, client_mock: MagicMock) -> None:
|
|
654
|
+
manager = self._create_manager(
|
|
655
|
+
use_async_quorum=False,
|
|
656
|
+
init_sync=False,
|
|
657
|
+
)
|
|
658
|
+
|
|
659
|
+
self.assertFalse(manager._init_sync)
|
|
660
|
+
|
|
661
|
+
quorum = QuorumResult()
|
|
662
|
+
quorum.quorum_id = 123
|
|
663
|
+
quorum.replica_rank = 1
|
|
664
|
+
quorum.replica_world_size = 2
|
|
665
|
+
quorum.recover_src_manager_address = "manager address"
|
|
666
|
+
quorum.store_address = f"localhost:{self.store.port}"
|
|
667
|
+
quorum.max_step = 1
|
|
668
|
+
quorum.max_replica_rank = 1
|
|
669
|
+
quorum.max_world_size = 2
|
|
670
|
+
quorum.heal = False
|
|
671
|
+
|
|
672
|
+
client_mock()._quorum.return_value = quorum
|
|
673
|
+
|
|
674
|
+
manager.start_quorum()
|
|
675
|
+
self.assertEqual(client_mock()._quorum.call_args.kwargs["init_sync"], False)
|
|
676
|
+
|
|
677
|
+
manager._init_sync = True
|
|
678
|
+
manager.start_quorum()
|
|
679
|
+
self.assertEqual(client_mock()._quorum.call_args.kwargs["init_sync"], True)
|
|
680
|
+
|
|
681
|
+
@patch("torchft.manager.ManagerClient", autospec=True)
|
|
682
|
+
def test_quorum_checkpoint_errors(self, client_mock: MagicMock) -> None:
|
|
683
|
+
manager = self._create_manager(use_async_quorum=True)
|
|
684
|
+
client_mock().should_commit = MagicMock(return_value=False)
|
|
685
|
+
|
|
686
|
+
transport = MagicMock(spec=CheckpointTransport)
|
|
687
|
+
transport.send_checkpoint.side_effect = RuntimeError("send failure")
|
|
688
|
+
transport.recv_checkpoint.side_effect = RuntimeError("recv failure")
|
|
689
|
+
manager._checkpoint_transport = transport
|
|
690
|
+
|
|
691
|
+
quorum = QuorumResult()
|
|
692
|
+
quorum.quorum_id = 123
|
|
693
|
+
quorum.replica_rank = 1
|
|
694
|
+
quorum.replica_world_size = 2
|
|
695
|
+
quorum.recover_src_manager_address = "manager address"
|
|
696
|
+
quorum.recover_src_replica_rank = 0
|
|
697
|
+
quorum.store_address = f"localhost:{self.store.port}"
|
|
698
|
+
quorum.max_step = 20
|
|
699
|
+
quorum.max_replica_rank = None
|
|
700
|
+
quorum.max_world_size = 2
|
|
701
|
+
quorum.heal = True
|
|
702
|
+
|
|
703
|
+
client_mock()._quorum.return_value = quorum
|
|
704
|
+
|
|
705
|
+
manager.start_quorum()
|
|
706
|
+
manager.wait_quorum()
|
|
707
|
+
self.assertFalse(manager.should_commit())
|
|
708
|
+
|
|
709
|
+
error = manager.errored()
|
|
710
|
+
assert error is not None
|
|
711
|
+
with self.assertRaisesRegex(RuntimeError, "recv failure"):
|
|
712
|
+
raise error.original_exception
|
|
713
|
+
|
|
714
|
+
quorum.recover_dst_replica_ranks = [0]
|
|
715
|
+
manager.start_quorum()
|
|
716
|
+
manager.wait_quorum()
|
|
717
|
+
self.assertFalse(manager.should_commit())
|
|
718
|
+
|
|
719
|
+
error = manager.errored()
|
|
720
|
+
assert error is not None
|
|
721
|
+
with self.assertRaisesRegex(RuntimeError, "send failure"):
|
|
722
|
+
raise error.original_exception
|
|
723
|
+
|
|
724
|
+
@patch("torchft.manager.ManagerClient", autospec=True)
|
|
725
|
+
def test_quorum_configure_errors(self, client_mock: MagicMock) -> None:
|
|
726
|
+
manager = self._create_manager(use_async_quorum=True)
|
|
727
|
+
client_mock().should_commit = MagicMock(return_value=False)
|
|
728
|
+
|
|
729
|
+
# pyre-ignore[16]: mock
|
|
730
|
+
manager._pg.configure.side_effect = RuntimeError("configure failure")
|
|
731
|
+
|
|
732
|
+
quorum = QuorumResult()
|
|
733
|
+
quorum.quorum_id = 123
|
|
734
|
+
quorum.replica_rank = 1
|
|
735
|
+
quorum.replica_world_size = 2
|
|
736
|
+
quorum.recover_src_manager_address = "manager address"
|
|
737
|
+
quorum.recover_src_replica_rank = 0
|
|
738
|
+
quorum.store_address = f"localhost:{self.store.port}"
|
|
739
|
+
quorum.max_step = 20
|
|
740
|
+
quorum.max_replica_rank = None
|
|
741
|
+
quorum.max_world_size = 2
|
|
742
|
+
|
|
743
|
+
client_mock()._quorum.return_value = quorum
|
|
744
|
+
|
|
745
|
+
manager.start_quorum()
|
|
746
|
+
manager.wait_quorum()
|
|
747
|
+
self.assertFalse(manager.should_commit())
|
|
748
|
+
|
|
749
|
+
error = manager.errored()
|
|
750
|
+
assert error is not None
|
|
751
|
+
with self.assertRaisesRegex(RuntimeError, "configure failure"):
|
|
752
|
+
raise error.original_exception
|
|
753
|
+
|
|
754
|
+
@patch("torchft.manager.ManagerClient", autospec=True)
|
|
755
|
+
def test_max_retries(self, client_mock: MagicMock) -> None:
|
|
756
|
+
# Create a manager with max_retries=2
|
|
757
|
+
manager = self._create_manager(max_retries=2)
|
|
758
|
+
|
|
759
|
+
# Setup quorum for testing
|
|
760
|
+
quorum = QuorumResult()
|
|
761
|
+
quorum.quorum_id = 123
|
|
762
|
+
quorum.replica_rank = 1
|
|
763
|
+
quorum.replica_world_size = 2
|
|
764
|
+
quorum.recover_src_manager_address = "manager address"
|
|
765
|
+
quorum.store_address = f"localhost:{self.store.port}"
|
|
766
|
+
quorum.max_step = 1
|
|
767
|
+
quorum.max_replica_rank = 1
|
|
768
|
+
quorum.max_world_size = 2
|
|
769
|
+
quorum.heal = False
|
|
770
|
+
client_mock()._quorum.return_value = quorum
|
|
771
|
+
|
|
772
|
+
# Make should_commit always return False to simulate failures
|
|
773
|
+
client_mock().should_commit = MagicMock(return_value=False)
|
|
774
|
+
|
|
775
|
+
# Start quorum
|
|
776
|
+
manager.start_quorum()
|
|
777
|
+
|
|
778
|
+
# First failure
|
|
779
|
+
self.assertFalse(manager.should_commit())
|
|
780
|
+
self.assertEqual(manager._commit_failures, 1)
|
|
781
|
+
|
|
782
|
+
# Second failure
|
|
783
|
+
self.assertFalse(manager.should_commit())
|
|
784
|
+
self.assertEqual(manager._commit_failures, 2)
|
|
785
|
+
|
|
786
|
+
# Third failure - should raise exception
|
|
787
|
+
with self.assertRaises(RuntimeError) as context:
|
|
788
|
+
manager.should_commit()
|
|
789
|
+
|
|
790
|
+
self.assertIn("exceeding max_retries=2", str(context.exception))
|
|
791
|
+
self.assertEqual(manager._commit_failures, 3)
|
|
792
|
+
|
|
793
|
+
# Now test that success resets the counter
|
|
794
|
+
manager._commit_failures = 2 # Reset to just before failure threshold
|
|
795
|
+
client_mock().should_commit = MagicMock(return_value=True) # Now succeed
|
|
796
|
+
|
|
797
|
+
# This should succeed and reset the counter
|
|
798
|
+
self.assertTrue(manager.should_commit())
|
|
799
|
+
self.assertEqual(manager._commit_failures, 0)
|
|
800
|
+
|
|
801
|
+
@patch("torchft.manager.ManagerClient", autospec=True)
|
|
802
|
+
def test_state_dict_lock_allow_disallow(self, client_mock: MagicMock) -> None:
|
|
803
|
+
"""Test that allow_state_dict_read and disallow_state_dict_read methods work correctly."""
|
|
804
|
+
manager = self._create_manager()
|
|
805
|
+
|
|
806
|
+
# Initially, state dict read should be allowed
|
|
807
|
+
self.assertTrue(manager._is_state_dict_read_allowed)
|
|
808
|
+
|
|
809
|
+
# Test disallow_state_dict_read
|
|
810
|
+
manager.disallow_state_dict_read()
|
|
811
|
+
self.assertFalse(manager._is_state_dict_read_allowed)
|
|
812
|
+
self.assertTrue(manager._state_dict_lock.w_locked())
|
|
813
|
+
|
|
814
|
+
# Calling disallow_state_dict_read again should be a no-op
|
|
815
|
+
manager.disallow_state_dict_read()
|
|
816
|
+
self.assertFalse(manager._is_state_dict_read_allowed)
|
|
817
|
+
self.assertTrue(manager._state_dict_lock.w_locked())
|
|
818
|
+
|
|
819
|
+
# Test allow_state_dict_read
|
|
820
|
+
manager.allow_state_dict_read()
|
|
821
|
+
self.assertTrue(manager._is_state_dict_read_allowed)
|
|
822
|
+
self.assertFalse(manager._state_dict_lock.w_locked())
|
|
823
|
+
|
|
824
|
+
# Calling allow_state_dict_read again should be a no-op
|
|
825
|
+
manager.allow_state_dict_read()
|
|
826
|
+
self.assertTrue(manager._is_state_dict_read_allowed)
|
|
827
|
+
self.assertFalse(manager._state_dict_lock.w_locked())
|
|
828
|
+
|
|
829
|
+
@patch("torchft.manager.ManagerClient", autospec=True)
|
|
830
|
+
def test_state_dict_lock_concurrent_access(self, client_mock: MagicMock) -> None:
|
|
831
|
+
"""Test that _state_dict_lock properly protects concurrent access to the state dictionary."""
|
|
832
|
+
manager: Manager = self._create_manager()
|
|
833
|
+
|
|
834
|
+
# Create flags for thread synchronization
|
|
835
|
+
access_attempted: threading.Event = threading.Event()
|
|
836
|
+
can_proceed: threading.Event = threading.Event()
|
|
837
|
+
access_result: dict[str, bool] = {"succeeded": False}
|
|
838
|
+
|
|
839
|
+
def try_access_state_dict() -> None:
|
|
840
|
+
# Wait until the main thread signals it's ready
|
|
841
|
+
nonlocal access_attempted, can_proceed, access_result, manager
|
|
842
|
+
access_attempted.set()
|
|
843
|
+
can_proceed.wait(timeout=1.0)
|
|
844
|
+
|
|
845
|
+
# Try to access the state dict
|
|
846
|
+
if manager._is_state_dict_read_allowed:
|
|
847
|
+
access_result["succeeded"] = True
|
|
848
|
+
|
|
849
|
+
# Start a thread that will try to access the state dict
|
|
850
|
+
thread = threading.Thread(target=try_access_state_dict)
|
|
851
|
+
thread.daemon = True
|
|
852
|
+
thread.start()
|
|
853
|
+
|
|
854
|
+
# Disallow state dict read
|
|
855
|
+
manager.disallow_state_dict_read()
|
|
856
|
+
self.assertFalse(manager._is_state_dict_read_allowed)
|
|
857
|
+
|
|
858
|
+
# Wait for the thread to be ready
|
|
859
|
+
access_attempted.wait(timeout=1.0)
|
|
860
|
+
|
|
861
|
+
# Signal the thread to proceed while state dict read is disallowed
|
|
862
|
+
can_proceed.set()
|
|
863
|
+
thread.join(timeout=1.0)
|
|
864
|
+
|
|
865
|
+
# The thread should not have been able to access the state dict
|
|
866
|
+
self.assertFalse(access_result["succeeded"])
|
|
867
|
+
|
|
868
|
+
# Reset for the second part of the test
|
|
869
|
+
access_attempted.clear()
|
|
870
|
+
can_proceed.clear()
|
|
871
|
+
|
|
872
|
+
# Start another thread
|
|
873
|
+
thread = threading.Thread(target=try_access_state_dict)
|
|
874
|
+
thread.daemon = True
|
|
875
|
+
thread.start()
|
|
876
|
+
|
|
877
|
+
# Allow state dict read
|
|
878
|
+
manager.allow_state_dict_read()
|
|
879
|
+
self.assertTrue(manager._is_state_dict_read_allowed)
|
|
880
|
+
|
|
881
|
+
# Wait for the thread to be ready
|
|
882
|
+
access_attempted.wait(timeout=1.0)
|
|
883
|
+
|
|
884
|
+
# Signal the thread to proceed while state dict read is allowed
|
|
885
|
+
can_proceed.set()
|
|
886
|
+
thread.join(timeout=1.0)
|
|
887
|
+
|
|
888
|
+
# The thread should now have been able to access the state dict
|
|
889
|
+
self.assertTrue(access_result["succeeded"])
|
|
890
|
+
|
|
891
|
+
@patch("torchft.manager.ManagerClient", autospec=True)
|
|
892
|
+
def test_manager_state_dict_with_lock(self, client_mock: MagicMock) -> None:
|
|
893
|
+
"""Test that _manager_state_dict properly uses the read lock."""
|
|
894
|
+
manager = self._create_manager()
|
|
895
|
+
|
|
896
|
+
# Replace the real RWLock with a mock to track lock acquisition
|
|
897
|
+
original_lock = manager._state_dict_lock
|
|
898
|
+
mock_lock = create_autospec(RWLock)
|
|
899
|
+
mock_context = MagicMock()
|
|
900
|
+
mock_lock.r_lock.return_value.__enter__ = lambda _: mock_context
|
|
901
|
+
mock_lock.r_lock.return_value.__exit__ = lambda *args: None
|
|
902
|
+
manager._state_dict_lock = mock_lock
|
|
903
|
+
|
|
904
|
+
# Call _manager_state_dict
|
|
905
|
+
result = manager._manager_state_dict()
|
|
906
|
+
|
|
907
|
+
# Verify that r_lock was called
|
|
908
|
+
mock_lock.r_lock.assert_called_once()
|
|
909
|
+
|
|
910
|
+
# Restore the original lock
|
|
911
|
+
manager._state_dict_lock = original_lock
|