torchft-nightly 2026.1.3__cp310-cp310-manylinux_2_24_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchft/__init__.py +34 -0
- torchft/_test/diloco_trainer.py +287 -0
- torchft/_test/managed_work_test.py +320 -0
- torchft/_test_utils.py +111 -0
- torchft/_torchft.cpython-310-x86_64-linux-gnu.so +0 -0
- torchft/_torchft.pyi +116 -0
- torchft/checkpointing/__init__.py +20 -0
- torchft/checkpointing/_rwlock.py +136 -0
- torchft/checkpointing/_serialization.py +39 -0
- torchft/checkpointing/http_transport.py +299 -0
- torchft/checkpointing/http_transport_bench.py +61 -0
- torchft/checkpointing/http_transport_test.py +146 -0
- torchft/checkpointing/pg_transport.py +306 -0
- torchft/checkpointing/pg_transport_bench.py +99 -0
- torchft/checkpointing/pg_transport_test.py +101 -0
- torchft/checkpointing/rwlock_test.py +58 -0
- torchft/checkpointing/transport.py +68 -0
- torchft/checkpointing/transport_test.py +161 -0
- torchft/collectives.py +415 -0
- torchft/collectives_test.py +212 -0
- torchft/coordination.py +39 -0
- torchft/coordination_test.py +29 -0
- torchft/data.py +77 -0
- torchft/data_test.py +39 -0
- torchft/ddp.py +105 -0
- torchft/ddp_test.py +68 -0
- torchft/diloco_regression_test.py +644 -0
- torchft/examples/slurm/README.md +34 -0
- torchft/examples/slurm/punisher.py +95 -0
- torchft/examples/slurm/runner.py +221 -0
- torchft/fsdp_test.py +102 -0
- torchft/futures.py +353 -0
- torchft/futures_test.py +140 -0
- torchft/http.py +13 -0
- torchft/lighthouse_test.py +163 -0
- torchft/local_sgd.py +796 -0
- torchft/local_sgd_integ_test.py +600 -0
- torchft/local_sgd_test.py +324 -0
- torchft/manager.py +1358 -0
- torchft/manager_integ_test.py +653 -0
- torchft/manager_test.py +911 -0
- torchft/multiprocessing.py +38 -0
- torchft/multiprocessing_dummy_context.py +135 -0
- torchft/multiprocessing_test.py +58 -0
- torchft/optim.py +63 -0
- torchft/optim_test.py +50 -0
- torchft/otel.py +134 -0
- torchft/parameter_server.py +195 -0
- torchft/parameter_server_test.py +47 -0
- torchft/process_group.py +2118 -0
- torchft/process_group_test.py +1028 -0
- torchft/quantization.py +686 -0
- torchft/quantization_test.py +131 -0
- torchft/torchx.py +89 -0
- torchft/utils.py +67 -0
- torchft/work.py +26 -0
- torchft_nightly-2026.1.3.dist-info/METADATA +308 -0
- torchft_nightly-2026.1.3.dist-info/RECORD +61 -0
- torchft_nightly-2026.1.3.dist-info/WHEEL +4 -0
- torchft_nightly-2026.1.3.dist-info/entry_points.txt +2 -0
- torchft_nightly-2026.1.3.dist-info/licenses/LICENSE +34 -0
|
@@ -0,0 +1,212 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
import unittest
|
|
8
|
+
from typing import Callable
|
|
9
|
+
from unittest import skipUnless, TestCase
|
|
10
|
+
|
|
11
|
+
import torch
|
|
12
|
+
import torch.distributed as dist
|
|
13
|
+
from parameterized import parameterized
|
|
14
|
+
from torch import cuda
|
|
15
|
+
from torch.distributed import ReduceOp, ReduceScatterOptions
|
|
16
|
+
|
|
17
|
+
from torchft import _test_utils
|
|
18
|
+
from torchft.process_group import ProcessGroup
|
|
19
|
+
from torchft.process_group_test import MultiPgBaseTest
|
|
20
|
+
|
|
21
|
+
try:
|
|
22
|
+
# pyre-ignore[21]: Could not find a module corresponding to import `triton`
|
|
23
|
+
import triton
|
|
24
|
+
except ImportError:
|
|
25
|
+
pass
|
|
26
|
+
else:
|
|
27
|
+
from torchft.collectives import (
|
|
28
|
+
allocate_reduce_scatter_output,
|
|
29
|
+
allreduce_quantized,
|
|
30
|
+
get_padded_sizes,
|
|
31
|
+
reduce_scatter_quantized,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
def _check_result_tolerance(
|
|
35
|
+
actual: torch.Tensor, expected: torch.Tensor, tolerance: float
|
|
36
|
+
) -> None:
|
|
37
|
+
diff = torch.abs(
|
|
38
|
+
(expected - actual).div(expected.to(torch.float32) + 0.0000001)
|
|
39
|
+
)
|
|
40
|
+
mean_diff = diff.mean().item()
|
|
41
|
+
|
|
42
|
+
if mean_diff > tolerance:
|
|
43
|
+
print(f"Diff: {diff=}\n{expected=}\n{actual=}")
|
|
44
|
+
raise AssertionError(f"Results not within tolerance {tolerance}")
|
|
45
|
+
|
|
46
|
+
@skipUnless(
|
|
47
|
+
torch.cuda.is_available() and torch.cuda.device_count() >= 2,
|
|
48
|
+
"2 CUDA devices are required for this test",
|
|
49
|
+
)
|
|
50
|
+
class QuantizedAllReduceTest(MultiPgBaseTest):
|
|
51
|
+
BACKEND = "nccl"
|
|
52
|
+
WORLD_SIZE = 2
|
|
53
|
+
|
|
54
|
+
def _run_parallel_collectives(
|
|
55
|
+
self, collective: Callable[[ProcessGroup, int, str], None]
|
|
56
|
+
) -> None:
|
|
57
|
+
futures = []
|
|
58
|
+
for rank in range(self.WORLD_SIZE):
|
|
59
|
+
pg = self.pg_pool[rank]
|
|
60
|
+
device = f"cuda:{rank}"
|
|
61
|
+
fut = self.executor.submit(collective, pg, rank, device)
|
|
62
|
+
futures.append(fut)
|
|
63
|
+
|
|
64
|
+
self._collect(futures)
|
|
65
|
+
|
|
66
|
+
def _run_all_reduce_collective(
|
|
67
|
+
self,
|
|
68
|
+
pg: ProcessGroup,
|
|
69
|
+
device: str,
|
|
70
|
+
tensors_num: int,
|
|
71
|
+
tensor_size: int,
|
|
72
|
+
multiplier: float,
|
|
73
|
+
tolerance: float,
|
|
74
|
+
reduce_op: ReduceOp,
|
|
75
|
+
dtype: torch.dtype,
|
|
76
|
+
) -> None:
|
|
77
|
+
cuda.set_device(device)
|
|
78
|
+
inp = (
|
|
79
|
+
torch.rand(
|
|
80
|
+
tensors_num * tensor_size,
|
|
81
|
+
dtype=dtype,
|
|
82
|
+
device=device,
|
|
83
|
+
)
|
|
84
|
+
* multiplier
|
|
85
|
+
)
|
|
86
|
+
for split in _test_utils.gen_splits(inp, tensor_size):
|
|
87
|
+
actual = inp.clone()
|
|
88
|
+
expected = inp.clone()
|
|
89
|
+
tensors = [
|
|
90
|
+
i.view(*s)
|
|
91
|
+
for s, i in zip(
|
|
92
|
+
split,
|
|
93
|
+
torch.split(actual, tensor_size),
|
|
94
|
+
)
|
|
95
|
+
]
|
|
96
|
+
|
|
97
|
+
work = allreduce_quantized(tensors, reduce_op, pg)
|
|
98
|
+
work.wait()
|
|
99
|
+
|
|
100
|
+
work = pg.allreduce([expected], reduce_op)
|
|
101
|
+
work.get_future().wait()
|
|
102
|
+
|
|
103
|
+
_check_result_tolerance(actual, expected, tolerance)
|
|
104
|
+
|
|
105
|
+
def _run_reduce_scatter_collective(
|
|
106
|
+
self,
|
|
107
|
+
pg: ProcessGroup,
|
|
108
|
+
device: str,
|
|
109
|
+
tensors_num: int,
|
|
110
|
+
tensor_size: int,
|
|
111
|
+
multiplier: float,
|
|
112
|
+
tolerance: float,
|
|
113
|
+
reduce_op: ReduceOp,
|
|
114
|
+
dtype: torch.dtype,
|
|
115
|
+
) -> None:
|
|
116
|
+
cuda.set_device(device)
|
|
117
|
+
inp = (
|
|
118
|
+
torch.rand(
|
|
119
|
+
tensors_num * tensor_size,
|
|
120
|
+
dtype=dtype,
|
|
121
|
+
device=device,
|
|
122
|
+
)
|
|
123
|
+
* multiplier
|
|
124
|
+
)
|
|
125
|
+
world_size = pg.size()
|
|
126
|
+
for split in _test_utils.gen_splits(inp, tensor_size):
|
|
127
|
+
actual = inp.clone()
|
|
128
|
+
tensors = [
|
|
129
|
+
i.view(*s)
|
|
130
|
+
for s, i in zip(
|
|
131
|
+
split,
|
|
132
|
+
torch.split(actual, tensor_size),
|
|
133
|
+
)
|
|
134
|
+
]
|
|
135
|
+
|
|
136
|
+
actual_output, _ = allocate_reduce_scatter_output(
|
|
137
|
+
tensors,
|
|
138
|
+
world_size,
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
opts = ReduceScatterOptions()
|
|
142
|
+
opts.reduceOp = reduce_op
|
|
143
|
+
|
|
144
|
+
work = reduce_scatter_quantized(actual_output, tensors, opts, pg)
|
|
145
|
+
work.get_future().wait()
|
|
146
|
+
|
|
147
|
+
padded_sizes = get_padded_sizes(tensors, world_size)
|
|
148
|
+
padded_numel = sum(s.numel() for s in padded_sizes)
|
|
149
|
+
|
|
150
|
+
padded_input = torch.empty(padded_numel, dtype=dtype, device=device)
|
|
151
|
+
torch._chunk_cat(
|
|
152
|
+
tensors, dim=0, num_chunks=world_size, out=padded_input
|
|
153
|
+
)
|
|
154
|
+
|
|
155
|
+
expected_output = torch.empty(
|
|
156
|
+
padded_numel // world_size, dtype=dtype, device=device
|
|
157
|
+
)
|
|
158
|
+
|
|
159
|
+
work = pg.reduce_scatter([expected_output], [[padded_input]], opts)
|
|
160
|
+
work.get_future().wait()
|
|
161
|
+
|
|
162
|
+
_check_result_tolerance(actual_output, expected_output, tolerance)
|
|
163
|
+
|
|
164
|
+
END_TO_END_CONFIGS: list[tuple[int, float, ReduceOp, torch.dtype]] = [
|
|
165
|
+
(ts, m, o, t)
|
|
166
|
+
for ts in [256, 1024, 2048]
|
|
167
|
+
for m in [1.0, 100.0, 1000.0]
|
|
168
|
+
for o in [ReduceOp.AVG, ReduceOp.SUM]
|
|
169
|
+
for t in [torch.float32, torch.float16, torch.bfloat16]
|
|
170
|
+
]
|
|
171
|
+
|
|
172
|
+
@parameterized.expand(END_TO_END_CONFIGS)
|
|
173
|
+
def test_all_reduce_collective(
|
|
174
|
+
self,
|
|
175
|
+
tensor_size: int,
|
|
176
|
+
multiplier: float,
|
|
177
|
+
reduce_op: ReduceOp,
|
|
178
|
+
dtype: torch.dtype,
|
|
179
|
+
) -> None:
|
|
180
|
+
self._run_parallel_collectives(
|
|
181
|
+
lambda pg, _, device: self._run_all_reduce_collective(
|
|
182
|
+
pg,
|
|
183
|
+
device,
|
|
184
|
+
2,
|
|
185
|
+
tensor_size,
|
|
186
|
+
multiplier,
|
|
187
|
+
0.04,
|
|
188
|
+
reduce_op,
|
|
189
|
+
dtype,
|
|
190
|
+
)
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
@parameterized.expand(END_TO_END_CONFIGS)
|
|
194
|
+
def test_reduce_scatter_collective(
|
|
195
|
+
self,
|
|
196
|
+
tensor_size: int,
|
|
197
|
+
multiplier: float,
|
|
198
|
+
reduce_op: ReduceOp,
|
|
199
|
+
dtype: torch.dtype,
|
|
200
|
+
) -> None:
|
|
201
|
+
self._run_parallel_collectives(
|
|
202
|
+
lambda pg, _, device: self._run_reduce_scatter_collective(
|
|
203
|
+
pg,
|
|
204
|
+
device,
|
|
205
|
+
2,
|
|
206
|
+
tensor_size,
|
|
207
|
+
multiplier,
|
|
208
|
+
0.05,
|
|
209
|
+
reduce_op,
|
|
210
|
+
dtype,
|
|
211
|
+
)
|
|
212
|
+
)
|
torchft/coordination.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
"""
|
|
8
|
+
Coordination (Low Level API)
|
|
9
|
+
============================
|
|
10
|
+
|
|
11
|
+
.. warning::
|
|
12
|
+
As torchft is still in development, the APIs in this module are subject to change.
|
|
13
|
+
|
|
14
|
+
This module exposes low level coordination APIs to allow you to build your own
|
|
15
|
+
custom fault tolerance algorithms on top of torchft.
|
|
16
|
+
|
|
17
|
+
If you're looking for a more complete solution, please use the other modules in
|
|
18
|
+
torchft.
|
|
19
|
+
|
|
20
|
+
This provides direct access to the Lighthouse and Manager servers and clients.
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
from torchft._torchft import (
|
|
24
|
+
LighthouseClient,
|
|
25
|
+
LighthouseServer,
|
|
26
|
+
ManagerClient,
|
|
27
|
+
ManagerServer,
|
|
28
|
+
Quorum,
|
|
29
|
+
QuorumMember,
|
|
30
|
+
)
|
|
31
|
+
|
|
32
|
+
__all__ = [
|
|
33
|
+
"LighthouseClient",
|
|
34
|
+
"LighthouseServer",
|
|
35
|
+
"ManagerServer",
|
|
36
|
+
"ManagerClient",
|
|
37
|
+
"Quorum",
|
|
38
|
+
"QuorumMember",
|
|
39
|
+
]
|
|
@@ -0,0 +1,29 @@
|
|
|
1
|
+
import inspect
|
|
2
|
+
from unittest import TestCase
|
|
3
|
+
|
|
4
|
+
from torchft.coordination import (
|
|
5
|
+
LighthouseClient,
|
|
6
|
+
LighthouseServer,
|
|
7
|
+
ManagerClient,
|
|
8
|
+
ManagerServer,
|
|
9
|
+
Quorum,
|
|
10
|
+
QuorumMember,
|
|
11
|
+
)
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class TestCoordination(TestCase):
|
|
15
|
+
def test_coordination_docs(self) -> None:
|
|
16
|
+
classes = [
|
|
17
|
+
ManagerClient,
|
|
18
|
+
ManagerServer,
|
|
19
|
+
LighthouseServer,
|
|
20
|
+
LighthouseClient,
|
|
21
|
+
Quorum,
|
|
22
|
+
QuorumMember,
|
|
23
|
+
]
|
|
24
|
+
for cls in classes:
|
|
25
|
+
self.assertIn("Args:", str(cls.__doc__), cls)
|
|
26
|
+
for name, method in inspect.getmembers(cls, predicate=inspect.ismethod):
|
|
27
|
+
if name.startswith("_"):
|
|
28
|
+
continue
|
|
29
|
+
self.assertIn("Args:", str(cls.__doc__), cls)
|
torchft/data.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
"""
|
|
8
|
+
Data
|
|
9
|
+
====
|
|
10
|
+
|
|
11
|
+
This module provides helper classes to implement fault tolerant data loaders.
|
|
12
|
+
|
|
13
|
+
We recommend using torchdata's StatefulDataLoader to checkpoint each replica's
|
|
14
|
+
dataloader frequently to avoid duplicate batches.
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
from typing import Optional
|
|
18
|
+
|
|
19
|
+
import torch.distributed as dist
|
|
20
|
+
from torch.utils import data
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
# pyre-fixme[24]: expected generic parameter
|
|
24
|
+
class DistributedSampler(data.distributed.DistributedSampler):
|
|
25
|
+
"""
|
|
26
|
+
DistributedSampler extends the standard PyTorch DistributedSampler with a
|
|
27
|
+
`num_replica_groups` that is used to shard the data across the fault
|
|
28
|
+
tolerance replica groups.
|
|
29
|
+
|
|
30
|
+
torchft doesn't know how many replica groups ahead of time so we need to set
|
|
31
|
+
this to be the max number.
|
|
32
|
+
|
|
33
|
+
This sampler is inherently lossy when used with torchft. torchft
|
|
34
|
+
occasionally drops batches on rejoining and if a replica group is down that
|
|
35
|
+
group examples will never be used. This can lead to imbalances if using a
|
|
36
|
+
small dataset.
|
|
37
|
+
|
|
38
|
+
This will shard the input dataset into ``num_replicas*num_replica_group``
|
|
39
|
+
number of shards.
|
|
40
|
+
|
|
41
|
+
Each shard rank is calculated via: ``rank + num_replicas*replica_rank``
|
|
42
|
+
|
|
43
|
+
num_replicas and replica_rank must be the same on all workers.
|
|
44
|
+
"""
|
|
45
|
+
|
|
46
|
+
def __init__(
|
|
47
|
+
self,
|
|
48
|
+
dataset: data.Dataset,
|
|
49
|
+
replica_rank: int,
|
|
50
|
+
num_replica_groups: int,
|
|
51
|
+
group_rank: Optional[int] = None,
|
|
52
|
+
num_replicas: Optional[int] = None,
|
|
53
|
+
**kwargs: object,
|
|
54
|
+
) -> None:
|
|
55
|
+
"""
|
|
56
|
+
Args:
|
|
57
|
+
data: the dataset to use
|
|
58
|
+
replica_rank: the group ID (0-num_replica_groups) to use for this shard of data.
|
|
59
|
+
num_replica_groups: the max number of global replica groups
|
|
60
|
+
rank: the local group rank
|
|
61
|
+
num_replicas: the local group world size
|
|
62
|
+
"""
|
|
63
|
+
if group_rank is None:
|
|
64
|
+
group_rank = dist.get_rank()
|
|
65
|
+
if num_replicas is None:
|
|
66
|
+
num_replicas = dist.get_world_size()
|
|
67
|
+
|
|
68
|
+
self.global_rank: int = group_rank + num_replicas * replica_rank
|
|
69
|
+
self.global_world_size: int = num_replicas * num_replica_groups
|
|
70
|
+
|
|
71
|
+
super().__init__(
|
|
72
|
+
dataset,
|
|
73
|
+
rank=self.global_rank,
|
|
74
|
+
num_replicas=self.global_world_size,
|
|
75
|
+
# pyre-fixme[6]: got object
|
|
76
|
+
**kwargs,
|
|
77
|
+
)
|
torchft/data_test.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
from unittest import TestCase
|
|
8
|
+
|
|
9
|
+
from torch.utils.data import Dataset
|
|
10
|
+
|
|
11
|
+
from torchft.data import DistributedSampler
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
class DummyDataset(Dataset):
|
|
15
|
+
def __init__(self, length: int) -> None:
|
|
16
|
+
self.length = length
|
|
17
|
+
|
|
18
|
+
def __len__(self) -> int:
|
|
19
|
+
return self.length
|
|
20
|
+
|
|
21
|
+
def __getitem__(self, idx: int) -> int:
|
|
22
|
+
return idx
|
|
23
|
+
|
|
24
|
+
|
|
25
|
+
class TestData(TestCase):
|
|
26
|
+
def test_distributed_sampler(self) -> None:
|
|
27
|
+
dataset = DummyDataset(1000)
|
|
28
|
+
sampler = DistributedSampler(
|
|
29
|
+
dataset,
|
|
30
|
+
replica_rank=1,
|
|
31
|
+
num_replica_groups=2,
|
|
32
|
+
group_rank=3,
|
|
33
|
+
num_replicas=4,
|
|
34
|
+
)
|
|
35
|
+
self.assertEqual(sampler.global_rank, 3 + 1 * 4)
|
|
36
|
+
self.assertEqual(sampler.global_world_size, 2 * 4)
|
|
37
|
+
|
|
38
|
+
sampler_iter = iter(sampler)
|
|
39
|
+
self.assertEqual(next(sampler_iter), 500)
|
torchft/ddp.py
ADDED
|
@@ -0,0 +1,105 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
"""
|
|
8
|
+
Distributed Data Parallel
|
|
9
|
+
==========================
|
|
10
|
+
|
|
11
|
+
This module implements a DistributedDataParallel wrapper that works with the
|
|
12
|
+
Manager to provide fault tolerance.
|
|
13
|
+
"""
|
|
14
|
+
|
|
15
|
+
import os
|
|
16
|
+
import sys
|
|
17
|
+
from typing import cast, Optional, TYPE_CHECKING
|
|
18
|
+
from unittest.mock import patch
|
|
19
|
+
|
|
20
|
+
import torch
|
|
21
|
+
import torch.distributed as dist
|
|
22
|
+
from torch import nn
|
|
23
|
+
from torch.distributed.algorithms.join import Joinable
|
|
24
|
+
from torch.nn import parallel
|
|
25
|
+
|
|
26
|
+
from torchft.process_group import ProcessGroup, ProcessGroupDummy, ProcessGroupGloo
|
|
27
|
+
|
|
28
|
+
if TYPE_CHECKING:
|
|
29
|
+
from torchft.manager import _ManagedFuture, Manager
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class DistributedDataParallel(parallel.DistributedDataParallel):
|
|
33
|
+
"""
|
|
34
|
+
This is a patched DistributedDataParallel implementation that makes it
|
|
35
|
+
compatible with torchft.
|
|
36
|
+
|
|
37
|
+
Important notes:
|
|
38
|
+
|
|
39
|
+
* This requires states to be synced on step 0 using an external mechanism
|
|
40
|
+
rather than an internal broadcast (torchft.Manager will do this).
|
|
41
|
+
* Using non-basic features of the DDP may cause your model to catch fire as
|
|
42
|
+
they haven't been tested with torchft.
|
|
43
|
+
* This doesn't any sanity checks such as verifying parameter sizes are the
|
|
44
|
+
same across workers.
|
|
45
|
+
"""
|
|
46
|
+
|
|
47
|
+
def __init__(self, manager: "Manager", module: nn.Module, **kwargs: object) -> None:
|
|
48
|
+
# use a dummy PG to soak up the init all reduce, actual comms will go
|
|
49
|
+
# through the comm_hook.
|
|
50
|
+
pg = ProcessGroupDummy(0, 1)
|
|
51
|
+
|
|
52
|
+
super().__init__(
|
|
53
|
+
module,
|
|
54
|
+
process_group=pg,
|
|
55
|
+
# HACK: This forces the reducer to never rebuild buckets.
|
|
56
|
+
# The reducer normally rebuilds the buckets after the first training
|
|
57
|
+
# step which can improve performance but is incompatible with
|
|
58
|
+
# torchft as it will cause the buckets to diverge for recovering
|
|
59
|
+
# replicas.
|
|
60
|
+
find_unused_parameters=True,
|
|
61
|
+
# pyre-fixme[6]: got object
|
|
62
|
+
**kwargs,
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
self.register_comm_hook(manager, self._comm_hook)
|
|
66
|
+
|
|
67
|
+
@staticmethod
|
|
68
|
+
def _comm_hook(
|
|
69
|
+
state: "Manager", bucket: dist.GradBucket
|
|
70
|
+
) -> torch.futures.Future[torch.Tensor]:
|
|
71
|
+
work = state.allreduce(bucket.buffer())
|
|
72
|
+
work.wait()
|
|
73
|
+
fut = work.get_future()
|
|
74
|
+
|
|
75
|
+
# We need to return the underlying future here otherwise
|
|
76
|
+
# this can hang
|
|
77
|
+
fut = cast("_ManagedFuture[torch.Tensor]", fut)
|
|
78
|
+
assert fut._fut
|
|
79
|
+
return fut._fut
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class PureDistributedDataParallel(nn.Module):
|
|
83
|
+
"""
|
|
84
|
+
A pure Python reimplementation of the DDP wrapper.
|
|
85
|
+
|
|
86
|
+
We recommend using DistributedDataParallel instead of this class.
|
|
87
|
+
|
|
88
|
+
This calls one allreduce per gradient tensor and doesn't use a reducer. This
|
|
89
|
+
may be very slow for real models.
|
|
90
|
+
"""
|
|
91
|
+
|
|
92
|
+
def __init__(self, manager: "Manager", module: nn.Module) -> None:
|
|
93
|
+
super().__init__()
|
|
94
|
+
|
|
95
|
+
self.module = module
|
|
96
|
+
|
|
97
|
+
def post_grad_hook(p: torch.Tensor) -> None:
|
|
98
|
+
if p.grad is not None:
|
|
99
|
+
manager.allreduce(p.grad)
|
|
100
|
+
|
|
101
|
+
for p in module.parameters():
|
|
102
|
+
p.register_post_accumulate_grad_hook(post_grad_hook)
|
|
103
|
+
|
|
104
|
+
def forward(self, *args: object) -> object:
|
|
105
|
+
return self.module(*args)
|
torchft/ddp_test.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
from unittest import TestCase
|
|
8
|
+
from unittest.mock import create_autospec
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
import torch.distributed as dist
|
|
12
|
+
from torch import nn
|
|
13
|
+
from torch.distributed.distributed_c10d import Work
|
|
14
|
+
from torch.futures import Future
|
|
15
|
+
|
|
16
|
+
from torchft.ddp import DistributedDataParallel, PureDistributedDataParallel
|
|
17
|
+
from torchft.manager import _ManagedWork, Manager
|
|
18
|
+
from torchft.process_group import ProcessGroupBabyGloo, ProcessGroupGloo
|
|
19
|
+
from torchft.work import _DummyWork
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class TestDDP(TestCase):
|
|
23
|
+
def test_pure_ddp(self) -> None:
|
|
24
|
+
manager = create_autospec(Manager)
|
|
25
|
+
|
|
26
|
+
m = nn.Linear(3, 4)
|
|
27
|
+
m = PureDistributedDataParallel(manager, m)
|
|
28
|
+
|
|
29
|
+
inp = torch.rand(2, 3)
|
|
30
|
+
out = m(inp)
|
|
31
|
+
loss = out.mean()
|
|
32
|
+
loss.backward()
|
|
33
|
+
|
|
34
|
+
for p in m.parameters():
|
|
35
|
+
self.assertIsNotNone(p.grad)
|
|
36
|
+
|
|
37
|
+
self.assertEqual(manager.allreduce.call_count, len(list(m.parameters())))
|
|
38
|
+
|
|
39
|
+
def test_ddp(self) -> None:
|
|
40
|
+
manager = create_autospec(Manager)
|
|
41
|
+
|
|
42
|
+
call_count = 0
|
|
43
|
+
|
|
44
|
+
# pyre-ignore[53]: Captured variable `manager` is not annotated.
|
|
45
|
+
def allreduce(
|
|
46
|
+
tensor: torch.Tensor,
|
|
47
|
+
) -> Work:
|
|
48
|
+
nonlocal call_count
|
|
49
|
+
|
|
50
|
+
call_count += 1
|
|
51
|
+
|
|
52
|
+
work = _DummyWork(tensor)
|
|
53
|
+
return _ManagedWork(manager, work, tensor)
|
|
54
|
+
|
|
55
|
+
manager.allreduce = allreduce
|
|
56
|
+
|
|
57
|
+
m = nn.Linear(3, 4)
|
|
58
|
+
m = DistributedDataParallel(manager, m)
|
|
59
|
+
|
|
60
|
+
inp = torch.rand(2, 3)
|
|
61
|
+
out = m(inp)
|
|
62
|
+
loss = out.mean()
|
|
63
|
+
loss.backward()
|
|
64
|
+
|
|
65
|
+
for p in m.parameters():
|
|
66
|
+
self.assertIsNotNone(p.grad)
|
|
67
|
+
|
|
68
|
+
self.assertGreaterEqual(call_count, 1)
|