torchft-nightly 2026.1.3__cp310-cp310-manylinux_2_24_x86_64.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- torchft/__init__.py +34 -0
- torchft/_test/diloco_trainer.py +287 -0
- torchft/_test/managed_work_test.py +320 -0
- torchft/_test_utils.py +111 -0
- torchft/_torchft.cpython-310-x86_64-linux-gnu.so +0 -0
- torchft/_torchft.pyi +116 -0
- torchft/checkpointing/__init__.py +20 -0
- torchft/checkpointing/_rwlock.py +136 -0
- torchft/checkpointing/_serialization.py +39 -0
- torchft/checkpointing/http_transport.py +299 -0
- torchft/checkpointing/http_transport_bench.py +61 -0
- torchft/checkpointing/http_transport_test.py +146 -0
- torchft/checkpointing/pg_transport.py +306 -0
- torchft/checkpointing/pg_transport_bench.py +99 -0
- torchft/checkpointing/pg_transport_test.py +101 -0
- torchft/checkpointing/rwlock_test.py +58 -0
- torchft/checkpointing/transport.py +68 -0
- torchft/checkpointing/transport_test.py +161 -0
- torchft/collectives.py +415 -0
- torchft/collectives_test.py +212 -0
- torchft/coordination.py +39 -0
- torchft/coordination_test.py +29 -0
- torchft/data.py +77 -0
- torchft/data_test.py +39 -0
- torchft/ddp.py +105 -0
- torchft/ddp_test.py +68 -0
- torchft/diloco_regression_test.py +644 -0
- torchft/examples/slurm/README.md +34 -0
- torchft/examples/slurm/punisher.py +95 -0
- torchft/examples/slurm/runner.py +221 -0
- torchft/fsdp_test.py +102 -0
- torchft/futures.py +353 -0
- torchft/futures_test.py +140 -0
- torchft/http.py +13 -0
- torchft/lighthouse_test.py +163 -0
- torchft/local_sgd.py +796 -0
- torchft/local_sgd_integ_test.py +600 -0
- torchft/local_sgd_test.py +324 -0
- torchft/manager.py +1358 -0
- torchft/manager_integ_test.py +653 -0
- torchft/manager_test.py +911 -0
- torchft/multiprocessing.py +38 -0
- torchft/multiprocessing_dummy_context.py +135 -0
- torchft/multiprocessing_test.py +58 -0
- torchft/optim.py +63 -0
- torchft/optim_test.py +50 -0
- torchft/otel.py +134 -0
- torchft/parameter_server.py +195 -0
- torchft/parameter_server_test.py +47 -0
- torchft/process_group.py +2118 -0
- torchft/process_group_test.py +1028 -0
- torchft/quantization.py +686 -0
- torchft/quantization_test.py +131 -0
- torchft/torchx.py +89 -0
- torchft/utils.py +67 -0
- torchft/work.py +26 -0
- torchft_nightly-2026.1.3.dist-info/METADATA +308 -0
- torchft_nightly-2026.1.3.dist-info/RECORD +61 -0
- torchft_nightly-2026.1.3.dist-info/WHEEL +4 -0
- torchft_nightly-2026.1.3.dist-info/entry_points.txt +2 -0
- torchft_nightly-2026.1.3.dist-info/licenses/LICENSE +34 -0
torchft/local_sgd.py
ADDED
|
@@ -0,0 +1,796 @@
|
|
|
1
|
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
2
|
+
# All rights reserved.
|
|
3
|
+
#
|
|
4
|
+
# This source code is licensed under the BSD-style license found in the
|
|
5
|
+
# LICENSE file in the root directory of this source tree.
|
|
6
|
+
"""
|
|
7
|
+
LocalSGD
|
|
8
|
+
=========
|
|
9
|
+
This module implements a fault tolerant version of LocalSGD and related methods.
|
|
10
|
+
"""
|
|
11
|
+
|
|
12
|
+
import logging
|
|
13
|
+
import math
|
|
14
|
+
import os
|
|
15
|
+
from contextlib import nullcontext
|
|
16
|
+
from types import TracebackType
|
|
17
|
+
from typing import Any, Dict, List, Optional, Tuple, Type
|
|
18
|
+
|
|
19
|
+
import torch
|
|
20
|
+
from torch import nn, optim
|
|
21
|
+
from torch.distributed.distributed_c10d import Work
|
|
22
|
+
from torch.distributed.tensor import DTensor
|
|
23
|
+
from torch.utils.hooks import RemovableHandle
|
|
24
|
+
|
|
25
|
+
from torchft.manager import Manager
|
|
26
|
+
|
|
27
|
+
logger: logging.Logger = logging.getLogger(__name__)
|
|
28
|
+
|
|
29
|
+
USE_BUCKETIZATION_ENV: str = "TORCHFT_USE_BUCKETIZATION"
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
def extract_local_tensor(t: torch.Tensor) -> torch.Tensor:
|
|
33
|
+
"""
|
|
34
|
+
Returns a cloned version of the input tensor. If the input tensor is a DTensor,
|
|
35
|
+
it extracts and clones its local representation.
|
|
36
|
+
"""
|
|
37
|
+
new_tensor = None
|
|
38
|
+
if isinstance(t, DTensor):
|
|
39
|
+
new_tensor = t.to_local().clone()
|
|
40
|
+
else:
|
|
41
|
+
new_tensor = t.clone()
|
|
42
|
+
new_tensor.grad = None
|
|
43
|
+
return new_tensor
|
|
44
|
+
|
|
45
|
+
|
|
46
|
+
class LocalSGD:
|
|
47
|
+
"""
|
|
48
|
+
LocalSGD is a context manager that
|
|
49
|
+
implements the algorithm described in https://arxiv.org/pdf/1805.09767
|
|
50
|
+
|
|
51
|
+
This will synchronize the model parameters periodically in a fault tolerant
|
|
52
|
+
way using a torchft Manager. The allreduce on the parameters will happen
|
|
53
|
+
every sync_every steps after the optimizer.step call.
|
|
54
|
+
|
|
55
|
+
The torchft quorum is computed at the beginning of ``sync_every`` steps. If
|
|
56
|
+
any error occurs, or a worker fails between syncs, ``sync_every`` steps will be
|
|
57
|
+
discarded and a new quorum will be computed on the next step.
|
|
58
|
+
|
|
59
|
+
If running in async mode, on a joining worker the first ``sync_every`` steps
|
|
60
|
+
will discarded as the model will be recovering during that period. When
|
|
61
|
+
using sync mode, the checkpoint will be restored prior to the first step.
|
|
62
|
+
"""
|
|
63
|
+
|
|
64
|
+
def __init__(
|
|
65
|
+
self,
|
|
66
|
+
manager: Manager,
|
|
67
|
+
model: nn.Module,
|
|
68
|
+
optimizer: optim.Optimizer,
|
|
69
|
+
sync_every: int,
|
|
70
|
+
) -> None:
|
|
71
|
+
"""
|
|
72
|
+
Args:
|
|
73
|
+
manager: The manager to use.
|
|
74
|
+
model: The model to wrap.
|
|
75
|
+
optimizer: The optimizer used by the model.
|
|
76
|
+
sync_every: How often to sync the model weights.
|
|
77
|
+
"""
|
|
78
|
+
super().__init__()
|
|
79
|
+
self._manager = manager
|
|
80
|
+
self._model = model
|
|
81
|
+
self._local_optimizer = optimizer
|
|
82
|
+
self._local_step = 0
|
|
83
|
+
self._sync_every = sync_every
|
|
84
|
+
assert sync_every >= 1, "sync_every must be greater than or equal to 1"
|
|
85
|
+
|
|
86
|
+
self._hooks: List[RemovableHandle] = []
|
|
87
|
+
|
|
88
|
+
def __enter__(self) -> "LocalSGD":
|
|
89
|
+
self._hooks.append(
|
|
90
|
+
self._local_optimizer.register_step_pre_hook(self._step_pre_hook)
|
|
91
|
+
)
|
|
92
|
+
# Add optimizer hook which increments the local step counter and syncs if necessary
|
|
93
|
+
self._hooks.append(
|
|
94
|
+
self._local_optimizer.register_step_post_hook(self._step_post_hook)
|
|
95
|
+
)
|
|
96
|
+
return self
|
|
97
|
+
|
|
98
|
+
def __exit__(
|
|
99
|
+
self,
|
|
100
|
+
exc_type: Optional[Type[BaseException]],
|
|
101
|
+
exc_value: Optional[BaseException],
|
|
102
|
+
traceback: Optional[TracebackType],
|
|
103
|
+
) -> bool:
|
|
104
|
+
# Handle any cleanup or error handling here
|
|
105
|
+
# Clean up hooks
|
|
106
|
+
for hook in self._hooks:
|
|
107
|
+
hook.remove()
|
|
108
|
+
self._hooks.clear()
|
|
109
|
+
|
|
110
|
+
return False # Propagate exceptions
|
|
111
|
+
|
|
112
|
+
def _step_pre_hook(
|
|
113
|
+
self, _optim: optim.Optimizer, _args: Tuple[Any, ...], _kwargs: Dict[str, Any]
|
|
114
|
+
) -> None:
|
|
115
|
+
# The checkpoint may transfer model parameters, so we need to make access to it thread safe
|
|
116
|
+
self._manager.disallow_state_dict_read()
|
|
117
|
+
|
|
118
|
+
def _step_post_hook(
|
|
119
|
+
self, _optim: optim.Optimizer, _args: Tuple[Any, ...], _kwargs: Dict[str, Any]
|
|
120
|
+
) -> None:
|
|
121
|
+
"""
|
|
122
|
+
This hook is registered on the optimizer and is called after the optimizer step.
|
|
123
|
+
"""
|
|
124
|
+
self._manager.allow_state_dict_read()
|
|
125
|
+
|
|
126
|
+
self._local_step += 1
|
|
127
|
+
if self._local_step >= self._sync_every:
|
|
128
|
+
self.sync()
|
|
129
|
+
|
|
130
|
+
def sync(self) -> None:
|
|
131
|
+
"""
|
|
132
|
+
Synchronizes and averages the model weights across the manager.
|
|
133
|
+
"""
|
|
134
|
+
self._manager.start_quorum()
|
|
135
|
+
self._perform_sync()
|
|
136
|
+
self._local_step = 0
|
|
137
|
+
|
|
138
|
+
def _perform_sync(self) -> None:
|
|
139
|
+
"""
|
|
140
|
+
Performs the synchronization of the model weights across the manager.
|
|
141
|
+
"""
|
|
142
|
+
averaged_parameters = self._average()
|
|
143
|
+
if self._manager.should_commit():
|
|
144
|
+
# Update the model parameters with the averaged values
|
|
145
|
+
for param, avg_param in zip(self._model.parameters(), averaged_parameters):
|
|
146
|
+
if isinstance(param, DTensor):
|
|
147
|
+
# we averaged the local version of the tensor so need to copy it back as a DTensor
|
|
148
|
+
param.data.copy_(
|
|
149
|
+
DTensor.from_local(
|
|
150
|
+
avg_param,
|
|
151
|
+
param.device_mesh,
|
|
152
|
+
param.placements,
|
|
153
|
+
shape=param.shape,
|
|
154
|
+
stride=param.stride(),
|
|
155
|
+
)
|
|
156
|
+
)
|
|
157
|
+
else:
|
|
158
|
+
param.data.copy_(avg_param)
|
|
159
|
+
|
|
160
|
+
def _average(self) -> list[torch.Tensor]:
|
|
161
|
+
"""
|
|
162
|
+
Averages the model parameters across the manager and returns the averaged parameters.
|
|
163
|
+
"""
|
|
164
|
+
works = []
|
|
165
|
+
averaged_parameters = []
|
|
166
|
+
for p in self._model.parameters():
|
|
167
|
+
# Create a new tensor to store the averaged parameter
|
|
168
|
+
avg_param = extract_local_tensor(p)
|
|
169
|
+
works.append(self._manager.allreduce(avg_param))
|
|
170
|
+
averaged_parameters.append(avg_param)
|
|
171
|
+
for work in works:
|
|
172
|
+
work.wait()
|
|
173
|
+
return averaged_parameters
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
class _StreamingDiLoCoFragment:
|
|
177
|
+
bucket_cap_mb: int = 1 * 1024 * 1024 * 1024
|
|
178
|
+
use_bucketization: bool = False
|
|
179
|
+
|
|
180
|
+
def __init__(
|
|
181
|
+
self,
|
|
182
|
+
manager: Manager,
|
|
183
|
+
model_fragment: nn.Module,
|
|
184
|
+
fragment_id: int,
|
|
185
|
+
fragment_sync_offset: int,
|
|
186
|
+
inner_optimizer: optim.Optimizer,
|
|
187
|
+
outer_optimizer: optim.Optimizer,
|
|
188
|
+
sync_every: int,
|
|
189
|
+
backup_device: Optional[torch.device] = None,
|
|
190
|
+
pin_memory: bool = True,
|
|
191
|
+
use_bucketization: bool = False,
|
|
192
|
+
bucket_cap_mb: Optional[int] = None,
|
|
193
|
+
should_quantize: bool = False,
|
|
194
|
+
fragment_sync_delay: int = 0,
|
|
195
|
+
fragment_update_alpha: float = 0.0,
|
|
196
|
+
) -> None:
|
|
197
|
+
if fragment_sync_offset > sync_every:
|
|
198
|
+
raise ValueError("Fragment must be synced once before `sync_every` steps")
|
|
199
|
+
|
|
200
|
+
self._fragment_id = fragment_id
|
|
201
|
+
self._manager = manager
|
|
202
|
+
self._model_fragment = model_fragment
|
|
203
|
+
self._fragment_sync_offset = fragment_sync_offset
|
|
204
|
+
self._local_optimizer = inner_optimizer
|
|
205
|
+
self._sync_every = sync_every
|
|
206
|
+
assert sync_every >= 1, "sync_every must be greater than or equal to 1"
|
|
207
|
+
self._backup_device = backup_device
|
|
208
|
+
self._pin_memory = pin_memory
|
|
209
|
+
self._fragment_sync_delay = fragment_sync_delay
|
|
210
|
+
self._fragment_update_alpha = fragment_update_alpha
|
|
211
|
+
|
|
212
|
+
self._outer_optimizer = outer_optimizer
|
|
213
|
+
|
|
214
|
+
# Stores pending all reduce
|
|
215
|
+
self._allreduce_work: list[Work] = []
|
|
216
|
+
self._stream: Optional[torch.cuda.Stream] = (
|
|
217
|
+
torch.cuda.Stream() if torch.cuda.is_available() else None
|
|
218
|
+
)
|
|
219
|
+
|
|
220
|
+
# Recorded on `_stream` to wait for allreduce to finish
|
|
221
|
+
self._stop_event: Optional[torch.cuda.Event] = None
|
|
222
|
+
|
|
223
|
+
if bucket_cap_mb is not None:
|
|
224
|
+
self.bucket_cap_mb = int(bucket_cap_mb * 1024 * 1024)
|
|
225
|
+
|
|
226
|
+
if os.getenv(USE_BUCKETIZATION_ENV, "False") == "True":
|
|
227
|
+
self.use_bucketization = True
|
|
228
|
+
else:
|
|
229
|
+
self.use_bucketization = use_bucketization
|
|
230
|
+
|
|
231
|
+
self.should_quantize = should_quantize
|
|
232
|
+
|
|
233
|
+
self._grads: Dict[str, torch.Tensor] = {}
|
|
234
|
+
|
|
235
|
+
# Used to save global parameters so that they can be restored in case
|
|
236
|
+
# commit fails
|
|
237
|
+
self.original_parameters: Dict[str, torch.Tensor] = {}
|
|
238
|
+
|
|
239
|
+
# Used to mix the local and global parameters
|
|
240
|
+
self._local_parameters: Dict[str, torch.Tensor] = {}
|
|
241
|
+
|
|
242
|
+
for name, p in self._model_fragment.named_parameters():
|
|
243
|
+
if isinstance(p, DTensor):
|
|
244
|
+
p = extract_local_tensor(p.data)
|
|
245
|
+
|
|
246
|
+
backup_device = self._backup_device or torch.device("cpu")
|
|
247
|
+
t = torch.empty(*tuple(p.shape), dtype=p.dtype, device=backup_device)
|
|
248
|
+
if (
|
|
249
|
+
self._pin_memory
|
|
250
|
+
and t.device == torch.device("cpu")
|
|
251
|
+
and torch.cuda.is_available()
|
|
252
|
+
):
|
|
253
|
+
t = t.pin_memory()
|
|
254
|
+
self.original_parameters[name] = t
|
|
255
|
+
|
|
256
|
+
def register_state_dict_fn(self) -> None:
|
|
257
|
+
"""
|
|
258
|
+
Register state dict functions for this fragment with the manager.
|
|
259
|
+
This allows for saving and loading the original_parameters during checkpointing and recovery.
|
|
260
|
+
|
|
261
|
+
Args:
|
|
262
|
+
manager: The manager to register with
|
|
263
|
+
fragment_id: Optional identifier for this fragment, used in the key
|
|
264
|
+
"""
|
|
265
|
+
# Generate a unique key for this fragment based on the model fragment's name or provided ID
|
|
266
|
+
fragment_key = f"StreamingDiLoCoFragment_{self._fragment_id}"
|
|
267
|
+
|
|
268
|
+
# Define load function for this fragment
|
|
269
|
+
def load_fn(state_dict: Dict[str, Dict[str, torch.Tensor]]) -> None:
|
|
270
|
+
for name, param in state_dict["original_parameters"].items():
|
|
271
|
+
if name in self.original_parameters:
|
|
272
|
+
self.original_parameters[name].copy_(param)
|
|
273
|
+
|
|
274
|
+
self._outer_optimizer.load_state_dict(state_dict["outer_optimizer"])
|
|
275
|
+
|
|
276
|
+
# Define save function for this fragment
|
|
277
|
+
def save_fn() -> Dict[str, Dict[str, torch.Tensor]]:
|
|
278
|
+
return {
|
|
279
|
+
"outer_optimizer": self._outer_optimizer.state_dict(),
|
|
280
|
+
"original_parameters": {
|
|
281
|
+
name: extract_local_tensor(param)
|
|
282
|
+
for name, param in self.original_parameters.items()
|
|
283
|
+
},
|
|
284
|
+
}
|
|
285
|
+
|
|
286
|
+
# Register the functions with the manager
|
|
287
|
+
self._manager.register_state_dict_fn(fragment_key, load_fn, save_fn)
|
|
288
|
+
|
|
289
|
+
@torch.profiler.record_function("torchft::local_sgd::save_parameters")
|
|
290
|
+
def save_parameters(self) -> None:
|
|
291
|
+
with torch.no_grad():
|
|
292
|
+
# TODO: consider running copy on a separate stream
|
|
293
|
+
for name, p in self._model_fragment.named_parameters():
|
|
294
|
+
param_to_local = extract_local_tensor(p.data)
|
|
295
|
+
self.original_parameters[name].copy_(param_to_local, non_blocking=True)
|
|
296
|
+
|
|
297
|
+
def _save_local_parameters(self) -> None:
|
|
298
|
+
"""
|
|
299
|
+
Saves a copy of the model's parameters.
|
|
300
|
+
"""
|
|
301
|
+
with torch.no_grad():
|
|
302
|
+
for name, p in self._model_fragment.named_parameters():
|
|
303
|
+
self._local_parameters[name] = extract_local_tensor(p.data)
|
|
304
|
+
|
|
305
|
+
@torch.profiler.record_function("torchft::local_sgd::restore_parameters")
|
|
306
|
+
def restore_parameters(self) -> None:
|
|
307
|
+
with torch.no_grad():
|
|
308
|
+
# TODO: consider running copy on a separate stream
|
|
309
|
+
for name, p in self._model_fragment.named_parameters():
|
|
310
|
+
if isinstance(p, DTensor):
|
|
311
|
+
# we averaged the local version of the tensor so need to copy it back as a DTensor
|
|
312
|
+
p.data.copy_(
|
|
313
|
+
DTensor.from_local(
|
|
314
|
+
self.original_parameters[name],
|
|
315
|
+
p.device_mesh,
|
|
316
|
+
p.placements,
|
|
317
|
+
shape=p.shape,
|
|
318
|
+
stride=p.stride(),
|
|
319
|
+
),
|
|
320
|
+
non_blocking=False,
|
|
321
|
+
)
|
|
322
|
+
else:
|
|
323
|
+
p.data.copy_(self.original_parameters[name], non_blocking=False)
|
|
324
|
+
|
|
325
|
+
def _save_grads(self) -> None:
|
|
326
|
+
"""
|
|
327
|
+
Saves pseudo-gradients of the parameters
|
|
328
|
+
"""
|
|
329
|
+
with torch.no_grad():
|
|
330
|
+
for name, p in self._model_fragment.named_parameters():
|
|
331
|
+
if isinstance(p, DTensor):
|
|
332
|
+
local_param = p.to_local()
|
|
333
|
+
else:
|
|
334
|
+
local_param = p
|
|
335
|
+
pseudogradient = (
|
|
336
|
+
self.original_parameters[name].to(p.device) - local_param
|
|
337
|
+
)
|
|
338
|
+
self._grads[name] = pseudogradient
|
|
339
|
+
|
|
340
|
+
def _set_grads(self) -> None:
|
|
341
|
+
"""
|
|
342
|
+
Sets the gradients of the model fragment from the allreduce result
|
|
343
|
+
"""
|
|
344
|
+
with torch.no_grad():
|
|
345
|
+
for name, p in self._model_fragment.named_parameters():
|
|
346
|
+
# avoid copying the gradient, it should be on the same device
|
|
347
|
+
if isinstance(p, DTensor):
|
|
348
|
+
p.grad = DTensor.from_local(
|
|
349
|
+
self._grads[name],
|
|
350
|
+
p.device_mesh,
|
|
351
|
+
p.placements,
|
|
352
|
+
shape=p.shape,
|
|
353
|
+
stride=p.stride(),
|
|
354
|
+
)
|
|
355
|
+
else:
|
|
356
|
+
p.grad = self._grads[name]
|
|
357
|
+
|
|
358
|
+
# No longer needed
|
|
359
|
+
del self._grads[name]
|
|
360
|
+
|
|
361
|
+
def _clear_local_parameters(self) -> None:
|
|
362
|
+
"""
|
|
363
|
+
Clears the saved copy of the model's parameters
|
|
364
|
+
"""
|
|
365
|
+
self._local_parameters = {}
|
|
366
|
+
|
|
367
|
+
def _merge_parameters(self) -> None:
|
|
368
|
+
"""
|
|
369
|
+
Merges the local and global parameters.
|
|
370
|
+
"""
|
|
371
|
+
for name, p in self._model_fragment.named_parameters():
|
|
372
|
+
# we averaged the local version of the tensor so need to copy it back as a DTensor
|
|
373
|
+
if isinstance(p, DTensor):
|
|
374
|
+
p.data.lerp_(
|
|
375
|
+
DTensor.from_local(
|
|
376
|
+
self._local_parameters[name],
|
|
377
|
+
p.device_mesh,
|
|
378
|
+
p.placements,
|
|
379
|
+
shape=p.shape,
|
|
380
|
+
stride=p.stride(),
|
|
381
|
+
),
|
|
382
|
+
self._fragment_update_alpha,
|
|
383
|
+
)
|
|
384
|
+
else:
|
|
385
|
+
p.data.lerp_(self._local_parameters[name], self._fragment_update_alpha)
|
|
386
|
+
|
|
387
|
+
@torch.profiler.record_function("torchft::local_sgd::wait")
|
|
388
|
+
def wait(self) -> None:
|
|
389
|
+
"""
|
|
390
|
+
Waits for the previously scheduled allreduce to finish
|
|
391
|
+
"""
|
|
392
|
+
if len(self._allreduce_work) == 0:
|
|
393
|
+
return
|
|
394
|
+
|
|
395
|
+
if self._stream is not None:
|
|
396
|
+
assert self._stop_event is not None
|
|
397
|
+
self._stop_event.synchronize()
|
|
398
|
+
self._stop_event = None
|
|
399
|
+
|
|
400
|
+
self._allreduce_work = []
|
|
401
|
+
|
|
402
|
+
@torch.profiler.record_function("torchft::local_sgd::prepare_sync")
|
|
403
|
+
def prepare_sync(self) -> None:
|
|
404
|
+
"""
|
|
405
|
+
Calculate the pseugradient, average them across the manager group and starts
|
|
406
|
+
allreduce on the pseudo-gradients but doesn't wait for it to finish.
|
|
407
|
+
"""
|
|
408
|
+
self._save_grads()
|
|
409
|
+
|
|
410
|
+
assert len(self._allreduce_work) == 0
|
|
411
|
+
|
|
412
|
+
# Make sure tensors are available to `_stream`
|
|
413
|
+
if self._stream is not None:
|
|
414
|
+
self._stream.wait_stream(torch.cuda.current_stream())
|
|
415
|
+
|
|
416
|
+
with (
|
|
417
|
+
torch.cuda.stream(self._stream)
|
|
418
|
+
if self._stream is not None
|
|
419
|
+
else nullcontext()
|
|
420
|
+
):
|
|
421
|
+
self._average_grads()
|
|
422
|
+
|
|
423
|
+
@torch.profiler.record_function("torchft::local_sgd::perform_sync")
|
|
424
|
+
def perform_sync(self) -> bool:
|
|
425
|
+
"""
|
|
426
|
+
Overrides the sync method to wait for the scheduled allreduce to finish and
|
|
427
|
+
steps using the outer optimizer.
|
|
428
|
+
"""
|
|
429
|
+
# Waiting for an allreduce before it has been sent is currently not supported.
|
|
430
|
+
assert len(self._allreduce_work) > 0
|
|
431
|
+
|
|
432
|
+
with (
|
|
433
|
+
torch.cuda.stream(self._stream)
|
|
434
|
+
if self._stream is not None
|
|
435
|
+
else nullcontext()
|
|
436
|
+
):
|
|
437
|
+
for work in self._allreduce_work:
|
|
438
|
+
work.wait()
|
|
439
|
+
|
|
440
|
+
if self._stream is not None:
|
|
441
|
+
self._stop_event = torch.cuda.Event()
|
|
442
|
+
self._stop_event.record()
|
|
443
|
+
|
|
444
|
+
self.wait()
|
|
445
|
+
|
|
446
|
+
# save the parameters so they can be used for merging
|
|
447
|
+
self._save_local_parameters()
|
|
448
|
+
# Restore the parameters back to the previous state
|
|
449
|
+
self.restore_parameters()
|
|
450
|
+
|
|
451
|
+
# For large values of `fragment_sync_delay`, this call can be
|
|
452
|
+
# a problem.
|
|
453
|
+
#
|
|
454
|
+
# This can return success even if the allreduce failed. Because
|
|
455
|
+
# the process group could have been reconfigured while the
|
|
456
|
+
# allreduce was inflight. The inflight allreduce may or may
|
|
457
|
+
# not have been aborted.
|
|
458
|
+
#
|
|
459
|
+
# We can track errors per allreduce to
|
|
460
|
+
# let the commit fail here. But this has the downside of
|
|
461
|
+
# reconfiguring the pg too many times resulting in
|
|
462
|
+
# more aborts and more commit failures.
|
|
463
|
+
should_commit = self._manager.should_commit()
|
|
464
|
+
|
|
465
|
+
if should_commit:
|
|
466
|
+
# Use the outer optimizer to update the model parameters
|
|
467
|
+
self._set_grads()
|
|
468
|
+
self._outer_optimizer.step()
|
|
469
|
+
self.save_parameters()
|
|
470
|
+
self._merge_parameters()
|
|
471
|
+
self._outer_optimizer.zero_grad()
|
|
472
|
+
|
|
473
|
+
# free up memory
|
|
474
|
+
self._clear_local_parameters()
|
|
475
|
+
|
|
476
|
+
return should_commit
|
|
477
|
+
|
|
478
|
+
def _average_grads(self) -> None:
|
|
479
|
+
"""
|
|
480
|
+
Efficiently averages gradients across the group using either:
|
|
481
|
+
- Per-parameter allreduce (old behavior)
|
|
482
|
+
- Bucketized allreduce (new behavior)
|
|
483
|
+
"""
|
|
484
|
+
if self.use_bucketization:
|
|
485
|
+
self._allreduce_bucketized()
|
|
486
|
+
else:
|
|
487
|
+
self._allreduce_per_param()
|
|
488
|
+
|
|
489
|
+
def _allreduce_per_param(self) -> None:
|
|
490
|
+
"""Performs allreduce on each gradient tensor separately (original method)."""
|
|
491
|
+
for name, p in self._model_fragment.named_parameters():
|
|
492
|
+
# Perform allreduce on the pseudogradients
|
|
493
|
+
work = self._manager.allreduce(
|
|
494
|
+
self._grads[name], should_quantize=self.should_quantize
|
|
495
|
+
)
|
|
496
|
+
|
|
497
|
+
self._allreduce_work.append(work)
|
|
498
|
+
|
|
499
|
+
def _bucketize_and_allreduce(
|
|
500
|
+
self,
|
|
501
|
+
tensors: List[torch.Tensor],
|
|
502
|
+
bucket_size_bytes: int,
|
|
503
|
+
) -> None:
|
|
504
|
+
"""
|
|
505
|
+
Applies allreduce on a list of tensors using bucketization.
|
|
506
|
+
|
|
507
|
+
Args:
|
|
508
|
+
tensors: List of torch tensors (e.g., gradients).
|
|
509
|
+
bucket_size_bytes: Max size of each bucket in bytes.
|
|
510
|
+
"""
|
|
511
|
+
if not tensors:
|
|
512
|
+
return
|
|
513
|
+
|
|
514
|
+
total_size = sum(t.numel() for t in tensors)
|
|
515
|
+
dtype, device = tensors[0].dtype, tensors[0].device
|
|
516
|
+
|
|
517
|
+
offset = 0
|
|
518
|
+
flat_index = 0
|
|
519
|
+
while offset < total_size:
|
|
520
|
+
chunk_size = min(
|
|
521
|
+
bucket_size_bytes // tensors[0].element_size(), total_size - offset
|
|
522
|
+
)
|
|
523
|
+
flat_buffer: torch.Tensor = torch.zeros(
|
|
524
|
+
chunk_size, dtype=dtype, device=device
|
|
525
|
+
)
|
|
526
|
+
|
|
527
|
+
pack_offset: int = 0
|
|
528
|
+
bucket_tensors: list[Tuple[torch.Tensor, int, int]] = []
|
|
529
|
+
for t in tensors[flat_index:]:
|
|
530
|
+
numel = t.numel()
|
|
531
|
+
if pack_offset + numel > chunk_size:
|
|
532
|
+
break
|
|
533
|
+
flat_buffer[pack_offset : pack_offset + numel].copy_(t.view(-1))
|
|
534
|
+
bucket_tensors.append((t, pack_offset, numel))
|
|
535
|
+
pack_offset += numel
|
|
536
|
+
flat_index += 1
|
|
537
|
+
|
|
538
|
+
work = self._manager.allreduce(
|
|
539
|
+
flat_buffer, should_quantize=self.should_quantize
|
|
540
|
+
)
|
|
541
|
+
|
|
542
|
+
def callback(
|
|
543
|
+
fut: torch.futures.Future[list[torch.Tensor]],
|
|
544
|
+
) -> list[torch.Tensor]:
|
|
545
|
+
nonlocal bucket_tensors, flat_buffer
|
|
546
|
+
for t, pack_offset, numel in bucket_tensors:
|
|
547
|
+
t.copy_(flat_buffer[pack_offset : pack_offset + numel].view_as(t))
|
|
548
|
+
|
|
549
|
+
return []
|
|
550
|
+
|
|
551
|
+
fut = work.get_future()
|
|
552
|
+
fut = fut.then(callback)
|
|
553
|
+
|
|
554
|
+
self._allreduce_work.append(work)
|
|
555
|
+
|
|
556
|
+
offset += chunk_size
|
|
557
|
+
|
|
558
|
+
def _allreduce_bucketized(self) -> None:
|
|
559
|
+
"""
|
|
560
|
+
Averages gradients using bucketized allreduce with a fixed buffer.
|
|
561
|
+
"""
|
|
562
|
+
grads = list(self._grads.values())
|
|
563
|
+
assert len(grads) > 0, "No gradients to allreduce"
|
|
564
|
+
self._bucketize_and_allreduce(
|
|
565
|
+
grads,
|
|
566
|
+
bucket_size_bytes=self.bucket_cap_mb,
|
|
567
|
+
)
|
|
568
|
+
|
|
569
|
+
|
|
570
|
+
class DiLoCo:
|
|
571
|
+
"""
|
|
572
|
+
DiLoCo implements distributed optimization by averaging and synchronizing
|
|
573
|
+
pseudogradients (delta of the previous global weight and current local weights).
|
|
574
|
+
|
|
575
|
+
The class implements a more general version of DiLoco, Streaming DiLoCo,
|
|
576
|
+
which synchronizes fragments of pseudogradients at different steps.
|
|
577
|
+
|
|
578
|
+
This algorithm requires a backup copy of the
|
|
579
|
+
weights. By default these are stored in CPU memory. If any error occurs
|
|
580
|
+
during the DiLoCo step, the step will be discarded and the model
|
|
581
|
+
parameters will reset back to the last time DiLoCo synchronized.
|
|
582
|
+
|
|
583
|
+
DiLoCo paper: https://arxiv.org/pdf/2311.08105
|
|
584
|
+
Streaming DiLoCo paper: https://arxiv.org/pdf/2501.18512
|
|
585
|
+
"""
|
|
586
|
+
|
|
587
|
+
def __init__(
|
|
588
|
+
self,
|
|
589
|
+
manager: Manager,
|
|
590
|
+
model_fragments: List[nn.Module],
|
|
591
|
+
inner_optimizer: optim.Optimizer,
|
|
592
|
+
# TODO: this is for backward compatibility
|
|
593
|
+
outer_optimizer: optim.Optimizer | list[optim.Optimizer],
|
|
594
|
+
sync_every: int,
|
|
595
|
+
backup_device: Optional[torch.device] = None,
|
|
596
|
+
pin_memory: bool = True,
|
|
597
|
+
use_bucketization: bool = False,
|
|
598
|
+
bucket_cap_mb: Optional[int] = None,
|
|
599
|
+
should_quantize: bool = False,
|
|
600
|
+
fragment_sync_delay: int = 0,
|
|
601
|
+
fragment_update_alpha: float = 0.0,
|
|
602
|
+
) -> None:
|
|
603
|
+
"""
|
|
604
|
+
Args:
|
|
605
|
+
manager: The manager to use.
|
|
606
|
+
model_fragments: The fragments of the model to wrap.
|
|
607
|
+
inner_optimizer: The optimizer used for the local parameters every step.
|
|
608
|
+
outer_optimizer: The optimizer used for the global parameters updated every "sync_every" steps.
|
|
609
|
+
sync_every: How often to update the model weights.
|
|
610
|
+
backup_device: The device to store the backup weights on. If None, the backup weights will be on CPU.
|
|
611
|
+
pin_memory: Whether to pin the memory for the backup weights (only for CPU device).
|
|
612
|
+
should_quantize: Whether to quantize the gradients before allreduce.
|
|
613
|
+
fragment_sync_delay: Controls the number of inner steps to wait before blocking on a fragment's
|
|
614
|
+
synchronization. This is the "tao" parameter in the Streaming DiLoCo paper.
|
|
615
|
+
fragment_update_alpha: Determines how to mix the local and global optimized parameters
|
|
616
|
+
"""
|
|
617
|
+
|
|
618
|
+
if isinstance(outer_optimizer, list):
|
|
619
|
+
assert len(outer_optimizer) == len(
|
|
620
|
+
model_fragments
|
|
621
|
+
), "The number of outer optimizers must match the number of model fragments"
|
|
622
|
+
|
|
623
|
+
if manager._use_async_quorum:
|
|
624
|
+
raise ValueError(
|
|
625
|
+
"Using DiLoCo require synchronous quorum to be enabled. "
|
|
626
|
+
"Ensure that the manager is initialized with use_async_quorum=False"
|
|
627
|
+
)
|
|
628
|
+
|
|
629
|
+
if sync_every < len(model_fragments):
|
|
630
|
+
raise ValueError("Only 1 fragment can be syncrhonized at a time")
|
|
631
|
+
|
|
632
|
+
if sync_every % len(model_fragments) != 0:
|
|
633
|
+
raise ValueError("sync_every must divide the number of fragments")
|
|
634
|
+
|
|
635
|
+
self._sync_every: int = sync_every // len(model_fragments)
|
|
636
|
+
if fragment_sync_delay >= self._sync_every:
|
|
637
|
+
raise ValueError(
|
|
638
|
+
"Fragment must be synced before it is reduced another time"
|
|
639
|
+
)
|
|
640
|
+
|
|
641
|
+
if fragment_update_alpha < 0 or fragment_update_alpha > 1:
|
|
642
|
+
raise ValueError("fragment_update_alpha must be between 0 and 1")
|
|
643
|
+
|
|
644
|
+
super().__init__()
|
|
645
|
+
self._manager = manager
|
|
646
|
+
|
|
647
|
+
# The number of training iterations performed.
|
|
648
|
+
# Used to synchronize which fragment to send across all
|
|
649
|
+
# replicas
|
|
650
|
+
self._local_step = 0
|
|
651
|
+
|
|
652
|
+
self._fragment_sync_delay = fragment_sync_delay
|
|
653
|
+
|
|
654
|
+
self._hooks: List[RemovableHandle] = []
|
|
655
|
+
|
|
656
|
+
self._local_optimizer = inner_optimizer
|
|
657
|
+
|
|
658
|
+
self._fragments: List[_StreamingDiLoCoFragment] = [
|
|
659
|
+
_StreamingDiLoCoFragment(
|
|
660
|
+
manager,
|
|
661
|
+
model_fragment,
|
|
662
|
+
i,
|
|
663
|
+
math.floor((sync_every / len(model_fragments)) * (i + 1)),
|
|
664
|
+
inner_optimizer,
|
|
665
|
+
(
|
|
666
|
+
outer_optimizer[i]
|
|
667
|
+
if isinstance(outer_optimizer, list)
|
|
668
|
+
else outer_optimizer
|
|
669
|
+
),
|
|
670
|
+
sync_every,
|
|
671
|
+
backup_device,
|
|
672
|
+
pin_memory,
|
|
673
|
+
use_bucketization,
|
|
674
|
+
bucket_cap_mb,
|
|
675
|
+
should_quantize,
|
|
676
|
+
fragment_sync_delay,
|
|
677
|
+
fragment_update_alpha,
|
|
678
|
+
)
|
|
679
|
+
for i, model_fragment in enumerate(model_fragments)
|
|
680
|
+
]
|
|
681
|
+
|
|
682
|
+
# This is to make sure we adhere to the assumptions made by the
|
|
683
|
+
# `_StreamingDiLoCoFragment` about the fragment sync schedule.
|
|
684
|
+
assert fragment_sync_delay < sync_every // len(model_fragments)
|
|
685
|
+
|
|
686
|
+
# Need to copy the parameters to the host to be safe if we are on the first step.
|
|
687
|
+
self._save_parameters()
|
|
688
|
+
self._register_state_dict_fn()
|
|
689
|
+
|
|
690
|
+
def _register_state_dict_fn(self) -> None:
|
|
691
|
+
for fragment in self._fragments:
|
|
692
|
+
fragment.register_state_dict_fn()
|
|
693
|
+
|
|
694
|
+
def _save_parameters(self) -> None:
|
|
695
|
+
for fragment in self._fragments:
|
|
696
|
+
fragment.save_parameters()
|
|
697
|
+
|
|
698
|
+
def _restore_parameters(self) -> None:
|
|
699
|
+
for fragment in self._fragments:
|
|
700
|
+
fragment.restore_parameters()
|
|
701
|
+
|
|
702
|
+
def __enter__(self) -> "DiLoCo":
|
|
703
|
+
self._hooks.append(
|
|
704
|
+
self._local_optimizer.register_step_pre_hook(self._step_pre_hook)
|
|
705
|
+
)
|
|
706
|
+
# Add optimizer hook which increments the local step counter and syncs if necessary
|
|
707
|
+
self._hooks.append(
|
|
708
|
+
self._local_optimizer.register_step_post_hook(self._step_post_hook)
|
|
709
|
+
)
|
|
710
|
+
return self
|
|
711
|
+
|
|
712
|
+
def _step_pre_hook(
|
|
713
|
+
self, _optim: optim.Optimizer, _args: Tuple[Any, ...], _kwargs: Dict[str, Any]
|
|
714
|
+
) -> None:
|
|
715
|
+
# The checkpoint may transfer model parameters, so we need to make access to it thread safe
|
|
716
|
+
self._manager.disallow_state_dict_read()
|
|
717
|
+
|
|
718
|
+
def __exit__(
|
|
719
|
+
self,
|
|
720
|
+
exc_type: Optional[Type[BaseException]],
|
|
721
|
+
exc_value: Optional[BaseException],
|
|
722
|
+
traceback: Optional[TracebackType],
|
|
723
|
+
) -> bool:
|
|
724
|
+
# Handle any cleanup or error handling here
|
|
725
|
+
# Clean up hooks
|
|
726
|
+
for hook in self._hooks:
|
|
727
|
+
hook.remove()
|
|
728
|
+
self._hooks.clear()
|
|
729
|
+
|
|
730
|
+
return False # Propagate exceptions
|
|
731
|
+
|
|
732
|
+
def _wait(self) -> None:
|
|
733
|
+
"""
|
|
734
|
+
Waits for allreduce to finish on all fragments
|
|
735
|
+
"""
|
|
736
|
+
for fragment in self._fragments:
|
|
737
|
+
fragment.wait()
|
|
738
|
+
|
|
739
|
+
def _current_fragment(self) -> int:
|
|
740
|
+
"""
|
|
741
|
+
Determines which fragment to prepare/sync based on the current step.
|
|
742
|
+
"""
|
|
743
|
+
step = self._manager.current_step()
|
|
744
|
+
return step % len(self._fragments)
|
|
745
|
+
|
|
746
|
+
def _step_post_hook(
|
|
747
|
+
self, _optim: optim.Optimizer, _args: Tuple[Any, ...], _kwargs: Dict[str, Any]
|
|
748
|
+
) -> None:
|
|
749
|
+
"""
|
|
750
|
+
This hook is registered on the optimizer and is called after the optimizer step.
|
|
751
|
+
"""
|
|
752
|
+
self._manager.allow_state_dict_read()
|
|
753
|
+
|
|
754
|
+
# We need to make sure all nodes send the same fragments in order.
|
|
755
|
+
# This is to avoid deadlocking e.g.
|
|
756
|
+
#
|
|
757
|
+
# 1. Step 1 - Node A sends fragment 1
|
|
758
|
+
# 2. Step 1 - Node B sends fragment 2
|
|
759
|
+
# 3. Step 2 - Node A waits for fragment 1
|
|
760
|
+
# 4. Step 2 - Node B waits for fragment 2
|
|
761
|
+
#
|
|
762
|
+
# Both of them will fail because Node A didn't send fragment 2
|
|
763
|
+
# and Node B didn't send fragment 1.
|
|
764
|
+
self._local_step += 1
|
|
765
|
+
|
|
766
|
+
if self._local_step == self._sync_every - self._fragment_sync_delay:
|
|
767
|
+
# Time to prepare a fragment
|
|
768
|
+
#
|
|
769
|
+
# Some replicas will get the same copy of the model, implying batches
|
|
770
|
+
# can be overrepresented.
|
|
771
|
+
self._manager.start_quorum()
|
|
772
|
+
fragment = self._current_fragment()
|
|
773
|
+
logger.info(f"Preparing fragment={fragment} step={self._local_step}")
|
|
774
|
+
self._fragments[fragment].prepare_sync()
|
|
775
|
+
|
|
776
|
+
if self._local_step < self._sync_every:
|
|
777
|
+
return
|
|
778
|
+
|
|
779
|
+
if self._local_step == self._sync_every:
|
|
780
|
+
# Time to sync a fragment
|
|
781
|
+
fragment = self._current_fragment()
|
|
782
|
+
logger.info(
|
|
783
|
+
f"Syncing fragment={fragment} step={self._local_step} manager_step={self._manager.current_step()}"
|
|
784
|
+
)
|
|
785
|
+
self._fragments[fragment].perform_sync()
|
|
786
|
+
|
|
787
|
+
# If the allreduce truly failed, we'll keep retrying this fragment.
|
|
788
|
+
# We reset the parameters upon failure. We'll skip over some data
|
|
789
|
+
# but we won't over train before syncing.
|
|
790
|
+
|
|
791
|
+
self._local_step = 0
|
|
792
|
+
return
|
|
793
|
+
|
|
794
|
+
assert (
|
|
795
|
+
False
|
|
796
|
+
), f"{self._local_step=} should never be greater than {self._sync_every=}"
|