torchft-nightly 2026.1.3__cp310-cp310-manylinux_2_24_x86_64.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (61) hide show
  1. torchft/__init__.py +34 -0
  2. torchft/_test/diloco_trainer.py +287 -0
  3. torchft/_test/managed_work_test.py +320 -0
  4. torchft/_test_utils.py +111 -0
  5. torchft/_torchft.cpython-310-x86_64-linux-gnu.so +0 -0
  6. torchft/_torchft.pyi +116 -0
  7. torchft/checkpointing/__init__.py +20 -0
  8. torchft/checkpointing/_rwlock.py +136 -0
  9. torchft/checkpointing/_serialization.py +39 -0
  10. torchft/checkpointing/http_transport.py +299 -0
  11. torchft/checkpointing/http_transport_bench.py +61 -0
  12. torchft/checkpointing/http_transport_test.py +146 -0
  13. torchft/checkpointing/pg_transport.py +306 -0
  14. torchft/checkpointing/pg_transport_bench.py +99 -0
  15. torchft/checkpointing/pg_transport_test.py +101 -0
  16. torchft/checkpointing/rwlock_test.py +58 -0
  17. torchft/checkpointing/transport.py +68 -0
  18. torchft/checkpointing/transport_test.py +161 -0
  19. torchft/collectives.py +415 -0
  20. torchft/collectives_test.py +212 -0
  21. torchft/coordination.py +39 -0
  22. torchft/coordination_test.py +29 -0
  23. torchft/data.py +77 -0
  24. torchft/data_test.py +39 -0
  25. torchft/ddp.py +105 -0
  26. torchft/ddp_test.py +68 -0
  27. torchft/diloco_regression_test.py +644 -0
  28. torchft/examples/slurm/README.md +34 -0
  29. torchft/examples/slurm/punisher.py +95 -0
  30. torchft/examples/slurm/runner.py +221 -0
  31. torchft/fsdp_test.py +102 -0
  32. torchft/futures.py +353 -0
  33. torchft/futures_test.py +140 -0
  34. torchft/http.py +13 -0
  35. torchft/lighthouse_test.py +163 -0
  36. torchft/local_sgd.py +796 -0
  37. torchft/local_sgd_integ_test.py +600 -0
  38. torchft/local_sgd_test.py +324 -0
  39. torchft/manager.py +1358 -0
  40. torchft/manager_integ_test.py +653 -0
  41. torchft/manager_test.py +911 -0
  42. torchft/multiprocessing.py +38 -0
  43. torchft/multiprocessing_dummy_context.py +135 -0
  44. torchft/multiprocessing_test.py +58 -0
  45. torchft/optim.py +63 -0
  46. torchft/optim_test.py +50 -0
  47. torchft/otel.py +134 -0
  48. torchft/parameter_server.py +195 -0
  49. torchft/parameter_server_test.py +47 -0
  50. torchft/process_group.py +2118 -0
  51. torchft/process_group_test.py +1028 -0
  52. torchft/quantization.py +686 -0
  53. torchft/quantization_test.py +131 -0
  54. torchft/torchx.py +89 -0
  55. torchft/utils.py +67 -0
  56. torchft/work.py +26 -0
  57. torchft_nightly-2026.1.3.dist-info/METADATA +308 -0
  58. torchft_nightly-2026.1.3.dist-info/RECORD +61 -0
  59. torchft_nightly-2026.1.3.dist-info/WHEEL +4 -0
  60. torchft_nightly-2026.1.3.dist-info/entry_points.txt +2 -0
  61. torchft_nightly-2026.1.3.dist-info/licenses/LICENSE +34 -0
@@ -0,0 +1,644 @@
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the BSD-style license found in the
5
+ # LICENSE file in the root directory of this source tree.
6
+
7
+ import argparse
8
+ import copy
9
+ import json
10
+ import logging
11
+ import os
12
+ import sys
13
+ import threading
14
+ from concurrent.futures import as_completed, ThreadPoolExecutor
15
+ from contextlib import ExitStack
16
+ from datetime import timedelta
17
+ from pathlib import Path
18
+ from typing import Any, Callable, cast, Dict, List, Optional, overload, Tuple
19
+ from unittest import skipIf, TestCase
20
+
21
+ import torch
22
+ from parameterized import parameterized
23
+ from torch import nn, optim
24
+
25
+ from torchft._test.diloco_trainer import DiLoCoTrainer, MultiModel
26
+ from torchft._torchft import LighthouseServer
27
+ from torchft.local_sgd import DiLoCo
28
+ from torchft.manager import Manager
29
+ from torchft.manager_integ_test import EventInjector, EventInjectorEvent, Runner
30
+
31
+ logging.basicConfig(level=logging.INFO)
32
+ logger: logging.Logger = logging.getLogger(__name__)
33
+
34
+
35
+ def handle_fixture(
36
+ fixture_filename: str,
37
+ results: list[list[Dict[str, Dict[int, Dict[str, List[float]]]]]],
38
+ ) -> Optional[list[list[Dict[str, Dict[str, Dict[str, List[float]]]]]]]:
39
+ """
40
+ Handle reading from or writing to a fixture file.
41
+
42
+ Args:
43
+ fixture_filename: The name of the fixture file (without path)
44
+ results: The results to write to the fixture file if in write mode
45
+
46
+ Returns:
47
+ The fixture data when reading, None when writing
48
+ """
49
+ script_directory = os.path.dirname(os.path.abspath(__file__))
50
+ root_directory = os.path.dirname(script_directory)
51
+
52
+ fixture_path = os.path.join(root_directory, "test_fixtures", fixture_filename)
53
+
54
+ write_fixture = os.environ.get("WRITE_FIXTURE", "false").lower() in ("true")
55
+
56
+ if write_fixture:
57
+ # Write results to fixture file
58
+ logger.info(f"Writing fixture to {fixture_path}")
59
+ with open(fixture_path, "w+") as f:
60
+ json.dump(results, f, indent=2)
61
+ return None
62
+
63
+ # Read fixture file and return the data
64
+ assert os.path.exists(fixture_path), f"Fixture file {fixture_path} does not exist"
65
+ logger.info(f"Validating against fixture at {fixture_path}")
66
+ with open(fixture_path, "r") as f:
67
+ fixture_data = json.load(f)
68
+
69
+ return fixture_data
70
+
71
+
72
+ class MockLinear(nn.Module):
73
+ """
74
+ A mock linear layer with deterministic parameter updates.
75
+ """
76
+
77
+ def __init__(self, in_features: int, out_features: int) -> None:
78
+ super().__init__()
79
+ # Initialize with specific values to make tracking easier
80
+ self.weight = nn.Parameter(torch.ones(out_features, in_features))
81
+
82
+ # Fixed gradients for deterministic updates
83
+ self.weight_grad_value = 2
84
+
85
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
86
+ # We don't actually do a forward pass, this should not be called
87
+ raise
88
+
89
+
90
+ class MockModel(MultiModel):
91
+ """
92
+ A mock model with deterministic parameter updates.
93
+ """
94
+
95
+ def __init__(self, in_dim: int = 3, out_dim: int = 4, n_layers: int = 1) -> None:
96
+ super().__init__()
97
+
98
+ for _ in range(n_layers):
99
+ # We don't care about matching dimensionality, we're not going to pass any
100
+ # input through the model
101
+ self.layers.append(MockLinear(in_dim, out_dim))
102
+
103
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
104
+ # We don't actually do a forward pass, this should not be called
105
+ raise
106
+
107
+
108
+ class MockOptimizer(optim.Optimizer):
109
+ """
110
+ A mock optimizer with deterministic parameter updates.
111
+ """
112
+
113
+ from typing import Iterator
114
+
115
+ def __init__(self, params: Iterator[torch.nn.Parameter], lr: float = 0.1) -> None:
116
+ defaults = dict(lr=lr)
117
+ super(MockOptimizer, self).__init__(params, defaults)
118
+
119
+ @overload
120
+ def step(self, closure: None = None) -> None: ...
121
+
122
+ @overload
123
+ def step(self, closure: Callable[[], float]) -> float: ...
124
+
125
+ def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
126
+ for group in self.param_groups:
127
+ for p in group["params"]:
128
+ if p.grad is None:
129
+ continue
130
+
131
+ # Apply a fixed update rule: subtract lr * grad
132
+ p.data.add_(p.grad.data, alpha=-group["lr"])
133
+
134
+
135
+ class MockDiLoCoTrainer(DiLoCoTrainer):
136
+ """
137
+ A customized DiLoCoTrainer that uses mock components for deterministic parameter updates.
138
+ """
139
+
140
+ def __init__(
141
+ self,
142
+ rank: int,
143
+ store_port: int,
144
+ device: torch.device,
145
+ runner: Runner,
146
+ model_state_dict: dict[str, Any],
147
+ n_fragments: int,
148
+ diloco_args: dict[str, Any],
149
+ inner_lr: float = 1,
150
+ outer_lr: float = 2,
151
+ quorum_barrier: Optional[threading.Barrier] = None,
152
+ ) -> None:
153
+ self.inner_lr = inner_lr
154
+ self.outer_lr = outer_lr
155
+
156
+ # Call parent constructor
157
+ super().__init__(
158
+ rank, store_port, device, runner, model_state_dict, n_fragments, diloco_args
159
+ )
160
+
161
+ self.quorum_barrier = quorum_barrier
162
+
163
+ def setup_model(self) -> MockModel:
164
+ """Set up the mock model and move it to the device."""
165
+ model = MockModel(in_dim=1, out_dim=1, n_layers=self.n_fragments)
166
+ model.load_state_dict(self.model_state_dict)
167
+ model.to(self.device)
168
+ return model
169
+
170
+ def setup_inner_optimizer(self) -> torch.optim.Optimizer:
171
+ """Set up the mock inner optimizer."""
172
+ return MockOptimizer(self.model.parameters(), lr=self.inner_lr)
173
+
174
+ def setup_outer_optimizers(self) -> list[torch.optim.Optimizer]:
175
+ """Set up mock outer optimizers."""
176
+ outer_optimizers = []
177
+ for i in range(self.n_fragments):
178
+ outer_optimizers.append(
179
+ MockOptimizer(self.model.layers[i].parameters(), lr=self.outer_lr)
180
+ )
181
+ return outer_optimizers
182
+
183
+ def train_loop(self) -> Dict[str, Any]:
184
+ """Run the training loop with mocked components."""
185
+ # Ensure sync_every is set in diloco_args
186
+ if "sync_every" not in self.diloco_args:
187
+ self.diloco_args["sync_every"] = 2
188
+
189
+ parameter_history = {"history": {}, "global_parameter_history": {}}
190
+
191
+ with DiLoCo(
192
+ self.manager,
193
+ [layer for layer in self.model.layers],
194
+ self.inner_optimizer,
195
+ self.outer_optimizers,
196
+ backup_device=self.device,
197
+ **self.diloco_args,
198
+ ) as self.diloco:
199
+ if self.quorum_barrier is not None:
200
+ self.manager.start_quorum()
201
+ self.manager.wait_quorum()
202
+ assert self.quorum_barrier is not None
203
+ self.quorum_barrier.wait()
204
+ assert self.manager.should_commit()
205
+ assert self.manager.should_commit()
206
+
207
+ local_step = 0
208
+ manager_steps = set()
209
+ while True:
210
+ # Capture parameters before each step
211
+ step_params = {}
212
+ for name, param in self.model.named_parameters():
213
+ step_params[name] = param.data.clone().detach().cpu().tolist()
214
+ parameter_history["history"][local_step] = step_params
215
+
216
+ manager_curr_step = self.manager.current_step()
217
+
218
+ if manager_curr_step == 7:
219
+ break
220
+
221
+ if manager_curr_step not in manager_steps:
222
+ # Store the manager state dict, converting to the right type
223
+ state_dict = copy.deepcopy(self.manager._manager_state_dict())
224
+ user_state_dict = cast(dict[str, object], state_dict["user"])
225
+ parameter_history["global_parameter_history"][local_step] = {}
226
+
227
+ for i in range(self.n_fragments):
228
+ value = cast(
229
+ dict[str, dict[str, torch.Tensor]],
230
+ user_state_dict[f"StreamingDiLoCoFragment_{i}"],
231
+ )
232
+ parameter_history["global_parameter_history"][local_step][
233
+ f"layers.{i}.weight"
234
+ ] = (
235
+ value["original_parameters"]["weight"]
236
+ .data.clone()
237
+ .detach()
238
+ .cpu()
239
+ .tolist()
240
+ )
241
+
242
+ manager_steps.add(manager_curr_step)
243
+
244
+ # For each parameter, set a deterministic gradient
245
+ for _, layer in enumerate(self.model.layers):
246
+ if isinstance(layer, MockLinear):
247
+ # Set fixed gradients
248
+ layer.weight.grad = (
249
+ torch.ones_like(layer.weight) * layer.weight_grad_value
250
+ )
251
+
252
+ # Step with deterministic updates
253
+ self.inner_optimizer.step()
254
+
255
+ self.runner.event_injector.check(self.rank, self.manager.current_step())
256
+ local_step += 1
257
+
258
+ return parameter_history
259
+
260
+
261
+ def mock_diloco_train_loop(
262
+ rank: int,
263
+ store_port: int,
264
+ device: torch.device,
265
+ runner: Runner,
266
+ train_loop_args: dict[str, Any] = {},
267
+ ) -> Dict[str, Dict[int, Dict[str, List[float]]]]:
268
+ """
269
+ Training loop with mocked components for deterministic parameter updates.
270
+ Uses MockDiLoCoTrainer to handle the training process.
271
+ """
272
+ model_state_dict = train_loop_args.get("model_state_dict", {})
273
+ n_fragments = train_loop_args.get("n_fragments", 1)
274
+ diloco_args = train_loop_args.get("diloco_args", {})
275
+ quorum_barrier = train_loop_args.get("quorum_barrier", None)
276
+
277
+ with ExitStack() as stack:
278
+ trainer = MockDiLoCoTrainer(
279
+ rank,
280
+ store_port,
281
+ device,
282
+ runner,
283
+ model_state_dict,
284
+ n_fragments,
285
+ diloco_args,
286
+ quorum_barrier=quorum_barrier,
287
+ )
288
+ stack.callback(trainer.manager.shutdown)
289
+ return trainer.train_loop()
290
+ return {}
291
+
292
+
293
+ class DiLoCoMockedUpdateTest(TestCase):
294
+ @parameterized.expand(
295
+ [
296
+ # Format: (use_cuda, n_fragments, fragment_sync_delay, fragment_update_alpha)
297
+ (False, 2, 0, 0), # 2 fragments, no delay, 0% mixing
298
+ (False, 2, 0, 0.5), # 2 fragments, no delay, 50% mixing
299
+ (False, 2, 0, 1), # 2 fragments, no delay, 100% mixing
300
+ (False, 2, 1, 0), # 2 fragments, with delay, 0% mixing
301
+ (False, 2, 1, 0.5), # 2 fragments, with delay, 50% mixing
302
+ (False, 2, 1, 1), # 2 fragments, with delay, 100% mixing
303
+ ]
304
+ )
305
+ def test_diloco_mocked_updates(
306
+ self,
307
+ use_cuda: bool,
308
+ n_fragments: int,
309
+ fragment_sync_delay: int,
310
+ fragment_update_alpha: float,
311
+ ) -> None:
312
+ """
313
+ Test that validates the model parameters are correctly updated by DiLoCo
314
+ using mocked components for deterministic updates with different configurations:
315
+ - n_fragments: Number of model fragments (1 or 2)
316
+ - fragment_sync_delay: Delay between preparing and syncing fragments (0 or 1)
317
+ - fragment_update_alpha: Controls mixing of local and global parameters (0.0, 0.5, or 1.0)
318
+ """
319
+ # Skip the test if use_cuda is True and there are not enough GPUs
320
+ if use_cuda and torch.cuda.device_count() < 2:
321
+ self.skipTest("Not enough GPUs for CUDA test")
322
+ if sys.platform == "darwin":
323
+ self.skipTest("not reliable on mac")
324
+
325
+ lighthouse = LighthouseServer(bind="[::]:0", min_replicas=2)
326
+ sync_every = 6
327
+ num_replicas = 2
328
+ futures = []
329
+
330
+ torch.manual_seed(42)
331
+ # Initialize the model with the specified number of fragments
332
+ # Create a proper state_dict for the model to avoid load_state_dict errors
333
+ temp_model = MockModel(in_dim=1, out_dim=1, n_layers=n_fragments)
334
+ model_state_dict = temp_model.state_dict()
335
+ quorum_barrier = threading.Barrier(num_replicas)
336
+
337
+ with ThreadPoolExecutor(max_workers=num_replicas) as executor:
338
+ for replica_id in range(num_replicas):
339
+ event_injector = EventInjector()
340
+ runner = Runner(
341
+ replica_id=replica_id,
342
+ num_replicas=num_replicas,
343
+ lighthouse_address=lighthouse.address(),
344
+ event_injector=event_injector,
345
+ train_loop=mock_diloco_train_loop,
346
+ use_cuda=use_cuda,
347
+ train_loop_args={
348
+ "quorum_barrier": quorum_barrier,
349
+ "n_fragments": n_fragments,
350
+ "model_state_dict": model_state_dict,
351
+ "diloco_args": {
352
+ "sync_every": sync_every,
353
+ "fragment_sync_delay": fragment_sync_delay,
354
+ "fragment_update_alpha": fragment_update_alpha,
355
+ },
356
+ },
357
+ )
358
+ futures.append(executor.submit(runner.run_replica))
359
+
360
+ for fut in as_completed(futures):
361
+ continue
362
+
363
+ results = []
364
+ for fut in futures:
365
+ results.append(fut.result())
366
+
367
+ lighthouse.shutdown()
368
+
369
+ # Check results against fixture or validate parameter updates
370
+ compared_with_fixture = self._check_against_fixture(results)
371
+
372
+ if not compared_with_fixture:
373
+ # If no fixture comparison was done, validate parameters directly
374
+ self._validate_parameter_updates(
375
+ results[0][0],
376
+ n_fragments,
377
+ sync_every,
378
+ fragment_sync_delay,
379
+ fragment_update_alpha,
380
+ )
381
+
382
+ @parameterized.expand(
383
+ [
384
+ # Format: (use_cuda, n_fragments, fragment_sync_delay, fragment_update_alpha)
385
+ (False, 2, 0, 0), # 2 fragments, no delay, 0% mixing
386
+ ]
387
+ )
388
+ def test_diloco_mocked_failure_recovery(
389
+ self,
390
+ use_cuda: bool,
391
+ n_fragments: int,
392
+ fragment_sync_delay: int,
393
+ fragment_update_alpha: float,
394
+ ) -> None:
395
+ """
396
+ Test that validates DiLoCo can recover from a replica failure.
397
+ One replica is set to fail at step 2, and the test verifies that
398
+ the system recovers and parameters are correctly synchronized after recovery.
399
+ """
400
+ # Skip the test if use_cuda is True and there are not enough GPUs
401
+ if use_cuda and torch.cuda.device_count() < 2:
402
+ self.skipTest("Not enough GPUs for CUDA test")
403
+ if sys.platform == "darwin":
404
+ self.skipTest("not reliable on mac")
405
+
406
+ lighthouse = LighthouseServer(bind="[::]:0", min_replicas=2)
407
+ sync_every = 6
408
+ num_replicas = 2
409
+ futures = []
410
+
411
+ # Create event injectors - make the second replica fail at step 2
412
+ event_injectors = [
413
+ EventInjector(), # First replica runs normally
414
+ EventInjector().fail_at(0, 2), # Second replica fails at step 2
415
+ ]
416
+
417
+ torch.manual_seed(42)
418
+ # Initialize the model with the specified number of fragments
419
+ temp_model = MockModel(in_dim=1, out_dim=1, n_layers=n_fragments)
420
+ model_state_dict = temp_model.state_dict()
421
+
422
+ with ThreadPoolExecutor(max_workers=num_replicas) as executor:
423
+ for replica_id, event_injector in zip(range(num_replicas), event_injectors):
424
+ runner = Runner(
425
+ replica_id=replica_id,
426
+ num_replicas=num_replicas,
427
+ lighthouse_address=lighthouse.address(),
428
+ event_injector=event_injector,
429
+ train_loop=mock_diloco_train_loop,
430
+ use_cuda=use_cuda,
431
+ train_loop_args={
432
+ "n_fragments": n_fragments,
433
+ "model_state_dict": model_state_dict,
434
+ "diloco_args": {
435
+ "sync_every": sync_every,
436
+ "fragment_sync_delay": fragment_sync_delay,
437
+ "fragment_update_alpha": fragment_update_alpha,
438
+ },
439
+ },
440
+ )
441
+ futures.append(executor.submit(runner.run_replica))
442
+
443
+ # Wait for all futures to complete
444
+ for fut in as_completed(futures):
445
+ continue
446
+
447
+ results = []
448
+ for fut in futures:
449
+ try:
450
+ results.append(fut.result())
451
+ except Exception as e:
452
+ print(f"Error in replica: {e}")
453
+ raise
454
+
455
+ lighthouse.shutdown()
456
+
457
+ # Check results against fixture or validate failure recovery
458
+ compared_with_fixture = self._check_against_fixture(results)
459
+
460
+ if not compared_with_fixture:
461
+ # Verify that the failure was injected
462
+ self.assertEqual(
463
+ event_injectors[1].count[EventInjectorEvent.Failure],
464
+ 1,
465
+ "Expected one failure event to be injected",
466
+ )
467
+
468
+ # Verify that both replicas have the same global parameters at the end
469
+ # Extract the global parameter history from both replicas
470
+ rep0_global = results[0][0]["global_parameter_history"]
471
+ rep1_global = results[1][0]["global_parameter_history"]
472
+
473
+ # Get the last step in both histories
474
+ last_step_rep0 = max(int(step) for step in rep0_global.keys())
475
+ last_step_rep1 = max(int(step) for step in rep1_global.keys())
476
+
477
+ # Compare the global parameters at the last step
478
+ for param_name in rep0_global[last_step_rep0].keys():
479
+ rep0_param = torch.tensor(rep0_global[last_step_rep0][param_name])
480
+ rep1_param = torch.tensor(rep1_global[last_step_rep1][param_name])
481
+
482
+ self.assertTrue(
483
+ torch.allclose(rep0_param, rep1_param, rtol=1e-5, atol=1e-5),
484
+ f"Global parameters don't match at the end: {rep0_param} vs {rep1_param} for {param_name}",
485
+ )
486
+
487
+ def _check_against_fixture(
488
+ self, results: list[list[Dict[str, Dict[int, Dict[str, List[float]]]]]]
489
+ ) -> bool:
490
+ """
491
+ Check test results against fixture data.
492
+
493
+ Args:
494
+ results: The test results to check against fixture
495
+
496
+ Returns:
497
+ bool: True if comparison with fixture was performed, False otherwise
498
+ """
499
+ # Handle fixture reading/writing
500
+ fixture_data = handle_fixture(f"{self.id()}.json", results)
501
+
502
+ # If no fixture exists we can't compare results
503
+ if fixture_data is None:
504
+ return False
505
+
506
+ # Compare fixture data with current results
507
+ for replica_idx, (fixture_history, current_history) in enumerate(
508
+ zip(fixture_data, results)
509
+ ):
510
+ fixture_history = fixture_history[0]["history"]
511
+ current_history = current_history[0]["history"]
512
+ for step, fixture_params in fixture_history.items():
513
+ for param_name, fixture_values in fixture_params.items():
514
+ current_values = current_history[int(step)][param_name]
515
+ # Convert to tensors for comparison with tolerance
516
+ fixture_tensor = torch.tensor(fixture_values)
517
+ current_tensor = torch.tensor(current_values)
518
+ self.assertTrue(
519
+ torch.allclose(
520
+ fixture_tensor, current_tensor, rtol=1e-5, atol=1e-5
521
+ ),
522
+ f"{fixture_tensor} is not the same as {current_tensor} for {param_name} at step {step}",
523
+ )
524
+
525
+ return True
526
+
527
+ def _validate_parameter_updates(
528
+ self,
529
+ parameter_history: Dict[str, Dict[int, Dict[str, List[float]]]],
530
+ n_fragments: int,
531
+ sync_every: int,
532
+ fragment_sync_delay: int,
533
+ fragment_update_alpha: float,
534
+ ) -> None:
535
+ """
536
+ Validate that model parameters are updated as expected according to DiLoCo algorithm.
537
+ Validates both regular steps (inner optimizer updates) and sync steps (outer optimizer updates).
538
+
539
+ Args:
540
+ history: Parameter history for a replica
541
+ num_replicas: Total number of replicas
542
+ n_fragments: Number of model fragments
543
+ sync_every: How often to sync parameters
544
+ fragment_sync_delay: Delay between preparing and syncing fragments
545
+ fragment_update_alpha: Controls mixing of local and global parameters
546
+ """
547
+ # Sync happens every sync_every steps for each fragment
548
+ sync_every_fragment = sync_every // n_fragments
549
+
550
+ history = parameter_history["history"]
551
+ global_parameter_history = parameter_history["global_parameter_history"]
552
+
553
+ # For each step in history, validate parameter updates
554
+ for step in range(1, 16): # Skip step 0 (initial state)
555
+ for fragment_param_name in history[step].keys():
556
+ # Get current parameters
557
+ fragment_idx = int(fragment_param_name.split(".")[1]) + 1
558
+ current_params = torch.tensor(history[step][fragment_param_name])
559
+
560
+ # Determine if this is a sync step for this fragment
561
+ # In DiLoCo, fragments are synced in a round-robin fashion
562
+ # Fragment i is synced at steps: i*sync_every_fragment + k*sync_every
563
+ # where k is a non-negative integer
564
+ is_sync_step = (
565
+ step - fragment_idx * sync_every_fragment
566
+ ) % sync_every == 0
567
+
568
+ if is_sync_step:
569
+ # This is a sync step for this fragment
570
+ # Find the previous sync step for this fragment
571
+ prev_sync_step = max(step - sync_every, 0)
572
+
573
+ # Find the prepare step for this fragment (when pseudogradients were calculated)
574
+ prepare_step = step - fragment_sync_delay
575
+
576
+ # Parameters at the previous sync step (global parameters before update)
577
+ prev_sync_params = torch.tensor(
578
+ global_parameter_history[prev_sync_step][fragment_param_name]
579
+ )
580
+
581
+ # Parameters at the prepare step (before allreduce)
582
+ prepare_params = (
583
+ torch.tensor(history[prepare_step - 1][fragment_param_name]) - 2
584
+ ) # inner_lr (1) * weight_grad_value (2)
585
+
586
+ # Calculate pseudogradient (difference between global and local params)
587
+ pseudogradient = prev_sync_params - prepare_params
588
+
589
+ # After allreduce, pseudogradient is averaged across replicas
590
+ # In our mock setup, all replicas have the same gradient, so no averaging is needed
591
+ averaged_pseudogradient = pseudogradient
592
+
593
+ # Outer optimizer applies this pseudogradient with its learning rate
594
+ outer_lr = 2
595
+
596
+ # Calculate expected global parameters after outer optimizer update
597
+ expected_global_params = (
598
+ prev_sync_params - outer_lr * averaged_pseudogradient
599
+ )
600
+
601
+ prev_params = torch.tensor(history[step - 1][fragment_param_name])
602
+ local_params = (
603
+ prev_params - 2
604
+ ) # inner_lr (1) * weight_grad_value (2)
605
+
606
+ # lerp: result = global_params * fragment_update_alpha + local_params * (1 - fragment_update_alpha)
607
+ expected_params = (
608
+ local_params * fragment_update_alpha
609
+ + expected_global_params * (1 - fragment_update_alpha)
610
+ )
611
+
612
+ # Validate synced parameters
613
+ self.assertTrue(
614
+ torch.allclose(
615
+ current_params, expected_params, rtol=1e-5, atol=1e-5
616
+ ),
617
+ f"Parameters at sync step {step} for fragment {fragment_param_name} "
618
+ f"don't match expected: {current_params} vs {expected_params}. "
619
+ f"{prepare_params=}, {prev_sync_params=}, {pseudogradient=}, {averaged_pseudogradient=}, {expected_global_params=}",
620
+ )
621
+ else:
622
+ # Get previous parameters
623
+ prev_params = torch.tensor(history[step - 1][fragment_param_name])
624
+
625
+ # Regular step (inner optimizer update)
626
+ # In our mock setup, each step parameters change by -lr * grad = -1 * 2 = -2
627
+ expected_params = (
628
+ prev_params - 2
629
+ ) # inner_lr (1) * weight_grad_value (2)
630
+
631
+ # Validate synced parameters
632
+ self.assertTrue(
633
+ torch.allclose(
634
+ current_params, expected_params, rtol=1e-5, atol=1e-5
635
+ ),
636
+ f"Parameters at sync step {step} for fragment {fragment_param_name} "
637
+ f"don't match expected: {current_params} vs {expected_params}. ",
638
+ )
639
+
640
+
641
+ if __name__ == "__main__":
642
+ import unittest
643
+
644
+ unittest.main()
@@ -0,0 +1,34 @@
1
+ ## Launch lighthouse
2
+
3
+ Run this command to launch the lighthouse somewhere and make sure other slurm nodes have access to this node
4
+
5
+
6
+ ```
7
+ RUST_BACKTRACE=1 torchft_lighthouse --min_replicas 1 --quorum_tick_ms 100 --join_timeout_ms 10000
8
+ ```
9
+
10
+ ## Launch training
11
+
12
+ First, go to your local torchtitan folder and run
13
+
14
+ ```bash
15
+ $ pip install -r requirements.txt
16
+ $ pip install .
17
+ ```
18
+
19
+ Run the following command to launch torchft lighthouse and replicas using torchtitan on slurm
20
+
21
+ ```bash
22
+ $ pip install torchx-nightly
23
+ $ # Set the address of the lighthouse server e.g.
24
+ $ export TORCHFT_LIGHTHOUSE=http://slurm-head-node-0:29510
25
+ $ python runner.py --workspace-dir=/path/to/torchtitan/folder --nodes=1 --nproc-per-node=8 --replica-count=2
26
+ ```
27
+
28
+ ## Test fault tolerance
29
+
30
+ To inject some failures, you can use the following command
31
+
32
+ ```bash
33
+ $ python punisher.py kill_loop --num-failures=10 --mtbf-secs=300
34
+ ```