torchft-nightly 2026.1.3__cp310-cp310-manylinux_2_24_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. torchft/__init__.py +34 -0
  2. torchft/_test/diloco_trainer.py +287 -0
  3. torchft/_test/managed_work_test.py +320 -0
  4. torchft/_test_utils.py +111 -0
  5. torchft/_torchft.cpython-310-x86_64-linux-gnu.so +0 -0
  6. torchft/_torchft.pyi +116 -0
  7. torchft/checkpointing/__init__.py +20 -0
  8. torchft/checkpointing/_rwlock.py +136 -0
  9. torchft/checkpointing/_serialization.py +39 -0
  10. torchft/checkpointing/http_transport.py +299 -0
  11. torchft/checkpointing/http_transport_bench.py +61 -0
  12. torchft/checkpointing/http_transport_test.py +146 -0
  13. torchft/checkpointing/pg_transport.py +306 -0
  14. torchft/checkpointing/pg_transport_bench.py +99 -0
  15. torchft/checkpointing/pg_transport_test.py +101 -0
  16. torchft/checkpointing/rwlock_test.py +58 -0
  17. torchft/checkpointing/transport.py +68 -0
  18. torchft/checkpointing/transport_test.py +161 -0
  19. torchft/collectives.py +415 -0
  20. torchft/collectives_test.py +212 -0
  21. torchft/coordination.py +39 -0
  22. torchft/coordination_test.py +29 -0
  23. torchft/data.py +77 -0
  24. torchft/data_test.py +39 -0
  25. torchft/ddp.py +105 -0
  26. torchft/ddp_test.py +68 -0
  27. torchft/diloco_regression_test.py +644 -0
  28. torchft/examples/slurm/README.md +34 -0
  29. torchft/examples/slurm/punisher.py +95 -0
  30. torchft/examples/slurm/runner.py +221 -0
  31. torchft/fsdp_test.py +102 -0
  32. torchft/futures.py +353 -0
  33. torchft/futures_test.py +140 -0
  34. torchft/http.py +13 -0
  35. torchft/lighthouse_test.py +163 -0
  36. torchft/local_sgd.py +796 -0
  37. torchft/local_sgd_integ_test.py +600 -0
  38. torchft/local_sgd_test.py +324 -0
  39. torchft/manager.py +1358 -0
  40. torchft/manager_integ_test.py +653 -0
  41. torchft/manager_test.py +911 -0
  42. torchft/multiprocessing.py +38 -0
  43. torchft/multiprocessing_dummy_context.py +135 -0
  44. torchft/multiprocessing_test.py +58 -0
  45. torchft/optim.py +63 -0
  46. torchft/optim_test.py +50 -0
  47. torchft/otel.py +134 -0
  48. torchft/parameter_server.py +195 -0
  49. torchft/parameter_server_test.py +47 -0
  50. torchft/process_group.py +2118 -0
  51. torchft/process_group_test.py +1028 -0
  52. torchft/quantization.py +686 -0
  53. torchft/quantization_test.py +131 -0
  54. torchft/torchx.py +89 -0
  55. torchft/utils.py +67 -0
  56. torchft/work.py +26 -0
  57. torchft_nightly-2026.1.3.dist-info/METADATA +308 -0
  58. torchft_nightly-2026.1.3.dist-info/RECORD +61 -0
  59. torchft_nightly-2026.1.3.dist-info/WHEEL +4 -0
  60. torchft_nightly-2026.1.3.dist-info/entry_points.txt +2 -0
  61. torchft_nightly-2026.1.3.dist-info/licenses/LICENSE +34 -0
@@ -0,0 +1,161 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import threading
8
+ import traceback
9
+ from concurrent.futures import as_completed, ThreadPoolExecutor
10
+ from datetime import timedelta
11
+ from typing import Callable
12
+ from unittest import TestCase
13
+
14
+ import torch
15
+ import torch.distributed as dist
16
+ from torch.distributed.tensor import DeviceMesh, distribute_tensor, DTensor
17
+
18
+ from torchft.checkpointing.transport import CheckpointTransport
19
+
20
+ TIMEOUT_REGEX = r".*(Timed out|timed out|timeout|time out).*"
21
+
22
+
23
+ def assertStateDictEqual(
24
+ self: TestCase, a: dict[str, object], b: dict[str, object]
25
+ ) -> None:
26
+ for k, v1 in a.items():
27
+ v2 = b[k]
28
+ if isinstance(v1, DTensor) and isinstance(v2, DTensor):
29
+ torch.testing.assert_close(v1._local_tensor.cpu(), v2._local_tensor.cpu())
30
+ self.assertEqual(v1._spec, v2._spec)
31
+ elif isinstance(v1, torch.Tensor) and isinstance(v2, torch.Tensor):
32
+ torch.testing.assert_close(v1.cpu(), v2.cpu())
33
+ else:
34
+ self.assertEqual(v1, v2)
35
+
36
+
37
+ def make_state_dict(device: torch.device) -> dict[str, object]:
38
+ device_mesh = DeviceMesh("cpu", 1)
39
+ tensor = torch.tensor([5, 6, 7])
40
+ dtensor: DTensor = distribute_tensor(tensor, device_mesh, [])
41
+
42
+ return {
43
+ "rank": torch.tensor([1, 2, 3], device=device),
44
+ # "strided": torch.tensor([10], device=device)[1::2],
45
+ "str": "str",
46
+ "int": 1234,
47
+ "dtensor": dtensor,
48
+ }
49
+
50
+
51
+ def run_multi_recovery_test(
52
+ self: TestCase,
53
+ init_transport: Callable[[int, int], CheckpointTransport[dict[str, object]]],
54
+ device: torch.device,
55
+ ) -> None:
56
+ """
57
+ This runs multi node recovery tests for a given transport function.
58
+
59
+ This tests send/recv in a 3 node setup, with all and some workers recovering
60
+ and also tests timeout behavior.
61
+ """
62
+ WORLD_SIZE: int = 3
63
+
64
+ # barrier is used to simulate quorum/allreduce barriers
65
+ barrier: threading.Barrier = threading.Barrier(WORLD_SIZE, timeout=10)
66
+ metadata: str = ""
67
+
68
+ dist.init_process_group(
69
+ backend="gloo", rank=0, world_size=1, store=dist.HashStore()
70
+ )
71
+
72
+ def run(rank: int) -> CheckpointTransport[dict[str, object]]:
73
+ transport = init_transport(rank, WORLD_SIZE)
74
+
75
+ if rank == 0:
76
+ nonlocal metadata
77
+ metadata = transport.metadata()
78
+
79
+ barrier.wait()
80
+
81
+ state_dict: dict[str, object] = make_state_dict(device)
82
+
83
+ # 3 node recovery
84
+ if rank == 0:
85
+ transport.send_checkpoint(
86
+ dst_ranks=[1, 2],
87
+ step=1,
88
+ state_dict=state_dict,
89
+ timeout=timedelta(seconds=10),
90
+ )
91
+ else:
92
+ got = transport.recv_checkpoint(
93
+ src_rank=0, metadata=metadata, step=1, timeout=timedelta(seconds=10)
94
+ )
95
+ assertStateDictEqual(self, got, state_dict)
96
+
97
+ barrier.wait()
98
+ transport.disallow_checkpoint()
99
+
100
+ # 2 node recovery
101
+ if rank == 0:
102
+ transport.send_checkpoint(
103
+ dst_ranks=[2],
104
+ step=2,
105
+ state_dict=state_dict,
106
+ timeout=timedelta(seconds=10),
107
+ )
108
+ elif rank == 2:
109
+ got = transport.recv_checkpoint(
110
+ src_rank=0, metadata=metadata, step=2, timeout=timedelta(seconds=10)
111
+ )
112
+ assertStateDictEqual(self, got, state_dict)
113
+
114
+ barrier.wait()
115
+ transport.disallow_checkpoint()
116
+
117
+ # timeout test
118
+ if rank == 2:
119
+ with self.assertRaisesRegex(Exception, TIMEOUT_REGEX):
120
+ transport.recv_checkpoint(
121
+ src_rank=0,
122
+ metadata=metadata,
123
+ step=3,
124
+ timeout=timedelta(milliseconds=10),
125
+ )
126
+
127
+ # Make sure send completes quickly.
128
+ # If the transport is async (such as with HTTP) this may just return
129
+ # immediately.
130
+ try:
131
+ transport.send_checkpoint(
132
+ dst_ranks=[0],
133
+ step=4,
134
+ state_dict=state_dict,
135
+ timeout=timedelta(seconds=10),
136
+ )
137
+ except Exception:
138
+ with self.assertRaisesRegex(Exception, TIMEOUT_REGEX):
139
+ raise
140
+
141
+ return transport
142
+
143
+ with ThreadPoolExecutor(max_workers=WORLD_SIZE) as executor:
144
+ results = []
145
+ for i in range(WORLD_SIZE):
146
+ results.append(executor.submit(run, i))
147
+
148
+ transports = []
149
+
150
+ try:
151
+ for fut in as_completed(results, timeout=10.0):
152
+ transports.append(fut.result())
153
+ except Exception as e:
154
+ print(e)
155
+ traceback.print_exc()
156
+ raise
157
+
158
+ for transport in transports:
159
+ transport.shutdown()
160
+
161
+ dist.destroy_process_group()
torchft/collectives.py ADDED
@@ -0,0 +1,415 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import math
8
+ from typing import TYPE_CHECKING
9
+
10
+ import torch
11
+
12
+ # pyre-ignore[21]: Could not find a module corresponding to import `triton`
13
+ import triton
14
+ from torch import cuda
15
+ from torch.distributed import ReduceOp
16
+ from torch.distributed.distributed_c10d import (
17
+ AllgatherOptions,
18
+ AllreduceOptions,
19
+ AllToAllOptions,
20
+ ReduceScatterOptions,
21
+ Work,
22
+ )
23
+ from torch.futures import Future
24
+
25
+ if TYPE_CHECKING:
26
+ from torchft.process_group import ProcessGroup
27
+
28
+ from torchft.quantization import (
29
+ fused_dequantize_from_fp8,
30
+ fused_quantize_into_fp8,
31
+ fused_reduce_fp8,
32
+ )
33
+
34
+
35
+ def _to_alltoall_options(
36
+ opts: AllreduceOptions | ReduceScatterOptions,
37
+ ) -> AllToAllOptions:
38
+ alltoall_opts = AllToAllOptions()
39
+ alltoall_opts.timeout = opts.timeout
40
+ return alltoall_opts
41
+
42
+
43
+ def _to_allgather_options(
44
+ opts: AllreduceOptions | ReduceScatterOptions,
45
+ ) -> AllgatherOptions:
46
+ allgather_opts = AllgatherOptions()
47
+ allgather_opts.timeout = opts.timeout
48
+ return allgather_opts
49
+
50
+
51
+ def get_padded_sizes(
52
+ tensors: list[torch.Tensor],
53
+ world_size: int,
54
+ ) -> list[torch.Size]:
55
+ """
56
+ Calculate padded sizes for tensors to ensure they can be evenly
57
+ divided across ranks.
58
+
59
+ This function computes padded tensor sizes by rounding up the
60
+ first dimension of each tensor to be a multiple of the world_size.
61
+ This ensures that when tensors are split across ranks
62
+ in distributed operations, each process receives an equal
63
+ number of elements.
64
+
65
+ Args:
66
+ tensors: List of tensors whose sizes need to be padded
67
+ world_size: Number of ranks in the distributed setup
68
+
69
+ Returns:
70
+ List of torch.Size objects with the first dimension padded
71
+ to be a multiple of world_size
72
+
73
+ Note:
74
+ For 1D tensors, they are treated as 2D tensors with a
75
+ second dimension of 1
76
+ """
77
+ padded_sizes = []
78
+ for tensor in tensors:
79
+ size = tensor.size()
80
+ if len(size) == 1:
81
+ size = (size[0], 1)
82
+ padded_m = math.ceil(size[0] / world_size) * world_size
83
+ padded_sizes.append(torch.Size((padded_m, *size[1:])))
84
+ return padded_sizes
85
+
86
+
87
+ def allocate_reduce_scatter_output(
88
+ tensors: list[torch.Tensor],
89
+ world_size: int,
90
+ ) -> tuple[torch.Tensor, list[torch.Size]]:
91
+ """
92
+ Allocate tensor for the output of a reduce-scatter operation.
93
+
94
+ This function creates a single contiguous tensor to hold the results of a
95
+ reduce-scatter operation across multiple ranks. It ensures that the tensor
96
+ is properly sized and shaped to accommodate the results, where each rank
97
+ will receive a portion of the reduced data.
98
+
99
+ Args:
100
+ tensors: List of input tensors for the reduce-scatter operation.
101
+ All tensors must be on the same device and have the same
102
+ data type.
103
+ world_size: Number of ranks in the distributed setup
104
+
105
+ Returns:
106
+ A tuple containing:
107
+ - A single contiguous tensor allocated for the reduce-scatter output
108
+ - A list of padded sizes for the input tensors that were split across
109
+ ranks
110
+
111
+ Raises:
112
+ AssertionError: If the input tensors are not all on the same device or
113
+ do not all have the same data type
114
+ """
115
+ device = tensors[0].device
116
+ dtype = tensors[0].dtype
117
+ for i in range(1, len(tensors)):
118
+ assert (
119
+ tensors[i].device == tensors[i - 1].device
120
+ ), "All inputs must be on the same device"
121
+ assert (
122
+ tensors[i].dtype == tensors[i - 1].dtype
123
+ ), "All inputs must be on the same dtype"
124
+
125
+ padded_sizes = get_padded_sizes(tensors, world_size)
126
+
127
+ chunks = []
128
+ numels = [size.numel() // world_size for size in padded_sizes]
129
+ tensor = torch.empty(
130
+ (sum(numels),),
131
+ device=device,
132
+ dtype=dtype,
133
+ )
134
+ for split, padded_size in zip(torch.split(tensor, numels), padded_sizes):
135
+ chunks.append(split.view(padded_size[0] // world_size, *padded_size[1:]))
136
+ return tensor, padded_sizes
137
+
138
+
139
+ class _QuantizedOpFuture(Future[list[torch.Tensor]]):
140
+ def __init__(
141
+ self,
142
+ sync_stream: cuda.Stream,
143
+ keep_alive_tensors: list[torch.Tensor],
144
+ return_tensors: list[torch.Tensor],
145
+ ) -> None:
146
+ super().__init__()
147
+ self._sync_stream = sync_stream
148
+ self._keep_alive_tensors = keep_alive_tensors
149
+ self._return_tensors = return_tensors
150
+
151
+ def wait(self) -> list[torch.Tensor]:
152
+ # Wait for the synchronization to complete.
153
+ cuda.current_stream().wait_stream(self._sync_stream)
154
+ # Clean up intermediate buffers.
155
+ del self._keep_alive_tensors
156
+ return self._return_tensors
157
+
158
+
159
+ def reduce_scatter_quantized(
160
+ output: torch.Tensor,
161
+ inputs: list[torch.Tensor],
162
+ opts: ReduceScatterOptions | ReduceOp,
163
+ process_group: "ProcessGroup",
164
+ sync_stream: cuda.Stream | None = None,
165
+ ) -> Work:
166
+ """
167
+ Performs a quantized reduce-scatter operation on a list of tensors.
168
+
169
+ This function implements an optimized reduce-scatter that reduces communication
170
+ overhead by quantizing tensors to FP8 format before sending them over the
171
+ network. The algorithm works as follows:
172
+
173
+ 1. Quantize input tensors to FP8 format
174
+ 2. Distribute chunks of quantized tensors to all ranks using all-to-all
175
+ 3. Reduce chunks locally in higher precision after dequantization
176
+ 4. Dequantize the result back to the original precision for the current rank
177
+
178
+ This implementation only supports the AVG and SUM reduce operations.
179
+
180
+ Args:
181
+ output: Pre-allocated tensor to store the output of the reduce-scatter operation
182
+ inputs: List of tensors to be reduced and scattered. All tensors must be on
183
+ the same CUDA device and have the same dtype.
184
+ opts: Options for the reduce-scatter operation. Can be either a
185
+ ReduceScatterOptions object or a ReduceOp enum.
186
+ process_group: The process group to perform the reduce-scatter on.
187
+ sync_stream: Optional CUDA stream to use for synchronization. If None,
188
+ a new stream will be created.
189
+
190
+ Returns:
191
+ A Future that can be used to wait for the operation to complete and
192
+ clean up intermediate buffers.
193
+
194
+ Raises:
195
+ NotImplementedError: If the reduce operation is not ReduceOp.AVG or ReduceOp.SUM.
196
+ """
197
+
198
+ if isinstance(opts, ReduceOp):
199
+ reducescatter_opts: ReduceScatterOptions = ReduceScatterOptions()
200
+ reducescatter_opts.reduceOp = opts
201
+ else:
202
+ reducescatter_opts: ReduceScatterOptions = opts
203
+
204
+ # Check if the reduceOp is AVG or SUM
205
+ if reducescatter_opts.reduceOp not in {
206
+ ReduceOp(ReduceOp.AVG),
207
+ ReduceOp(ReduceOp.SUM),
208
+ }:
209
+ raise NotImplementedError(
210
+ f"ReduceOp {reducescatter_opts.reduceOp} is not supported "
211
+ f"for quantized reduce-scatter, only AVG and SUM are supported"
212
+ )
213
+
214
+ rank: int = process_group.rank()
215
+ world_size: int = process_group.size()
216
+
217
+ reduce_output_sizes = [
218
+ torch.Size((s[0] // world_size, *s[1:]))
219
+ for s in get_padded_sizes(inputs, world_size)
220
+ ]
221
+ reduce_output_numels = [s.numel() for s in reduce_output_sizes]
222
+ reduce_outputs: list[torch.Tensor] = [
223
+ o.view(s)
224
+ for o, s in zip(
225
+ output.split(reduce_output_numels),
226
+ reduce_output_sizes,
227
+ )
228
+ ]
229
+
230
+ if sync_stream is None:
231
+ sync_stream = cuda.Stream()
232
+
233
+ assert sync_stream is not None
234
+ # Ensure that all operations are completed on the current stream
235
+ # before proceeding with all-reduce
236
+ sync_stream.wait_stream(cuda.current_stream())
237
+ with cuda.stream(sync_stream):
238
+ # Quantize tensoers and compute their scales, all inlined in the
239
+ # output tensor.
240
+ quantized_inputs = fused_quantize_into_fp8(inputs, world_size)
241
+
242
+ # Allocate output tensor where all-reduce results will be stored
243
+ quantized_inputs_out: torch.Tensor = torch.zeros_like(quantized_inputs)
244
+ # Collect chunks and their scales from other ranks
245
+ work = process_group.alltoall_base(
246
+ quantized_inputs_out.view(world_size, -1),
247
+ quantized_inputs.view(world_size, -1),
248
+ [],
249
+ [],
250
+ _to_alltoall_options(reducescatter_opts),
251
+ )
252
+ work.wait()
253
+
254
+ fut = work.get_future()
255
+
256
+ def callback(fut: Future[list[torch.Tensor]]) -> None:
257
+ nonlocal \
258
+ inputs, \
259
+ quantized_inputs_out, \
260
+ world_size, \
261
+ sync_stream, \
262
+ rank, \
263
+ reduce_outputs, \
264
+ reducescatter_opts
265
+
266
+ with torch.cuda.stream(sync_stream):
267
+ # Setup stream dependency
268
+ fut.wait()
269
+ # Reduce chunks locally in higher precision after dequantization.
270
+ # The output is again quantized.
271
+ fused_reduce_fp8(
272
+ inputs,
273
+ quantized_inputs_out,
274
+ world_size,
275
+ rank,
276
+ reducescatter_opts.reduceOp,
277
+ )
278
+
279
+ # Get view into the output tensor that corresponds to the
280
+ # current rank
281
+ quantized_reduce_scatter = (
282
+ quantized_inputs_out.view(world_size, -1).split(1)[rank].squeeze(0)
283
+ )
284
+ # Dequantize the result back to the original precision for
285
+ # the current rank
286
+ fused_dequantize_from_fp8(
287
+ reduce_outputs,
288
+ quantized_reduce_scatter,
289
+ 1,
290
+ )
291
+
292
+ fut.add_done_callback(callback)
293
+
294
+ return work
295
+
296
+
297
+ def allreduce_quantized(
298
+ tensors: list[torch.Tensor],
299
+ opts: AllreduceOptions | ReduceOp,
300
+ process_group: "ProcessGroup",
301
+ sync_stream: cuda.Stream | None = None,
302
+ ) -> Work:
303
+ """
304
+ Performs a quantized all-reduce operation on a list of tensors.
305
+
306
+ This function implements an optimized all-reduce that reduces communication
307
+ overhead by quantizing tensors to FP8 format before sending them over the
308
+ network. The algorithm works as follows:
309
+
310
+ 1. Quantize input tensors to FP8 format
311
+ 2. Distribute chunks of quantized tensors to all ranks using all-to-all
312
+ 3. Reduce chunks locally in higher precision after dequantization
313
+ 4. Collect reduced chunks from all ranks using all-gather
314
+ 5. Dequantize the result back to the original precision
315
+
316
+ This implementation only supports the AVG reduce operation.
317
+
318
+ Args:
319
+ tensors: List of tensors to be reduced. All tensors must be on the same
320
+ CUDA device and have the same dtype.
321
+ opts: Options for the all-reduce operation. Can be either an
322
+ AllreduceOptions object or a ReduceOp enum. If a ReduceOp is
323
+ provided, it must be ReduceOp.AVG.
324
+ process_group: The process group to perform the all-reduce on.
325
+ sync_stream: Optional CUDA stream to use for synchronization. If None,
326
+ a new stream will be created.
327
+
328
+ Returns:
329
+ A Future that can be used to wait for the operation to complete and
330
+ clean up intermediate buffers.
331
+
332
+ Raises:
333
+ NotImplementedError: If the reduce operation is not ReduceOp.AVG.
334
+ """
335
+ if isinstance(opts, ReduceOp):
336
+ allreduce_opts = AllreduceOptions()
337
+ allreduce_opts.reduceOp = opts
338
+ else:
339
+ allreduce_opts = opts
340
+
341
+ # Check if the reduceOp is AVG or SUM
342
+ if allreduce_opts.reduceOp not in {
343
+ ReduceOp(ReduceOp.AVG),
344
+ ReduceOp(ReduceOp.SUM),
345
+ }:
346
+ raise NotImplementedError(
347
+ f"ReduceOp {allreduce_opts.reduceOp} is not supported "
348
+ f"for quantized allreduce, only AVG and SUM are supported"
349
+ )
350
+
351
+ rank = process_group.rank()
352
+ world_size: int = process_group.size()
353
+
354
+ if sync_stream is None:
355
+ sync_stream = cuda.Stream()
356
+
357
+ assert sync_stream is not None
358
+ # Ensure that all operations are completed on the current stream
359
+ # before proceeding with all-reduce
360
+ sync_stream.wait_stream(cuda.current_stream())
361
+ with cuda.stream(sync_stream):
362
+ # Quantize tensoers and compute their scales, all inlined in the
363
+ # output tensor.
364
+ quantized_tensors: torch.Tensor = fused_quantize_into_fp8(tensors, world_size)
365
+
366
+ # Allocate output tensor where all-reduce results will be stored
367
+ quantized_tensors_out = torch.zeros_like(quantized_tensors)
368
+ # Collect chunks and their scales from other ranks
369
+ process_group.alltoall_base(
370
+ quantized_tensors_out.view(world_size, -1),
371
+ quantized_tensors.view(world_size, -1),
372
+ [],
373
+ [],
374
+ _to_alltoall_options(allreduce_opts),
375
+ ).wait()
376
+
377
+ # Reduce chunks locally in higher precision after dequantization.
378
+ # The output is again quantized.
379
+ fused_reduce_fp8(
380
+ tensors,
381
+ quantized_tensors_out,
382
+ world_size,
383
+ rank,
384
+ allreduce_opts.reduceOp,
385
+ )
386
+
387
+ # Collect reduced chunks from other ranks.
388
+ work = process_group.allgather_into_tensor_coalesced(
389
+ [quantized_tensors.view(world_size, -1)],
390
+ [torch.split(quantized_tensors_out.view(world_size, -1), 1)[rank]],
391
+ _to_allgather_options(allreduce_opts),
392
+ )
393
+
394
+ # NOTE: This is not supposed to be used with gloo, only with NCCL.
395
+ # So we setup the stream dependency here by calling work.wait(),
396
+ # which doesn't block the CPU.
397
+ #
398
+ # The future callback below will run after the work has been
399
+ # completed.
400
+
401
+ work.wait()
402
+ fut = work.get_future()
403
+
404
+ def callback(fut: Future[list[torch.Tensor]]) -> None:
405
+ # Dequantize and copy to output buffer.
406
+ nonlocal tensors, quantized_tensors, world_size, sync_stream
407
+
408
+ with torch.cuda.stream(sync_stream):
409
+ # Setup stream dependency
410
+ fut.wait()
411
+ # Dequantize the result back to the original precision
412
+ fused_dequantize_from_fp8(tensors, quantized_tensors, world_size)
413
+
414
+ fut.add_done_callback(callback)
415
+ return work