torchft-nightly 2026.1.3__cp310-cp310-manylinux_2_24_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchft/__init__.py +34 -0
- torchft/_test/diloco_trainer.py +287 -0
- torchft/_test/managed_work_test.py +320 -0
- torchft/_test_utils.py +111 -0
- torchft/_torchft.cpython-310-x86_64-linux-gnu.so +0 -0
- torchft/_torchft.pyi +116 -0
- torchft/checkpointing/__init__.py +20 -0
- torchft/checkpointing/_rwlock.py +136 -0
- torchft/checkpointing/_serialization.py +39 -0
- torchft/checkpointing/http_transport.py +299 -0
- torchft/checkpointing/http_transport_bench.py +61 -0
- torchft/checkpointing/http_transport_test.py +146 -0
- torchft/checkpointing/pg_transport.py +306 -0
- torchft/checkpointing/pg_transport_bench.py +99 -0
- torchft/checkpointing/pg_transport_test.py +101 -0
- torchft/checkpointing/rwlock_test.py +58 -0
- torchft/checkpointing/transport.py +68 -0
- torchft/checkpointing/transport_test.py +161 -0
- torchft/collectives.py +415 -0
- torchft/collectives_test.py +212 -0
- torchft/coordination.py +39 -0
- torchft/coordination_test.py +29 -0
- torchft/data.py +77 -0
- torchft/data_test.py +39 -0
- torchft/ddp.py +105 -0
- torchft/ddp_test.py +68 -0
- torchft/diloco_regression_test.py +644 -0
- torchft/examples/slurm/README.md +34 -0
- torchft/examples/slurm/punisher.py +95 -0
- torchft/examples/slurm/runner.py +221 -0
- torchft/fsdp_test.py +102 -0
- torchft/futures.py +353 -0
- torchft/futures_test.py +140 -0
- torchft/http.py +13 -0
- torchft/lighthouse_test.py +163 -0
- torchft/local_sgd.py +796 -0
- torchft/local_sgd_integ_test.py +600 -0
- torchft/local_sgd_test.py +324 -0
- torchft/manager.py +1358 -0
- torchft/manager_integ_test.py +653 -0
- torchft/manager_test.py +911 -0
- torchft/multiprocessing.py +38 -0
- torchft/multiprocessing_dummy_context.py +135 -0
- torchft/multiprocessing_test.py +58 -0
- torchft/optim.py +63 -0
- torchft/optim_test.py +50 -0
- torchft/otel.py +134 -0
- torchft/parameter_server.py +195 -0
- torchft/parameter_server_test.py +47 -0
- torchft/process_group.py +2118 -0
- torchft/process_group_test.py +1028 -0
- torchft/quantization.py +686 -0
- torchft/quantization_test.py +131 -0
- torchft/torchx.py +89 -0
- torchft/utils.py +67 -0
- torchft/work.py +26 -0
- torchft_nightly-2026.1.3.dist-info/METADATA +308 -0
- torchft_nightly-2026.1.3.dist-info/RECORD +61 -0
- torchft_nightly-2026.1.3.dist-info/WHEEL +4 -0
- torchft_nightly-2026.1.3.dist-info/entry_points.txt +2 -0
- torchft_nightly-2026.1.3.dist-info/licenses/LICENSE +34 -0
|
@@ -0,0 +1,161 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
import threading
|
|
8
|
+
import traceback
|
|
9
|
+
from concurrent.futures import as_completed, ThreadPoolExecutor
|
|
10
|
+
from datetime import timedelta
|
|
11
|
+
from typing import Callable
|
|
12
|
+
from unittest import TestCase
|
|
13
|
+
|
|
14
|
+
import torch
|
|
15
|
+
import torch.distributed as dist
|
|
16
|
+
from torch.distributed.tensor import DeviceMesh, distribute_tensor, DTensor
|
|
17
|
+
|
|
18
|
+
from torchft.checkpointing.transport import CheckpointTransport
|
|
19
|
+
|
|
20
|
+
TIMEOUT_REGEX = r".*(Timed out|timed out|timeout|time out).*"
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def assertStateDictEqual(
|
|
24
|
+
self: TestCase, a: dict[str, object], b: dict[str, object]
|
|
25
|
+
) -> None:
|
|
26
|
+
for k, v1 in a.items():
|
|
27
|
+
v2 = b[k]
|
|
28
|
+
if isinstance(v1, DTensor) and isinstance(v2, DTensor):
|
|
29
|
+
torch.testing.assert_close(v1._local_tensor.cpu(), v2._local_tensor.cpu())
|
|
30
|
+
self.assertEqual(v1._spec, v2._spec)
|
|
31
|
+
elif isinstance(v1, torch.Tensor) and isinstance(v2, torch.Tensor):
|
|
32
|
+
torch.testing.assert_close(v1.cpu(), v2.cpu())
|
|
33
|
+
else:
|
|
34
|
+
self.assertEqual(v1, v2)
|
|
35
|
+
|
|
36
|
+
|
|
37
|
+
def make_state_dict(device: torch.device) -> dict[str, object]:
|
|
38
|
+
device_mesh = DeviceMesh("cpu", 1)
|
|
39
|
+
tensor = torch.tensor([5, 6, 7])
|
|
40
|
+
dtensor: DTensor = distribute_tensor(tensor, device_mesh, [])
|
|
41
|
+
|
|
42
|
+
return {
|
|
43
|
+
"rank": torch.tensor([1, 2, 3], device=device),
|
|
44
|
+
# "strided": torch.tensor([10], device=device)[1::2],
|
|
45
|
+
"str": "str",
|
|
46
|
+
"int": 1234,
|
|
47
|
+
"dtensor": dtensor,
|
|
48
|
+
}
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def run_multi_recovery_test(
|
|
52
|
+
self: TestCase,
|
|
53
|
+
init_transport: Callable[[int, int], CheckpointTransport[dict[str, object]]],
|
|
54
|
+
device: torch.device,
|
|
55
|
+
) -> None:
|
|
56
|
+
"""
|
|
57
|
+
This runs multi node recovery tests for a given transport function.
|
|
58
|
+
|
|
59
|
+
This tests send/recv in a 3 node setup, with all and some workers recovering
|
|
60
|
+
and also tests timeout behavior.
|
|
61
|
+
"""
|
|
62
|
+
WORLD_SIZE: int = 3
|
|
63
|
+
|
|
64
|
+
# barrier is used to simulate quorum/allreduce barriers
|
|
65
|
+
barrier: threading.Barrier = threading.Barrier(WORLD_SIZE, timeout=10)
|
|
66
|
+
metadata: str = ""
|
|
67
|
+
|
|
68
|
+
dist.init_process_group(
|
|
69
|
+
backend="gloo", rank=0, world_size=1, store=dist.HashStore()
|
|
70
|
+
)
|
|
71
|
+
|
|
72
|
+
def run(rank: int) -> CheckpointTransport[dict[str, object]]:
|
|
73
|
+
transport = init_transport(rank, WORLD_SIZE)
|
|
74
|
+
|
|
75
|
+
if rank == 0:
|
|
76
|
+
nonlocal metadata
|
|
77
|
+
metadata = transport.metadata()
|
|
78
|
+
|
|
79
|
+
barrier.wait()
|
|
80
|
+
|
|
81
|
+
state_dict: dict[str, object] = make_state_dict(device)
|
|
82
|
+
|
|
83
|
+
# 3 node recovery
|
|
84
|
+
if rank == 0:
|
|
85
|
+
transport.send_checkpoint(
|
|
86
|
+
dst_ranks=[1, 2],
|
|
87
|
+
step=1,
|
|
88
|
+
state_dict=state_dict,
|
|
89
|
+
timeout=timedelta(seconds=10),
|
|
90
|
+
)
|
|
91
|
+
else:
|
|
92
|
+
got = transport.recv_checkpoint(
|
|
93
|
+
src_rank=0, metadata=metadata, step=1, timeout=timedelta(seconds=10)
|
|
94
|
+
)
|
|
95
|
+
assertStateDictEqual(self, got, state_dict)
|
|
96
|
+
|
|
97
|
+
barrier.wait()
|
|
98
|
+
transport.disallow_checkpoint()
|
|
99
|
+
|
|
100
|
+
# 2 node recovery
|
|
101
|
+
if rank == 0:
|
|
102
|
+
transport.send_checkpoint(
|
|
103
|
+
dst_ranks=[2],
|
|
104
|
+
step=2,
|
|
105
|
+
state_dict=state_dict,
|
|
106
|
+
timeout=timedelta(seconds=10),
|
|
107
|
+
)
|
|
108
|
+
elif rank == 2:
|
|
109
|
+
got = transport.recv_checkpoint(
|
|
110
|
+
src_rank=0, metadata=metadata, step=2, timeout=timedelta(seconds=10)
|
|
111
|
+
)
|
|
112
|
+
assertStateDictEqual(self, got, state_dict)
|
|
113
|
+
|
|
114
|
+
barrier.wait()
|
|
115
|
+
transport.disallow_checkpoint()
|
|
116
|
+
|
|
117
|
+
# timeout test
|
|
118
|
+
if rank == 2:
|
|
119
|
+
with self.assertRaisesRegex(Exception, TIMEOUT_REGEX):
|
|
120
|
+
transport.recv_checkpoint(
|
|
121
|
+
src_rank=0,
|
|
122
|
+
metadata=metadata,
|
|
123
|
+
step=3,
|
|
124
|
+
timeout=timedelta(milliseconds=10),
|
|
125
|
+
)
|
|
126
|
+
|
|
127
|
+
# Make sure send completes quickly.
|
|
128
|
+
# If the transport is async (such as with HTTP) this may just return
|
|
129
|
+
# immediately.
|
|
130
|
+
try:
|
|
131
|
+
transport.send_checkpoint(
|
|
132
|
+
dst_ranks=[0],
|
|
133
|
+
step=4,
|
|
134
|
+
state_dict=state_dict,
|
|
135
|
+
timeout=timedelta(seconds=10),
|
|
136
|
+
)
|
|
137
|
+
except Exception:
|
|
138
|
+
with self.assertRaisesRegex(Exception, TIMEOUT_REGEX):
|
|
139
|
+
raise
|
|
140
|
+
|
|
141
|
+
return transport
|
|
142
|
+
|
|
143
|
+
with ThreadPoolExecutor(max_workers=WORLD_SIZE) as executor:
|
|
144
|
+
results = []
|
|
145
|
+
for i in range(WORLD_SIZE):
|
|
146
|
+
results.append(executor.submit(run, i))
|
|
147
|
+
|
|
148
|
+
transports = []
|
|
149
|
+
|
|
150
|
+
try:
|
|
151
|
+
for fut in as_completed(results, timeout=10.0):
|
|
152
|
+
transports.append(fut.result())
|
|
153
|
+
except Exception as e:
|
|
154
|
+
print(e)
|
|
155
|
+
traceback.print_exc()
|
|
156
|
+
raise
|
|
157
|
+
|
|
158
|
+
for transport in transports:
|
|
159
|
+
transport.shutdown()
|
|
160
|
+
|
|
161
|
+
dist.destroy_process_group()
|
torchft/collectives.py
ADDED
|
@@ -0,0 +1,415 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
|
|
7
|
+
import math
|
|
8
|
+
from typing import TYPE_CHECKING
|
|
9
|
+
|
|
10
|
+
import torch
|
|
11
|
+
|
|
12
|
+
# pyre-ignore[21]: Could not find a module corresponding to import `triton`
|
|
13
|
+
import triton
|
|
14
|
+
from torch import cuda
|
|
15
|
+
from torch.distributed import ReduceOp
|
|
16
|
+
from torch.distributed.distributed_c10d import (
|
|
17
|
+
AllgatherOptions,
|
|
18
|
+
AllreduceOptions,
|
|
19
|
+
AllToAllOptions,
|
|
20
|
+
ReduceScatterOptions,
|
|
21
|
+
Work,
|
|
22
|
+
)
|
|
23
|
+
from torch.futures import Future
|
|
24
|
+
|
|
25
|
+
if TYPE_CHECKING:
|
|
26
|
+
from torchft.process_group import ProcessGroup
|
|
27
|
+
|
|
28
|
+
from torchft.quantization import (
|
|
29
|
+
fused_dequantize_from_fp8,
|
|
30
|
+
fused_quantize_into_fp8,
|
|
31
|
+
fused_reduce_fp8,
|
|
32
|
+
)
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
def _to_alltoall_options(
|
|
36
|
+
opts: AllreduceOptions | ReduceScatterOptions,
|
|
37
|
+
) -> AllToAllOptions:
|
|
38
|
+
alltoall_opts = AllToAllOptions()
|
|
39
|
+
alltoall_opts.timeout = opts.timeout
|
|
40
|
+
return alltoall_opts
|
|
41
|
+
|
|
42
|
+
|
|
43
|
+
def _to_allgather_options(
|
|
44
|
+
opts: AllreduceOptions | ReduceScatterOptions,
|
|
45
|
+
) -> AllgatherOptions:
|
|
46
|
+
allgather_opts = AllgatherOptions()
|
|
47
|
+
allgather_opts.timeout = opts.timeout
|
|
48
|
+
return allgather_opts
|
|
49
|
+
|
|
50
|
+
|
|
51
|
+
def get_padded_sizes(
|
|
52
|
+
tensors: list[torch.Tensor],
|
|
53
|
+
world_size: int,
|
|
54
|
+
) -> list[torch.Size]:
|
|
55
|
+
"""
|
|
56
|
+
Calculate padded sizes for tensors to ensure they can be evenly
|
|
57
|
+
divided across ranks.
|
|
58
|
+
|
|
59
|
+
This function computes padded tensor sizes by rounding up the
|
|
60
|
+
first dimension of each tensor to be a multiple of the world_size.
|
|
61
|
+
This ensures that when tensors are split across ranks
|
|
62
|
+
in distributed operations, each process receives an equal
|
|
63
|
+
number of elements.
|
|
64
|
+
|
|
65
|
+
Args:
|
|
66
|
+
tensors: List of tensors whose sizes need to be padded
|
|
67
|
+
world_size: Number of ranks in the distributed setup
|
|
68
|
+
|
|
69
|
+
Returns:
|
|
70
|
+
List of torch.Size objects with the first dimension padded
|
|
71
|
+
to be a multiple of world_size
|
|
72
|
+
|
|
73
|
+
Note:
|
|
74
|
+
For 1D tensors, they are treated as 2D tensors with a
|
|
75
|
+
second dimension of 1
|
|
76
|
+
"""
|
|
77
|
+
padded_sizes = []
|
|
78
|
+
for tensor in tensors:
|
|
79
|
+
size = tensor.size()
|
|
80
|
+
if len(size) == 1:
|
|
81
|
+
size = (size[0], 1)
|
|
82
|
+
padded_m = math.ceil(size[0] / world_size) * world_size
|
|
83
|
+
padded_sizes.append(torch.Size((padded_m, *size[1:])))
|
|
84
|
+
return padded_sizes
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
def allocate_reduce_scatter_output(
|
|
88
|
+
tensors: list[torch.Tensor],
|
|
89
|
+
world_size: int,
|
|
90
|
+
) -> tuple[torch.Tensor, list[torch.Size]]:
|
|
91
|
+
"""
|
|
92
|
+
Allocate tensor for the output of a reduce-scatter operation.
|
|
93
|
+
|
|
94
|
+
This function creates a single contiguous tensor to hold the results of a
|
|
95
|
+
reduce-scatter operation across multiple ranks. It ensures that the tensor
|
|
96
|
+
is properly sized and shaped to accommodate the results, where each rank
|
|
97
|
+
will receive a portion of the reduced data.
|
|
98
|
+
|
|
99
|
+
Args:
|
|
100
|
+
tensors: List of input tensors for the reduce-scatter operation.
|
|
101
|
+
All tensors must be on the same device and have the same
|
|
102
|
+
data type.
|
|
103
|
+
world_size: Number of ranks in the distributed setup
|
|
104
|
+
|
|
105
|
+
Returns:
|
|
106
|
+
A tuple containing:
|
|
107
|
+
- A single contiguous tensor allocated for the reduce-scatter output
|
|
108
|
+
- A list of padded sizes for the input tensors that were split across
|
|
109
|
+
ranks
|
|
110
|
+
|
|
111
|
+
Raises:
|
|
112
|
+
AssertionError: If the input tensors are not all on the same device or
|
|
113
|
+
do not all have the same data type
|
|
114
|
+
"""
|
|
115
|
+
device = tensors[0].device
|
|
116
|
+
dtype = tensors[0].dtype
|
|
117
|
+
for i in range(1, len(tensors)):
|
|
118
|
+
assert (
|
|
119
|
+
tensors[i].device == tensors[i - 1].device
|
|
120
|
+
), "All inputs must be on the same device"
|
|
121
|
+
assert (
|
|
122
|
+
tensors[i].dtype == tensors[i - 1].dtype
|
|
123
|
+
), "All inputs must be on the same dtype"
|
|
124
|
+
|
|
125
|
+
padded_sizes = get_padded_sizes(tensors, world_size)
|
|
126
|
+
|
|
127
|
+
chunks = []
|
|
128
|
+
numels = [size.numel() // world_size for size in padded_sizes]
|
|
129
|
+
tensor = torch.empty(
|
|
130
|
+
(sum(numels),),
|
|
131
|
+
device=device,
|
|
132
|
+
dtype=dtype,
|
|
133
|
+
)
|
|
134
|
+
for split, padded_size in zip(torch.split(tensor, numels), padded_sizes):
|
|
135
|
+
chunks.append(split.view(padded_size[0] // world_size, *padded_size[1:]))
|
|
136
|
+
return tensor, padded_sizes
|
|
137
|
+
|
|
138
|
+
|
|
139
|
+
class _QuantizedOpFuture(Future[list[torch.Tensor]]):
|
|
140
|
+
def __init__(
|
|
141
|
+
self,
|
|
142
|
+
sync_stream: cuda.Stream,
|
|
143
|
+
keep_alive_tensors: list[torch.Tensor],
|
|
144
|
+
return_tensors: list[torch.Tensor],
|
|
145
|
+
) -> None:
|
|
146
|
+
super().__init__()
|
|
147
|
+
self._sync_stream = sync_stream
|
|
148
|
+
self._keep_alive_tensors = keep_alive_tensors
|
|
149
|
+
self._return_tensors = return_tensors
|
|
150
|
+
|
|
151
|
+
def wait(self) -> list[torch.Tensor]:
|
|
152
|
+
# Wait for the synchronization to complete.
|
|
153
|
+
cuda.current_stream().wait_stream(self._sync_stream)
|
|
154
|
+
# Clean up intermediate buffers.
|
|
155
|
+
del self._keep_alive_tensors
|
|
156
|
+
return self._return_tensors
|
|
157
|
+
|
|
158
|
+
|
|
159
|
+
def reduce_scatter_quantized(
|
|
160
|
+
output: torch.Tensor,
|
|
161
|
+
inputs: list[torch.Tensor],
|
|
162
|
+
opts: ReduceScatterOptions | ReduceOp,
|
|
163
|
+
process_group: "ProcessGroup",
|
|
164
|
+
sync_stream: cuda.Stream | None = None,
|
|
165
|
+
) -> Work:
|
|
166
|
+
"""
|
|
167
|
+
Performs a quantized reduce-scatter operation on a list of tensors.
|
|
168
|
+
|
|
169
|
+
This function implements an optimized reduce-scatter that reduces communication
|
|
170
|
+
overhead by quantizing tensors to FP8 format before sending them over the
|
|
171
|
+
network. The algorithm works as follows:
|
|
172
|
+
|
|
173
|
+
1. Quantize input tensors to FP8 format
|
|
174
|
+
2. Distribute chunks of quantized tensors to all ranks using all-to-all
|
|
175
|
+
3. Reduce chunks locally in higher precision after dequantization
|
|
176
|
+
4. Dequantize the result back to the original precision for the current rank
|
|
177
|
+
|
|
178
|
+
This implementation only supports the AVG and SUM reduce operations.
|
|
179
|
+
|
|
180
|
+
Args:
|
|
181
|
+
output: Pre-allocated tensor to store the output of the reduce-scatter operation
|
|
182
|
+
inputs: List of tensors to be reduced and scattered. All tensors must be on
|
|
183
|
+
the same CUDA device and have the same dtype.
|
|
184
|
+
opts: Options for the reduce-scatter operation. Can be either a
|
|
185
|
+
ReduceScatterOptions object or a ReduceOp enum.
|
|
186
|
+
process_group: The process group to perform the reduce-scatter on.
|
|
187
|
+
sync_stream: Optional CUDA stream to use for synchronization. If None,
|
|
188
|
+
a new stream will be created.
|
|
189
|
+
|
|
190
|
+
Returns:
|
|
191
|
+
A Future that can be used to wait for the operation to complete and
|
|
192
|
+
clean up intermediate buffers.
|
|
193
|
+
|
|
194
|
+
Raises:
|
|
195
|
+
NotImplementedError: If the reduce operation is not ReduceOp.AVG or ReduceOp.SUM.
|
|
196
|
+
"""
|
|
197
|
+
|
|
198
|
+
if isinstance(opts, ReduceOp):
|
|
199
|
+
reducescatter_opts: ReduceScatterOptions = ReduceScatterOptions()
|
|
200
|
+
reducescatter_opts.reduceOp = opts
|
|
201
|
+
else:
|
|
202
|
+
reducescatter_opts: ReduceScatterOptions = opts
|
|
203
|
+
|
|
204
|
+
# Check if the reduceOp is AVG or SUM
|
|
205
|
+
if reducescatter_opts.reduceOp not in {
|
|
206
|
+
ReduceOp(ReduceOp.AVG),
|
|
207
|
+
ReduceOp(ReduceOp.SUM),
|
|
208
|
+
}:
|
|
209
|
+
raise NotImplementedError(
|
|
210
|
+
f"ReduceOp {reducescatter_opts.reduceOp} is not supported "
|
|
211
|
+
f"for quantized reduce-scatter, only AVG and SUM are supported"
|
|
212
|
+
)
|
|
213
|
+
|
|
214
|
+
rank: int = process_group.rank()
|
|
215
|
+
world_size: int = process_group.size()
|
|
216
|
+
|
|
217
|
+
reduce_output_sizes = [
|
|
218
|
+
torch.Size((s[0] // world_size, *s[1:]))
|
|
219
|
+
for s in get_padded_sizes(inputs, world_size)
|
|
220
|
+
]
|
|
221
|
+
reduce_output_numels = [s.numel() for s in reduce_output_sizes]
|
|
222
|
+
reduce_outputs: list[torch.Tensor] = [
|
|
223
|
+
o.view(s)
|
|
224
|
+
for o, s in zip(
|
|
225
|
+
output.split(reduce_output_numels),
|
|
226
|
+
reduce_output_sizes,
|
|
227
|
+
)
|
|
228
|
+
]
|
|
229
|
+
|
|
230
|
+
if sync_stream is None:
|
|
231
|
+
sync_stream = cuda.Stream()
|
|
232
|
+
|
|
233
|
+
assert sync_stream is not None
|
|
234
|
+
# Ensure that all operations are completed on the current stream
|
|
235
|
+
# before proceeding with all-reduce
|
|
236
|
+
sync_stream.wait_stream(cuda.current_stream())
|
|
237
|
+
with cuda.stream(sync_stream):
|
|
238
|
+
# Quantize tensoers and compute their scales, all inlined in the
|
|
239
|
+
# output tensor.
|
|
240
|
+
quantized_inputs = fused_quantize_into_fp8(inputs, world_size)
|
|
241
|
+
|
|
242
|
+
# Allocate output tensor where all-reduce results will be stored
|
|
243
|
+
quantized_inputs_out: torch.Tensor = torch.zeros_like(quantized_inputs)
|
|
244
|
+
# Collect chunks and their scales from other ranks
|
|
245
|
+
work = process_group.alltoall_base(
|
|
246
|
+
quantized_inputs_out.view(world_size, -1),
|
|
247
|
+
quantized_inputs.view(world_size, -1),
|
|
248
|
+
[],
|
|
249
|
+
[],
|
|
250
|
+
_to_alltoall_options(reducescatter_opts),
|
|
251
|
+
)
|
|
252
|
+
work.wait()
|
|
253
|
+
|
|
254
|
+
fut = work.get_future()
|
|
255
|
+
|
|
256
|
+
def callback(fut: Future[list[torch.Tensor]]) -> None:
|
|
257
|
+
nonlocal \
|
|
258
|
+
inputs, \
|
|
259
|
+
quantized_inputs_out, \
|
|
260
|
+
world_size, \
|
|
261
|
+
sync_stream, \
|
|
262
|
+
rank, \
|
|
263
|
+
reduce_outputs, \
|
|
264
|
+
reducescatter_opts
|
|
265
|
+
|
|
266
|
+
with torch.cuda.stream(sync_stream):
|
|
267
|
+
# Setup stream dependency
|
|
268
|
+
fut.wait()
|
|
269
|
+
# Reduce chunks locally in higher precision after dequantization.
|
|
270
|
+
# The output is again quantized.
|
|
271
|
+
fused_reduce_fp8(
|
|
272
|
+
inputs,
|
|
273
|
+
quantized_inputs_out,
|
|
274
|
+
world_size,
|
|
275
|
+
rank,
|
|
276
|
+
reducescatter_opts.reduceOp,
|
|
277
|
+
)
|
|
278
|
+
|
|
279
|
+
# Get view into the output tensor that corresponds to the
|
|
280
|
+
# current rank
|
|
281
|
+
quantized_reduce_scatter = (
|
|
282
|
+
quantized_inputs_out.view(world_size, -1).split(1)[rank].squeeze(0)
|
|
283
|
+
)
|
|
284
|
+
# Dequantize the result back to the original precision for
|
|
285
|
+
# the current rank
|
|
286
|
+
fused_dequantize_from_fp8(
|
|
287
|
+
reduce_outputs,
|
|
288
|
+
quantized_reduce_scatter,
|
|
289
|
+
1,
|
|
290
|
+
)
|
|
291
|
+
|
|
292
|
+
fut.add_done_callback(callback)
|
|
293
|
+
|
|
294
|
+
return work
|
|
295
|
+
|
|
296
|
+
|
|
297
|
+
def allreduce_quantized(
|
|
298
|
+
tensors: list[torch.Tensor],
|
|
299
|
+
opts: AllreduceOptions | ReduceOp,
|
|
300
|
+
process_group: "ProcessGroup",
|
|
301
|
+
sync_stream: cuda.Stream | None = None,
|
|
302
|
+
) -> Work:
|
|
303
|
+
"""
|
|
304
|
+
Performs a quantized all-reduce operation on a list of tensors.
|
|
305
|
+
|
|
306
|
+
This function implements an optimized all-reduce that reduces communication
|
|
307
|
+
overhead by quantizing tensors to FP8 format before sending them over the
|
|
308
|
+
network. The algorithm works as follows:
|
|
309
|
+
|
|
310
|
+
1. Quantize input tensors to FP8 format
|
|
311
|
+
2. Distribute chunks of quantized tensors to all ranks using all-to-all
|
|
312
|
+
3. Reduce chunks locally in higher precision after dequantization
|
|
313
|
+
4. Collect reduced chunks from all ranks using all-gather
|
|
314
|
+
5. Dequantize the result back to the original precision
|
|
315
|
+
|
|
316
|
+
This implementation only supports the AVG reduce operation.
|
|
317
|
+
|
|
318
|
+
Args:
|
|
319
|
+
tensors: List of tensors to be reduced. All tensors must be on the same
|
|
320
|
+
CUDA device and have the same dtype.
|
|
321
|
+
opts: Options for the all-reduce operation. Can be either an
|
|
322
|
+
AllreduceOptions object or a ReduceOp enum. If a ReduceOp is
|
|
323
|
+
provided, it must be ReduceOp.AVG.
|
|
324
|
+
process_group: The process group to perform the all-reduce on.
|
|
325
|
+
sync_stream: Optional CUDA stream to use for synchronization. If None,
|
|
326
|
+
a new stream will be created.
|
|
327
|
+
|
|
328
|
+
Returns:
|
|
329
|
+
A Future that can be used to wait for the operation to complete and
|
|
330
|
+
clean up intermediate buffers.
|
|
331
|
+
|
|
332
|
+
Raises:
|
|
333
|
+
NotImplementedError: If the reduce operation is not ReduceOp.AVG.
|
|
334
|
+
"""
|
|
335
|
+
if isinstance(opts, ReduceOp):
|
|
336
|
+
allreduce_opts = AllreduceOptions()
|
|
337
|
+
allreduce_opts.reduceOp = opts
|
|
338
|
+
else:
|
|
339
|
+
allreduce_opts = opts
|
|
340
|
+
|
|
341
|
+
# Check if the reduceOp is AVG or SUM
|
|
342
|
+
if allreduce_opts.reduceOp not in {
|
|
343
|
+
ReduceOp(ReduceOp.AVG),
|
|
344
|
+
ReduceOp(ReduceOp.SUM),
|
|
345
|
+
}:
|
|
346
|
+
raise NotImplementedError(
|
|
347
|
+
f"ReduceOp {allreduce_opts.reduceOp} is not supported "
|
|
348
|
+
f"for quantized allreduce, only AVG and SUM are supported"
|
|
349
|
+
)
|
|
350
|
+
|
|
351
|
+
rank = process_group.rank()
|
|
352
|
+
world_size: int = process_group.size()
|
|
353
|
+
|
|
354
|
+
if sync_stream is None:
|
|
355
|
+
sync_stream = cuda.Stream()
|
|
356
|
+
|
|
357
|
+
assert sync_stream is not None
|
|
358
|
+
# Ensure that all operations are completed on the current stream
|
|
359
|
+
# before proceeding with all-reduce
|
|
360
|
+
sync_stream.wait_stream(cuda.current_stream())
|
|
361
|
+
with cuda.stream(sync_stream):
|
|
362
|
+
# Quantize tensoers and compute their scales, all inlined in the
|
|
363
|
+
# output tensor.
|
|
364
|
+
quantized_tensors: torch.Tensor = fused_quantize_into_fp8(tensors, world_size)
|
|
365
|
+
|
|
366
|
+
# Allocate output tensor where all-reduce results will be stored
|
|
367
|
+
quantized_tensors_out = torch.zeros_like(quantized_tensors)
|
|
368
|
+
# Collect chunks and their scales from other ranks
|
|
369
|
+
process_group.alltoall_base(
|
|
370
|
+
quantized_tensors_out.view(world_size, -1),
|
|
371
|
+
quantized_tensors.view(world_size, -1),
|
|
372
|
+
[],
|
|
373
|
+
[],
|
|
374
|
+
_to_alltoall_options(allreduce_opts),
|
|
375
|
+
).wait()
|
|
376
|
+
|
|
377
|
+
# Reduce chunks locally in higher precision after dequantization.
|
|
378
|
+
# The output is again quantized.
|
|
379
|
+
fused_reduce_fp8(
|
|
380
|
+
tensors,
|
|
381
|
+
quantized_tensors_out,
|
|
382
|
+
world_size,
|
|
383
|
+
rank,
|
|
384
|
+
allreduce_opts.reduceOp,
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
# Collect reduced chunks from other ranks.
|
|
388
|
+
work = process_group.allgather_into_tensor_coalesced(
|
|
389
|
+
[quantized_tensors.view(world_size, -1)],
|
|
390
|
+
[torch.split(quantized_tensors_out.view(world_size, -1), 1)[rank]],
|
|
391
|
+
_to_allgather_options(allreduce_opts),
|
|
392
|
+
)
|
|
393
|
+
|
|
394
|
+
# NOTE: This is not supposed to be used with gloo, only with NCCL.
|
|
395
|
+
# So we setup the stream dependency here by calling work.wait(),
|
|
396
|
+
# which doesn't block the CPU.
|
|
397
|
+
#
|
|
398
|
+
# The future callback below will run after the work has been
|
|
399
|
+
# completed.
|
|
400
|
+
|
|
401
|
+
work.wait()
|
|
402
|
+
fut = work.get_future()
|
|
403
|
+
|
|
404
|
+
def callback(fut: Future[list[torch.Tensor]]) -> None:
|
|
405
|
+
# Dequantize and copy to output buffer.
|
|
406
|
+
nonlocal tensors, quantized_tensors, world_size, sync_stream
|
|
407
|
+
|
|
408
|
+
with torch.cuda.stream(sync_stream):
|
|
409
|
+
# Setup stream dependency
|
|
410
|
+
fut.wait()
|
|
411
|
+
# Dequantize the result back to the original precision
|
|
412
|
+
fused_dequantize_from_fp8(tensors, quantized_tensors, world_size)
|
|
413
|
+
|
|
414
|
+
fut.add_done_callback(callback)
|
|
415
|
+
return work
|