torchft-nightly 2026.1.3__cp310-cp310-manylinux_2_24_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. torchft/__init__.py +34 -0
  2. torchft/_test/diloco_trainer.py +287 -0
  3. torchft/_test/managed_work_test.py +320 -0
  4. torchft/_test_utils.py +111 -0
  5. torchft/_torchft.cpython-310-x86_64-linux-gnu.so +0 -0
  6. torchft/_torchft.pyi +116 -0
  7. torchft/checkpointing/__init__.py +20 -0
  8. torchft/checkpointing/_rwlock.py +136 -0
  9. torchft/checkpointing/_serialization.py +39 -0
  10. torchft/checkpointing/http_transport.py +299 -0
  11. torchft/checkpointing/http_transport_bench.py +61 -0
  12. torchft/checkpointing/http_transport_test.py +146 -0
  13. torchft/checkpointing/pg_transport.py +306 -0
  14. torchft/checkpointing/pg_transport_bench.py +99 -0
  15. torchft/checkpointing/pg_transport_test.py +101 -0
  16. torchft/checkpointing/rwlock_test.py +58 -0
  17. torchft/checkpointing/transport.py +68 -0
  18. torchft/checkpointing/transport_test.py +161 -0
  19. torchft/collectives.py +415 -0
  20. torchft/collectives_test.py +212 -0
  21. torchft/coordination.py +39 -0
  22. torchft/coordination_test.py +29 -0
  23. torchft/data.py +77 -0
  24. torchft/data_test.py +39 -0
  25. torchft/ddp.py +105 -0
  26. torchft/ddp_test.py +68 -0
  27. torchft/diloco_regression_test.py +644 -0
  28. torchft/examples/slurm/README.md +34 -0
  29. torchft/examples/slurm/punisher.py +95 -0
  30. torchft/examples/slurm/runner.py +221 -0
  31. torchft/fsdp_test.py +102 -0
  32. torchft/futures.py +353 -0
  33. torchft/futures_test.py +140 -0
  34. torchft/http.py +13 -0
  35. torchft/lighthouse_test.py +163 -0
  36. torchft/local_sgd.py +796 -0
  37. torchft/local_sgd_integ_test.py +600 -0
  38. torchft/local_sgd_test.py +324 -0
  39. torchft/manager.py +1358 -0
  40. torchft/manager_integ_test.py +653 -0
  41. torchft/manager_test.py +911 -0
  42. torchft/multiprocessing.py +38 -0
  43. torchft/multiprocessing_dummy_context.py +135 -0
  44. torchft/multiprocessing_test.py +58 -0
  45. torchft/optim.py +63 -0
  46. torchft/optim_test.py +50 -0
  47. torchft/otel.py +134 -0
  48. torchft/parameter_server.py +195 -0
  49. torchft/parameter_server_test.py +47 -0
  50. torchft/process_group.py +2118 -0
  51. torchft/process_group_test.py +1028 -0
  52. torchft/quantization.py +686 -0
  53. torchft/quantization_test.py +131 -0
  54. torchft/torchx.py +89 -0
  55. torchft/utils.py +67 -0
  56. torchft/work.py +26 -0
  57. torchft_nightly-2026.1.3.dist-info/METADATA +308 -0
  58. torchft_nightly-2026.1.3.dist-info/RECORD +61 -0
  59. torchft_nightly-2026.1.3.dist-info/WHEEL +4 -0
  60. torchft_nightly-2026.1.3.dist-info/entry_points.txt +2 -0
  61. torchft_nightly-2026.1.3.dist-info/licenses/LICENSE +34 -0
torchft/local_sgd.py ADDED
@@ -0,0 +1,796 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+ """
7
+ LocalSGD
8
+ =========
9
+ This module implements a fault tolerant version of LocalSGD and related methods.
10
+ """
11
+
12
+ import logging
13
+ import math
14
+ import os
15
+ from contextlib import nullcontext
16
+ from types import TracebackType
17
+ from typing import Any, Dict, List, Optional, Tuple, Type
18
+
19
+ import torch
20
+ from torch import nn, optim
21
+ from torch.distributed.distributed_c10d import Work
22
+ from torch.distributed.tensor import DTensor
23
+ from torch.utils.hooks import RemovableHandle
24
+
25
+ from torchft.manager import Manager
26
+
27
+ logger: logging.Logger = logging.getLogger(__name__)
28
+
29
+ USE_BUCKETIZATION_ENV: str = "TORCHFT_USE_BUCKETIZATION"
30
+
31
+
32
+ def extract_local_tensor(t: torch.Tensor) -> torch.Tensor:
33
+ """
34
+ Returns a cloned version of the input tensor. If the input tensor is a DTensor,
35
+ it extracts and clones its local representation.
36
+ """
37
+ new_tensor = None
38
+ if isinstance(t, DTensor):
39
+ new_tensor = t.to_local().clone()
40
+ else:
41
+ new_tensor = t.clone()
42
+ new_tensor.grad = None
43
+ return new_tensor
44
+
45
+
46
+ class LocalSGD:
47
+ """
48
+ LocalSGD is a context manager that
49
+ implements the algorithm described in https://arxiv.org/pdf/1805.09767
50
+
51
+ This will synchronize the model parameters periodically in a fault tolerant
52
+ way using a torchft Manager. The allreduce on the parameters will happen
53
+ every sync_every steps after the optimizer.step call.
54
+
55
+ The torchft quorum is computed at the beginning of ``sync_every`` steps. If
56
+ any error occurs, or a worker fails between syncs, ``sync_every`` steps will be
57
+ discarded and a new quorum will be computed on the next step.
58
+
59
+ If running in async mode, on a joining worker the first ``sync_every`` steps
60
+ will discarded as the model will be recovering during that period. When
61
+ using sync mode, the checkpoint will be restored prior to the first step.
62
+ """
63
+
64
+ def __init__(
65
+ self,
66
+ manager: Manager,
67
+ model: nn.Module,
68
+ optimizer: optim.Optimizer,
69
+ sync_every: int,
70
+ ) -> None:
71
+ """
72
+ Args:
73
+ manager: The manager to use.
74
+ model: The model to wrap.
75
+ optimizer: The optimizer used by the model.
76
+ sync_every: How often to sync the model weights.
77
+ """
78
+ super().__init__()
79
+ self._manager = manager
80
+ self._model = model
81
+ self._local_optimizer = optimizer
82
+ self._local_step = 0
83
+ self._sync_every = sync_every
84
+ assert sync_every >= 1, "sync_every must be greater than or equal to 1"
85
+
86
+ self._hooks: List[RemovableHandle] = []
87
+
88
+ def __enter__(self) -> "LocalSGD":
89
+ self._hooks.append(
90
+ self._local_optimizer.register_step_pre_hook(self._step_pre_hook)
91
+ )
92
+ # Add optimizer hook which increments the local step counter and syncs if necessary
93
+ self._hooks.append(
94
+ self._local_optimizer.register_step_post_hook(self._step_post_hook)
95
+ )
96
+ return self
97
+
98
+ def __exit__(
99
+ self,
100
+ exc_type: Optional[Type[BaseException]],
101
+ exc_value: Optional[BaseException],
102
+ traceback: Optional[TracebackType],
103
+ ) -> bool:
104
+ # Handle any cleanup or error handling here
105
+ # Clean up hooks
106
+ for hook in self._hooks:
107
+ hook.remove()
108
+ self._hooks.clear()
109
+
110
+ return False # Propagate exceptions
111
+
112
+ def _step_pre_hook(
113
+ self, _optim: optim.Optimizer, _args: Tuple[Any, ...], _kwargs: Dict[str, Any]
114
+ ) -> None:
115
+ # The checkpoint may transfer model parameters, so we need to make access to it thread safe
116
+ self._manager.disallow_state_dict_read()
117
+
118
+ def _step_post_hook(
119
+ self, _optim: optim.Optimizer, _args: Tuple[Any, ...], _kwargs: Dict[str, Any]
120
+ ) -> None:
121
+ """
122
+ This hook is registered on the optimizer and is called after the optimizer step.
123
+ """
124
+ self._manager.allow_state_dict_read()
125
+
126
+ self._local_step += 1
127
+ if self._local_step >= self._sync_every:
128
+ self.sync()
129
+
130
+ def sync(self) -> None:
131
+ """
132
+ Synchronizes and averages the model weights across the manager.
133
+ """
134
+ self._manager.start_quorum()
135
+ self._perform_sync()
136
+ self._local_step = 0
137
+
138
+ def _perform_sync(self) -> None:
139
+ """
140
+ Performs the synchronization of the model weights across the manager.
141
+ """
142
+ averaged_parameters = self._average()
143
+ if self._manager.should_commit():
144
+ # Update the model parameters with the averaged values
145
+ for param, avg_param in zip(self._model.parameters(), averaged_parameters):
146
+ if isinstance(param, DTensor):
147
+ # we averaged the local version of the tensor so need to copy it back as a DTensor
148
+ param.data.copy_(
149
+ DTensor.from_local(
150
+ avg_param,
151
+ param.device_mesh,
152
+ param.placements,
153
+ shape=param.shape,
154
+ stride=param.stride(),
155
+ )
156
+ )
157
+ else:
158
+ param.data.copy_(avg_param)
159
+
160
+ def _average(self) -> list[torch.Tensor]:
161
+ """
162
+ Averages the model parameters across the manager and returns the averaged parameters.
163
+ """
164
+ works = []
165
+ averaged_parameters = []
166
+ for p in self._model.parameters():
167
+ # Create a new tensor to store the averaged parameter
168
+ avg_param = extract_local_tensor(p)
169
+ works.append(self._manager.allreduce(avg_param))
170
+ averaged_parameters.append(avg_param)
171
+ for work in works:
172
+ work.wait()
173
+ return averaged_parameters
174
+
175
+
176
+ class _StreamingDiLoCoFragment:
177
+ bucket_cap_mb: int = 1 * 1024 * 1024 * 1024
178
+ use_bucketization: bool = False
179
+
180
+ def __init__(
181
+ self,
182
+ manager: Manager,
183
+ model_fragment: nn.Module,
184
+ fragment_id: int,
185
+ fragment_sync_offset: int,
186
+ inner_optimizer: optim.Optimizer,
187
+ outer_optimizer: optim.Optimizer,
188
+ sync_every: int,
189
+ backup_device: Optional[torch.device] = None,
190
+ pin_memory: bool = True,
191
+ use_bucketization: bool = False,
192
+ bucket_cap_mb: Optional[int] = None,
193
+ should_quantize: bool = False,
194
+ fragment_sync_delay: int = 0,
195
+ fragment_update_alpha: float = 0.0,
196
+ ) -> None:
197
+ if fragment_sync_offset > sync_every:
198
+ raise ValueError("Fragment must be synced once before `sync_every` steps")
199
+
200
+ self._fragment_id = fragment_id
201
+ self._manager = manager
202
+ self._model_fragment = model_fragment
203
+ self._fragment_sync_offset = fragment_sync_offset
204
+ self._local_optimizer = inner_optimizer
205
+ self._sync_every = sync_every
206
+ assert sync_every >= 1, "sync_every must be greater than or equal to 1"
207
+ self._backup_device = backup_device
208
+ self._pin_memory = pin_memory
209
+ self._fragment_sync_delay = fragment_sync_delay
210
+ self._fragment_update_alpha = fragment_update_alpha
211
+
212
+ self._outer_optimizer = outer_optimizer
213
+
214
+ # Stores pending all reduce
215
+ self._allreduce_work: list[Work] = []
216
+ self._stream: Optional[torch.cuda.Stream] = (
217
+ torch.cuda.Stream() if torch.cuda.is_available() else None
218
+ )
219
+
220
+ # Recorded on `_stream` to wait for allreduce to finish
221
+ self._stop_event: Optional[torch.cuda.Event] = None
222
+
223
+ if bucket_cap_mb is not None:
224
+ self.bucket_cap_mb = int(bucket_cap_mb * 1024 * 1024)
225
+
226
+ if os.getenv(USE_BUCKETIZATION_ENV, "False") == "True":
227
+ self.use_bucketization = True
228
+ else:
229
+ self.use_bucketization = use_bucketization
230
+
231
+ self.should_quantize = should_quantize
232
+
233
+ self._grads: Dict[str, torch.Tensor] = {}
234
+
235
+ # Used to save global parameters so that they can be restored in case
236
+ # commit fails
237
+ self.original_parameters: Dict[str, torch.Tensor] = {}
238
+
239
+ # Used to mix the local and global parameters
240
+ self._local_parameters: Dict[str, torch.Tensor] = {}
241
+
242
+ for name, p in self._model_fragment.named_parameters():
243
+ if isinstance(p, DTensor):
244
+ p = extract_local_tensor(p.data)
245
+
246
+ backup_device = self._backup_device or torch.device("cpu")
247
+ t = torch.empty(*tuple(p.shape), dtype=p.dtype, device=backup_device)
248
+ if (
249
+ self._pin_memory
250
+ and t.device == torch.device("cpu")
251
+ and torch.cuda.is_available()
252
+ ):
253
+ t = t.pin_memory()
254
+ self.original_parameters[name] = t
255
+
256
+ def register_state_dict_fn(self) -> None:
257
+ """
258
+ Register state dict functions for this fragment with the manager.
259
+ This allows for saving and loading the original_parameters during checkpointing and recovery.
260
+
261
+ Args:
262
+ manager: The manager to register with
263
+ fragment_id: Optional identifier for this fragment, used in the key
264
+ """
265
+ # Generate a unique key for this fragment based on the model fragment's name or provided ID
266
+ fragment_key = f"StreamingDiLoCoFragment_{self._fragment_id}"
267
+
268
+ # Define load function for this fragment
269
+ def load_fn(state_dict: Dict[str, Dict[str, torch.Tensor]]) -> None:
270
+ for name, param in state_dict["original_parameters"].items():
271
+ if name in self.original_parameters:
272
+ self.original_parameters[name].copy_(param)
273
+
274
+ self._outer_optimizer.load_state_dict(state_dict["outer_optimizer"])
275
+
276
+ # Define save function for this fragment
277
+ def save_fn() -> Dict[str, Dict[str, torch.Tensor]]:
278
+ return {
279
+ "outer_optimizer": self._outer_optimizer.state_dict(),
280
+ "original_parameters": {
281
+ name: extract_local_tensor(param)
282
+ for name, param in self.original_parameters.items()
283
+ },
284
+ }
285
+
286
+ # Register the functions with the manager
287
+ self._manager.register_state_dict_fn(fragment_key, load_fn, save_fn)
288
+
289
+ @torch.profiler.record_function("torchft::local_sgd::save_parameters")
290
+ def save_parameters(self) -> None:
291
+ with torch.no_grad():
292
+ # TODO: consider running copy on a separate stream
293
+ for name, p in self._model_fragment.named_parameters():
294
+ param_to_local = extract_local_tensor(p.data)
295
+ self.original_parameters[name].copy_(param_to_local, non_blocking=True)
296
+
297
+ def _save_local_parameters(self) -> None:
298
+ """
299
+ Saves a copy of the model's parameters.
300
+ """
301
+ with torch.no_grad():
302
+ for name, p in self._model_fragment.named_parameters():
303
+ self._local_parameters[name] = extract_local_tensor(p.data)
304
+
305
+ @torch.profiler.record_function("torchft::local_sgd::restore_parameters")
306
+ def restore_parameters(self) -> None:
307
+ with torch.no_grad():
308
+ # TODO: consider running copy on a separate stream
309
+ for name, p in self._model_fragment.named_parameters():
310
+ if isinstance(p, DTensor):
311
+ # we averaged the local version of the tensor so need to copy it back as a DTensor
312
+ p.data.copy_(
313
+ DTensor.from_local(
314
+ self.original_parameters[name],
315
+ p.device_mesh,
316
+ p.placements,
317
+ shape=p.shape,
318
+ stride=p.stride(),
319
+ ),
320
+ non_blocking=False,
321
+ )
322
+ else:
323
+ p.data.copy_(self.original_parameters[name], non_blocking=False)
324
+
325
+ def _save_grads(self) -> None:
326
+ """
327
+ Saves pseudo-gradients of the parameters
328
+ """
329
+ with torch.no_grad():
330
+ for name, p in self._model_fragment.named_parameters():
331
+ if isinstance(p, DTensor):
332
+ local_param = p.to_local()
333
+ else:
334
+ local_param = p
335
+ pseudogradient = (
336
+ self.original_parameters[name].to(p.device) - local_param
337
+ )
338
+ self._grads[name] = pseudogradient
339
+
340
+ def _set_grads(self) -> None:
341
+ """
342
+ Sets the gradients of the model fragment from the allreduce result
343
+ """
344
+ with torch.no_grad():
345
+ for name, p in self._model_fragment.named_parameters():
346
+ # avoid copying the gradient, it should be on the same device
347
+ if isinstance(p, DTensor):
348
+ p.grad = DTensor.from_local(
349
+ self._grads[name],
350
+ p.device_mesh,
351
+ p.placements,
352
+ shape=p.shape,
353
+ stride=p.stride(),
354
+ )
355
+ else:
356
+ p.grad = self._grads[name]
357
+
358
+ # No longer needed
359
+ del self._grads[name]
360
+
361
+ def _clear_local_parameters(self) -> None:
362
+ """
363
+ Clears the saved copy of the model's parameters
364
+ """
365
+ self._local_parameters = {}
366
+
367
+ def _merge_parameters(self) -> None:
368
+ """
369
+ Merges the local and global parameters.
370
+ """
371
+ for name, p in self._model_fragment.named_parameters():
372
+ # we averaged the local version of the tensor so need to copy it back as a DTensor
373
+ if isinstance(p, DTensor):
374
+ p.data.lerp_(
375
+ DTensor.from_local(
376
+ self._local_parameters[name],
377
+ p.device_mesh,
378
+ p.placements,
379
+ shape=p.shape,
380
+ stride=p.stride(),
381
+ ),
382
+ self._fragment_update_alpha,
383
+ )
384
+ else:
385
+ p.data.lerp_(self._local_parameters[name], self._fragment_update_alpha)
386
+
387
+ @torch.profiler.record_function("torchft::local_sgd::wait")
388
+ def wait(self) -> None:
389
+ """
390
+ Waits for the previously scheduled allreduce to finish
391
+ """
392
+ if len(self._allreduce_work) == 0:
393
+ return
394
+
395
+ if self._stream is not None:
396
+ assert self._stop_event is not None
397
+ self._stop_event.synchronize()
398
+ self._stop_event = None
399
+
400
+ self._allreduce_work = []
401
+
402
+ @torch.profiler.record_function("torchft::local_sgd::prepare_sync")
403
+ def prepare_sync(self) -> None:
404
+ """
405
+ Calculate the pseugradient, average them across the manager group and starts
406
+ allreduce on the pseudo-gradients but doesn't wait for it to finish.
407
+ """
408
+ self._save_grads()
409
+
410
+ assert len(self._allreduce_work) == 0
411
+
412
+ # Make sure tensors are available to `_stream`
413
+ if self._stream is not None:
414
+ self._stream.wait_stream(torch.cuda.current_stream())
415
+
416
+ with (
417
+ torch.cuda.stream(self._stream)
418
+ if self._stream is not None
419
+ else nullcontext()
420
+ ):
421
+ self._average_grads()
422
+
423
+ @torch.profiler.record_function("torchft::local_sgd::perform_sync")
424
+ def perform_sync(self) -> bool:
425
+ """
426
+ Overrides the sync method to wait for the scheduled allreduce to finish and
427
+ steps using the outer optimizer.
428
+ """
429
+ # Waiting for an allreduce before it has been sent is currently not supported.
430
+ assert len(self._allreduce_work) > 0
431
+
432
+ with (
433
+ torch.cuda.stream(self._stream)
434
+ if self._stream is not None
435
+ else nullcontext()
436
+ ):
437
+ for work in self._allreduce_work:
438
+ work.wait()
439
+
440
+ if self._stream is not None:
441
+ self._stop_event = torch.cuda.Event()
442
+ self._stop_event.record()
443
+
444
+ self.wait()
445
+
446
+ # save the parameters so they can be used for merging
447
+ self._save_local_parameters()
448
+ # Restore the parameters back to the previous state
449
+ self.restore_parameters()
450
+
451
+ # For large values of `fragment_sync_delay`, this call can be
452
+ # a problem.
453
+ #
454
+ # This can return success even if the allreduce failed. Because
455
+ # the process group could have been reconfigured while the
456
+ # allreduce was inflight. The inflight allreduce may or may
457
+ # not have been aborted.
458
+ #
459
+ # We can track errors per allreduce to
460
+ # let the commit fail here. But this has the downside of
461
+ # reconfiguring the pg too many times resulting in
462
+ # more aborts and more commit failures.
463
+ should_commit = self._manager.should_commit()
464
+
465
+ if should_commit:
466
+ # Use the outer optimizer to update the model parameters
467
+ self._set_grads()
468
+ self._outer_optimizer.step()
469
+ self.save_parameters()
470
+ self._merge_parameters()
471
+ self._outer_optimizer.zero_grad()
472
+
473
+ # free up memory
474
+ self._clear_local_parameters()
475
+
476
+ return should_commit
477
+
478
+ def _average_grads(self) -> None:
479
+ """
480
+ Efficiently averages gradients across the group using either:
481
+ - Per-parameter allreduce (old behavior)
482
+ - Bucketized allreduce (new behavior)
483
+ """
484
+ if self.use_bucketization:
485
+ self._allreduce_bucketized()
486
+ else:
487
+ self._allreduce_per_param()
488
+
489
+ def _allreduce_per_param(self) -> None:
490
+ """Performs allreduce on each gradient tensor separately (original method)."""
491
+ for name, p in self._model_fragment.named_parameters():
492
+ # Perform allreduce on the pseudogradients
493
+ work = self._manager.allreduce(
494
+ self._grads[name], should_quantize=self.should_quantize
495
+ )
496
+
497
+ self._allreduce_work.append(work)
498
+
499
+ def _bucketize_and_allreduce(
500
+ self,
501
+ tensors: List[torch.Tensor],
502
+ bucket_size_bytes: int,
503
+ ) -> None:
504
+ """
505
+ Applies allreduce on a list of tensors using bucketization.
506
+
507
+ Args:
508
+ tensors: List of torch tensors (e.g., gradients).
509
+ bucket_size_bytes: Max size of each bucket in bytes.
510
+ """
511
+ if not tensors:
512
+ return
513
+
514
+ total_size = sum(t.numel() for t in tensors)
515
+ dtype, device = tensors[0].dtype, tensors[0].device
516
+
517
+ offset = 0
518
+ flat_index = 0
519
+ while offset < total_size:
520
+ chunk_size = min(
521
+ bucket_size_bytes // tensors[0].element_size(), total_size - offset
522
+ )
523
+ flat_buffer: torch.Tensor = torch.zeros(
524
+ chunk_size, dtype=dtype, device=device
525
+ )
526
+
527
+ pack_offset: int = 0
528
+ bucket_tensors: list[Tuple[torch.Tensor, int, int]] = []
529
+ for t in tensors[flat_index:]:
530
+ numel = t.numel()
531
+ if pack_offset + numel > chunk_size:
532
+ break
533
+ flat_buffer[pack_offset : pack_offset + numel].copy_(t.view(-1))
534
+ bucket_tensors.append((t, pack_offset, numel))
535
+ pack_offset += numel
536
+ flat_index += 1
537
+
538
+ work = self._manager.allreduce(
539
+ flat_buffer, should_quantize=self.should_quantize
540
+ )
541
+
542
+ def callback(
543
+ fut: torch.futures.Future[list[torch.Tensor]],
544
+ ) -> list[torch.Tensor]:
545
+ nonlocal bucket_tensors, flat_buffer
546
+ for t, pack_offset, numel in bucket_tensors:
547
+ t.copy_(flat_buffer[pack_offset : pack_offset + numel].view_as(t))
548
+
549
+ return []
550
+
551
+ fut = work.get_future()
552
+ fut = fut.then(callback)
553
+
554
+ self._allreduce_work.append(work)
555
+
556
+ offset += chunk_size
557
+
558
+ def _allreduce_bucketized(self) -> None:
559
+ """
560
+ Averages gradients using bucketized allreduce with a fixed buffer.
561
+ """
562
+ grads = list(self._grads.values())
563
+ assert len(grads) > 0, "No gradients to allreduce"
564
+ self._bucketize_and_allreduce(
565
+ grads,
566
+ bucket_size_bytes=self.bucket_cap_mb,
567
+ )
568
+
569
+
570
+ class DiLoCo:
571
+ """
572
+ DiLoCo implements distributed optimization by averaging and synchronizing
573
+ pseudogradients (delta of the previous global weight and current local weights).
574
+
575
+ The class implements a more general version of DiLoco, Streaming DiLoCo,
576
+ which synchronizes fragments of pseudogradients at different steps.
577
+
578
+ This algorithm requires a backup copy of the
579
+ weights. By default these are stored in CPU memory. If any error occurs
580
+ during the DiLoCo step, the step will be discarded and the model
581
+ parameters will reset back to the last time DiLoCo synchronized.
582
+
583
+ DiLoCo paper: https://arxiv.org/pdf/2311.08105
584
+ Streaming DiLoCo paper: https://arxiv.org/pdf/2501.18512
585
+ """
586
+
587
+ def __init__(
588
+ self,
589
+ manager: Manager,
590
+ model_fragments: List[nn.Module],
591
+ inner_optimizer: optim.Optimizer,
592
+ # TODO: this is for backward compatibility
593
+ outer_optimizer: optim.Optimizer | list[optim.Optimizer],
594
+ sync_every: int,
595
+ backup_device: Optional[torch.device] = None,
596
+ pin_memory: bool = True,
597
+ use_bucketization: bool = False,
598
+ bucket_cap_mb: Optional[int] = None,
599
+ should_quantize: bool = False,
600
+ fragment_sync_delay: int = 0,
601
+ fragment_update_alpha: float = 0.0,
602
+ ) -> None:
603
+ """
604
+ Args:
605
+ manager: The manager to use.
606
+ model_fragments: The fragments of the model to wrap.
607
+ inner_optimizer: The optimizer used for the local parameters every step.
608
+ outer_optimizer: The optimizer used for the global parameters updated every "sync_every" steps.
609
+ sync_every: How often to update the model weights.
610
+ backup_device: The device to store the backup weights on. If None, the backup weights will be on CPU.
611
+ pin_memory: Whether to pin the memory for the backup weights (only for CPU device).
612
+ should_quantize: Whether to quantize the gradients before allreduce.
613
+ fragment_sync_delay: Controls the number of inner steps to wait before blocking on a fragment's
614
+ synchronization. This is the "tao" parameter in the Streaming DiLoCo paper.
615
+ fragment_update_alpha: Determines how to mix the local and global optimized parameters
616
+ """
617
+
618
+ if isinstance(outer_optimizer, list):
619
+ assert len(outer_optimizer) == len(
620
+ model_fragments
621
+ ), "The number of outer optimizers must match the number of model fragments"
622
+
623
+ if manager._use_async_quorum:
624
+ raise ValueError(
625
+ "Using DiLoCo require synchronous quorum to be enabled. "
626
+ "Ensure that the manager is initialized with use_async_quorum=False"
627
+ )
628
+
629
+ if sync_every < len(model_fragments):
630
+ raise ValueError("Only 1 fragment can be syncrhonized at a time")
631
+
632
+ if sync_every % len(model_fragments) != 0:
633
+ raise ValueError("sync_every must divide the number of fragments")
634
+
635
+ self._sync_every: int = sync_every // len(model_fragments)
636
+ if fragment_sync_delay >= self._sync_every:
637
+ raise ValueError(
638
+ "Fragment must be synced before it is reduced another time"
639
+ )
640
+
641
+ if fragment_update_alpha < 0 or fragment_update_alpha > 1:
642
+ raise ValueError("fragment_update_alpha must be between 0 and 1")
643
+
644
+ super().__init__()
645
+ self._manager = manager
646
+
647
+ # The number of training iterations performed.
648
+ # Used to synchronize which fragment to send across all
649
+ # replicas
650
+ self._local_step = 0
651
+
652
+ self._fragment_sync_delay = fragment_sync_delay
653
+
654
+ self._hooks: List[RemovableHandle] = []
655
+
656
+ self._local_optimizer = inner_optimizer
657
+
658
+ self._fragments: List[_StreamingDiLoCoFragment] = [
659
+ _StreamingDiLoCoFragment(
660
+ manager,
661
+ model_fragment,
662
+ i,
663
+ math.floor((sync_every / len(model_fragments)) * (i + 1)),
664
+ inner_optimizer,
665
+ (
666
+ outer_optimizer[i]
667
+ if isinstance(outer_optimizer, list)
668
+ else outer_optimizer
669
+ ),
670
+ sync_every,
671
+ backup_device,
672
+ pin_memory,
673
+ use_bucketization,
674
+ bucket_cap_mb,
675
+ should_quantize,
676
+ fragment_sync_delay,
677
+ fragment_update_alpha,
678
+ )
679
+ for i, model_fragment in enumerate(model_fragments)
680
+ ]
681
+
682
+ # This is to make sure we adhere to the assumptions made by the
683
+ # `_StreamingDiLoCoFragment` about the fragment sync schedule.
684
+ assert fragment_sync_delay < sync_every // len(model_fragments)
685
+
686
+ # Need to copy the parameters to the host to be safe if we are on the first step.
687
+ self._save_parameters()
688
+ self._register_state_dict_fn()
689
+
690
+ def _register_state_dict_fn(self) -> None:
691
+ for fragment in self._fragments:
692
+ fragment.register_state_dict_fn()
693
+
694
+ def _save_parameters(self) -> None:
695
+ for fragment in self._fragments:
696
+ fragment.save_parameters()
697
+
698
+ def _restore_parameters(self) -> None:
699
+ for fragment in self._fragments:
700
+ fragment.restore_parameters()
701
+
702
+ def __enter__(self) -> "DiLoCo":
703
+ self._hooks.append(
704
+ self._local_optimizer.register_step_pre_hook(self._step_pre_hook)
705
+ )
706
+ # Add optimizer hook which increments the local step counter and syncs if necessary
707
+ self._hooks.append(
708
+ self._local_optimizer.register_step_post_hook(self._step_post_hook)
709
+ )
710
+ return self
711
+
712
+ def _step_pre_hook(
713
+ self, _optim: optim.Optimizer, _args: Tuple[Any, ...], _kwargs: Dict[str, Any]
714
+ ) -> None:
715
+ # The checkpoint may transfer model parameters, so we need to make access to it thread safe
716
+ self._manager.disallow_state_dict_read()
717
+
718
+ def __exit__(
719
+ self,
720
+ exc_type: Optional[Type[BaseException]],
721
+ exc_value: Optional[BaseException],
722
+ traceback: Optional[TracebackType],
723
+ ) -> bool:
724
+ # Handle any cleanup or error handling here
725
+ # Clean up hooks
726
+ for hook in self._hooks:
727
+ hook.remove()
728
+ self._hooks.clear()
729
+
730
+ return False # Propagate exceptions
731
+
732
+ def _wait(self) -> None:
733
+ """
734
+ Waits for allreduce to finish on all fragments
735
+ """
736
+ for fragment in self._fragments:
737
+ fragment.wait()
738
+
739
+ def _current_fragment(self) -> int:
740
+ """
741
+ Determines which fragment to prepare/sync based on the current step.
742
+ """
743
+ step = self._manager.current_step()
744
+ return step % len(self._fragments)
745
+
746
+ def _step_post_hook(
747
+ self, _optim: optim.Optimizer, _args: Tuple[Any, ...], _kwargs: Dict[str, Any]
748
+ ) -> None:
749
+ """
750
+ This hook is registered on the optimizer and is called after the optimizer step.
751
+ """
752
+ self._manager.allow_state_dict_read()
753
+
754
+ # We need to make sure all nodes send the same fragments in order.
755
+ # This is to avoid deadlocking e.g.
756
+ #
757
+ # 1. Step 1 - Node A sends fragment 1
758
+ # 2. Step 1 - Node B sends fragment 2
759
+ # 3. Step 2 - Node A waits for fragment 1
760
+ # 4. Step 2 - Node B waits for fragment 2
761
+ #
762
+ # Both of them will fail because Node A didn't send fragment 2
763
+ # and Node B didn't send fragment 1.
764
+ self._local_step += 1
765
+
766
+ if self._local_step == self._sync_every - self._fragment_sync_delay:
767
+ # Time to prepare a fragment
768
+ #
769
+ # Some replicas will get the same copy of the model, implying batches
770
+ # can be overrepresented.
771
+ self._manager.start_quorum()
772
+ fragment = self._current_fragment()
773
+ logger.info(f"Preparing fragment={fragment} step={self._local_step}")
774
+ self._fragments[fragment].prepare_sync()
775
+
776
+ if self._local_step < self._sync_every:
777
+ return
778
+
779
+ if self._local_step == self._sync_every:
780
+ # Time to sync a fragment
781
+ fragment = self._current_fragment()
782
+ logger.info(
783
+ f"Syncing fragment={fragment} step={self._local_step} manager_step={self._manager.current_step()}"
784
+ )
785
+ self._fragments[fragment].perform_sync()
786
+
787
+ # If the allreduce truly failed, we'll keep retrying this fragment.
788
+ # We reset the parameters upon failure. We'll skip over some data
789
+ # but we won't over train before syncing.
790
+
791
+ self._local_step = 0
792
+ return
793
+
794
+ assert (
795
+ False
796
+ ), f"{self._local_step=} should never be greater than {self._sync_every=}"