nexaroa 0.0.111__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. neuroshard/__init__.py +93 -0
  2. neuroshard/__main__.py +4 -0
  3. neuroshard/cli.py +466 -0
  4. neuroshard/core/__init__.py +92 -0
  5. neuroshard/core/consensus/verifier.py +252 -0
  6. neuroshard/core/crypto/__init__.py +20 -0
  7. neuroshard/core/crypto/ecdsa.py +392 -0
  8. neuroshard/core/economics/__init__.py +52 -0
  9. neuroshard/core/economics/constants.py +387 -0
  10. neuroshard/core/economics/ledger.py +2111 -0
  11. neuroshard/core/economics/market.py +975 -0
  12. neuroshard/core/economics/wallet.py +168 -0
  13. neuroshard/core/governance/__init__.py +74 -0
  14. neuroshard/core/governance/proposal.py +561 -0
  15. neuroshard/core/governance/registry.py +545 -0
  16. neuroshard/core/governance/versioning.py +332 -0
  17. neuroshard/core/governance/voting.py +453 -0
  18. neuroshard/core/model/__init__.py +30 -0
  19. neuroshard/core/model/dynamic.py +4186 -0
  20. neuroshard/core/model/llm.py +905 -0
  21. neuroshard/core/model/registry.py +164 -0
  22. neuroshard/core/model/scaler.py +387 -0
  23. neuroshard/core/model/tokenizer.py +568 -0
  24. neuroshard/core/network/__init__.py +56 -0
  25. neuroshard/core/network/connection_pool.py +72 -0
  26. neuroshard/core/network/dht.py +130 -0
  27. neuroshard/core/network/dht_plan.py +55 -0
  28. neuroshard/core/network/dht_proof_store.py +516 -0
  29. neuroshard/core/network/dht_protocol.py +261 -0
  30. neuroshard/core/network/dht_service.py +506 -0
  31. neuroshard/core/network/encrypted_channel.py +141 -0
  32. neuroshard/core/network/nat.py +201 -0
  33. neuroshard/core/network/nat_traversal.py +695 -0
  34. neuroshard/core/network/p2p.py +929 -0
  35. neuroshard/core/network/p2p_data.py +150 -0
  36. neuroshard/core/swarm/__init__.py +106 -0
  37. neuroshard/core/swarm/aggregation.py +729 -0
  38. neuroshard/core/swarm/buffers.py +643 -0
  39. neuroshard/core/swarm/checkpoint.py +709 -0
  40. neuroshard/core/swarm/compute.py +624 -0
  41. neuroshard/core/swarm/diloco.py +844 -0
  42. neuroshard/core/swarm/factory.py +1288 -0
  43. neuroshard/core/swarm/heartbeat.py +669 -0
  44. neuroshard/core/swarm/logger.py +487 -0
  45. neuroshard/core/swarm/router.py +658 -0
  46. neuroshard/core/swarm/service.py +640 -0
  47. neuroshard/core/training/__init__.py +29 -0
  48. neuroshard/core/training/checkpoint.py +600 -0
  49. neuroshard/core/training/distributed.py +1602 -0
  50. neuroshard/core/training/global_tracker.py +617 -0
  51. neuroshard/core/training/production.py +276 -0
  52. neuroshard/governance_cli.py +729 -0
  53. neuroshard/grpc_server.py +895 -0
  54. neuroshard/runner.py +3223 -0
  55. neuroshard/sdk/__init__.py +92 -0
  56. neuroshard/sdk/client.py +990 -0
  57. neuroshard/sdk/errors.py +101 -0
  58. neuroshard/sdk/types.py +282 -0
  59. neuroshard/tracker/__init__.py +0 -0
  60. neuroshard/tracker/server.py +864 -0
  61. neuroshard/ui/__init__.py +0 -0
  62. neuroshard/ui/app.py +102 -0
  63. neuroshard/ui/templates/index.html +1052 -0
  64. neuroshard/utils/__init__.py +0 -0
  65. neuroshard/utils/autostart.py +81 -0
  66. neuroshard/utils/hardware.py +121 -0
  67. neuroshard/utils/serialization.py +90 -0
  68. neuroshard/version.py +1 -0
  69. nexaroa-0.0.111.dist-info/METADATA +283 -0
  70. nexaroa-0.0.111.dist-info/RECORD +78 -0
  71. nexaroa-0.0.111.dist-info/WHEEL +5 -0
  72. nexaroa-0.0.111.dist-info/entry_points.txt +4 -0
  73. nexaroa-0.0.111.dist-info/licenses/LICENSE +190 -0
  74. nexaroa-0.0.111.dist-info/top_level.txt +2 -0
  75. protos/__init__.py +0 -0
  76. protos/neuroshard.proto +651 -0
  77. protos/neuroshard_pb2.py +160 -0
  78. protos/neuroshard_pb2_grpc.py +1298 -0
@@ -0,0 +1,844 @@
1
+ """
2
+ DiLoCo Trainer - Distributed Low-Communication Training
3
+
4
+ Implements the DiLoCo algorithm for bandwidth-efficient distributed training:
5
+ - Inner Loop: Each node trains independently for N steps (local SGD)
6
+ - Outer Loop: Periodically sync pseudo-gradients across peers
7
+ - Outer Optimizer: Nesterov momentum on the aggregated delta
8
+
9
+ Key Benefits:
10
+ - N× reduction in communication (sync every 500 steps vs every step)
11
+ - More robust to stragglers (nodes train at their own pace)
12
+ - Better for high-latency residential networks
13
+ - Naturally supports the "Don't Stop" soft overflow mechanism
14
+
15
+ Based on: "DiLoCo: Distributed Low-Communication Training of Language Models"
16
+ (Douillard et al., 2023)
17
+
18
+ Usage:
19
+ trainer = DiLoCoTrainer(model, optimizer, inner_steps=500)
20
+
21
+ # Training loop
22
+ while training:
23
+ loss = trainer.inner_step(batch)
24
+
25
+ if trainer.should_sync():
26
+ pseudo_grads = trainer.compute_pseudo_gradient()
27
+ aggregated = await gossip_gradients(pseudo_grads)
28
+ trainer.apply_outer_update(aggregated)
29
+ """
30
+
31
+ import asyncio
32
+ import copy
33
+ import logging
34
+ import threading
35
+ import time
36
+ from dataclasses import dataclass, field
37
+ from typing import Dict, List, Optional, Any, Callable, Tuple
38
+ from enum import Enum
39
+
40
+ import torch
41
+ import torch.nn as nn
42
+ import torch.nn.functional as F
43
+
44
+ logger = logging.getLogger(__name__)
45
+
46
+
47
+ class DiLoCoPhase(Enum):
48
+ """Current phase of DiLoCo training."""
49
+ INNER_LOOP = "inner_loop" # Local training
50
+ COMPUTING_DELTA = "computing_delta" # Computing pseudo-gradient
51
+ SYNCING = "syncing" # Waiting for peer aggregation
52
+ OUTER_STEP = "outer_step" # Applying outer update
53
+ IDLE = "idle"
54
+
55
+
56
+ @dataclass
57
+ class DiLoCoConfig:
58
+ """Configuration for DiLoCo training."""
59
+ # Inner loop settings
60
+ inner_steps: int = 500 # Steps before sync
61
+ inner_lr: float = 1e-4 # Inner optimizer learning rate
62
+ inner_weight_decay: float = 0.1 # Weight decay for inner optimizer
63
+
64
+ # Outer loop settings
65
+ outer_lr: float = 0.7 # Outer optimizer learning rate
66
+ outer_momentum: float = 0.9 # Nesterov momentum
67
+ outer_weight_decay: float = 0.0 # Outer weight decay (usually 0)
68
+
69
+ # Gradient settings
70
+ max_grad_norm: float = 1.0 # Gradient clipping
71
+ gradient_accumulation: int = 1 # Accumulation steps
72
+
73
+ # Sync settings
74
+ sync_timeout: float = 60.0 # Timeout waiting for peers
75
+ min_peers_for_sync: int = 1 # Minimum peers to average with
76
+
77
+ # Validation
78
+ validate_gradients: bool = True # Enable gradient validation
79
+ gradient_cosine_threshold: float = 0.5 # Min cosine similarity
80
+
81
+ # Learning rate scheduling (NEW)
82
+ use_lr_scheduler: bool = True # Enable cosine annealing LR
83
+ warmup_steps: int = 1000 # LR warmup steps (linear ramp)
84
+ min_lr_ratio: float = 0.1 # Min LR = min_lr_ratio * inner_lr
85
+ lr_decay_steps: int = 50000 # Steps for full cosine cycle
86
+
87
+
88
+ @dataclass
89
+ class DiLoCoStats:
90
+ """Statistics for DiLoCo training."""
91
+ inner_step_count: int = 0
92
+ outer_step_count: int = 0
93
+ total_inner_steps: int = 0
94
+
95
+ # Loss tracking
96
+ inner_loss_sum: float = 0.0
97
+ inner_loss_count: int = 0
98
+
99
+ # Sync tracking
100
+ successful_syncs: int = 0
101
+ failed_syncs: int = 0
102
+ local_only_outer_steps: int = 0
103
+
104
+ # Timing
105
+ inner_loop_time: float = 0.0
106
+ outer_loop_time: float = 0.0
107
+ sync_time: float = 0.0
108
+
109
+ # Gradient stats
110
+ avg_pseudo_grad_norm: float = 0.0
111
+ avg_cosine_with_peers: float = 0.0
112
+
113
+ @property
114
+ def avg_inner_loss(self) -> float:
115
+ if self.inner_loss_count == 0:
116
+ return 0.0
117
+ return self.inner_loss_sum / self.inner_loss_count
118
+
119
+ def reset_inner_stats(self):
120
+ """Reset stats for new inner loop."""
121
+ self.inner_loss_sum = 0.0
122
+ self.inner_loss_count = 0
123
+ self.inner_step_count = 0
124
+
125
+
126
+ class OuterOptimizer:
127
+ """
128
+ Nesterov momentum optimizer for DiLoCo outer loop.
129
+
130
+ Applies Nesterov-style momentum to pseudo-gradients:
131
+ v_t = momentum * v_{t-1} + delta
132
+ w_t = w_{t-1} + lr * (momentum * v_t + delta)
133
+
134
+ This provides better convergence than simple averaging.
135
+ """
136
+
137
+ def __init__(
138
+ self,
139
+ lr: float = 0.7,
140
+ momentum: float = 0.9,
141
+ weight_decay: float = 0.0,
142
+ ):
143
+ self.lr = lr
144
+ self.momentum = momentum
145
+ self.weight_decay = weight_decay
146
+
147
+ # Momentum buffers
148
+ self.velocity: Dict[str, torch.Tensor] = {}
149
+
150
+ def step(
151
+ self,
152
+ model: nn.Module,
153
+ pseudo_gradients: Dict[str, torch.Tensor],
154
+ ):
155
+ """
156
+ Apply outer optimizer step.
157
+
158
+ Args:
159
+ model: Model to update
160
+ pseudo_gradients: Dict of name -> pseudo-gradient tensor
161
+ """
162
+ with torch.no_grad():
163
+ for name, param in model.named_parameters():
164
+ if name not in pseudo_gradients:
165
+ continue
166
+
167
+ delta = pseudo_gradients[name]
168
+
169
+ # Weight decay (applied to delta, not param)
170
+ if self.weight_decay > 0:
171
+ delta = delta + self.weight_decay * param.data
172
+
173
+ # Initialize velocity if needed
174
+ if name not in self.velocity:
175
+ self.velocity[name] = torch.zeros_like(delta)
176
+
177
+ v = self.velocity[name]
178
+
179
+ # Nesterov momentum update
180
+ # v_new = momentum * v + delta
181
+ v.mul_(self.momentum).add_(delta)
182
+
183
+ # Update: w = w + lr * (momentum * v_new + delta)
184
+ # This is the "look ahead" part of Nesterov
185
+ update = self.lr * (self.momentum * v + delta)
186
+ param.data.add_(update)
187
+
188
+ # Save velocity
189
+ self.velocity[name] = v
190
+
191
+ def state_dict(self) -> Dict[str, Any]:
192
+ """Get optimizer state."""
193
+ return {
194
+ 'lr': self.lr,
195
+ 'momentum': self.momentum,
196
+ 'velocity': {k: v.clone() for k, v in self.velocity.items()},
197
+ }
198
+
199
+ def load_state_dict(self, state: Dict[str, Any], device: str = None):
200
+ """Load optimizer state.
201
+
202
+ Args:
203
+ state: State dict to load
204
+ device: Target device for tensors (if None, keeps original device)
205
+ """
206
+ self.lr = state.get('lr', self.lr)
207
+ self.momentum = state.get('momentum', self.momentum)
208
+ if device:
209
+ self.velocity = {k: v.clone().to(device) for k, v in state.get('velocity', {}).items()}
210
+ else:
211
+ self.velocity = {k: v.clone() for k, v in state.get('velocity', {}).items()}
212
+
213
+
214
+ class DiLoCoTrainer:
215
+ """
216
+ Distributed Low-Communication Training Coordinator.
217
+
218
+ Manages the DiLoCo algorithm:
219
+ 1. Save initial weights at start of inner loop
220
+ 2. Train locally for N steps (inner loop)
221
+ 3. Compute pseudo-gradient (delta from initial)
222
+ 4. Sync with peers via gossip
223
+ 5. Apply outer optimizer update
224
+ 6. Repeat
225
+
226
+ Integrates with:
227
+ - SwarmRouter for peer discovery
228
+ - ActivationBuffer for async compute
229
+ - RobustAggregator for Byzantine-tolerant sync
230
+ """
231
+
232
+ def __init__(
233
+ self,
234
+ model: nn.Module,
235
+ config: Optional[DiLoCoConfig] = None,
236
+ inner_optimizer: Optional[torch.optim.Optimizer] = None,
237
+ device: str = "cpu",
238
+ ):
239
+ """
240
+ Initialize DiLoCo trainer.
241
+
242
+ Args:
243
+ model: Model to train
244
+ config: DiLoCo configuration
245
+ inner_optimizer: Optimizer for inner loop (creates AdamW if None)
246
+ device: Device for training
247
+ """
248
+ self.model = model
249
+ self.config = config or DiLoCoConfig()
250
+ self.device = device
251
+
252
+ # Inner optimizer
253
+ if inner_optimizer is None:
254
+ self.inner_optimizer = torch.optim.AdamW(
255
+ model.parameters(),
256
+ lr=self.config.inner_lr,
257
+ weight_decay=self.config.inner_weight_decay,
258
+ betas=(0.9, 0.95),
259
+ )
260
+ else:
261
+ self.inner_optimizer = inner_optimizer
262
+
263
+ # Outer optimizer
264
+ self.outer_optimizer = OuterOptimizer(
265
+ lr=self.config.outer_lr,
266
+ momentum=self.config.outer_momentum,
267
+ weight_decay=self.config.outer_weight_decay,
268
+ )
269
+
270
+ # Initial weights (saved at start of each inner loop)
271
+ self.initial_weights: Dict[str, torch.Tensor] = {}
272
+
273
+ # State
274
+ self.phase = DiLoCoPhase.IDLE
275
+ self.stats = DiLoCoStats()
276
+
277
+ # Gradient accumulation
278
+ self._accumulated_loss = 0.0
279
+ self._accumulation_count = 0
280
+
281
+ # Callbacks
282
+ self._sync_callback: Optional[Callable] = None
283
+ self._on_outer_step: Optional[Callable] = None
284
+
285
+ # Learning rate scheduling
286
+ self._base_lr = self.config.inner_lr
287
+ self._current_lr = self._base_lr
288
+ self._min_lr = self._base_lr * self.config.min_lr_ratio
289
+
290
+ # Thread safety
291
+ self._lock = threading.RLock()
292
+
293
+ logger.info(f"DiLoCoTrainer initialized: inner_steps={self.config.inner_steps}, "
294
+ f"outer_lr={self.config.outer_lr}, outer_momentum={self.config.outer_momentum}")
295
+ if self.config.use_lr_scheduler:
296
+ logger.info(f" LR scheduler: warmup={self.config.warmup_steps}, "
297
+ f"decay_steps={self.config.lr_decay_steps}, min_ratio={self.config.min_lr_ratio}")
298
+
299
+ def set_sync_callback(self, callback: Callable):
300
+ """
301
+ Set callback for pseudo-gradient synchronization.
302
+
303
+ Callback signature:
304
+ async def sync(pseudo_grads: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]
305
+
306
+ Should return aggregated pseudo-gradients from peers.
307
+ """
308
+ self._sync_callback = callback
309
+
310
+ def set_outer_step_callback(self, callback: Callable):
311
+ """Set callback called after each outer step."""
312
+ self._on_outer_step = callback
313
+
314
+ # ==================== LIFECYCLE ====================
315
+
316
+ def start_inner_loop(self):
317
+ """Start a new inner loop by saving initial weights."""
318
+ with self._lock:
319
+ self._save_initial_weights()
320
+ self.stats.reset_inner_stats()
321
+ self.phase = DiLoCoPhase.INNER_LOOP
322
+
323
+ logger.debug(f"Started inner loop {self.stats.outer_step_count + 1}")
324
+
325
+ def _save_initial_weights(self):
326
+ """Save current weights as initial for pseudo-gradient computation."""
327
+ self.initial_weights = {}
328
+ for name, param in self.model.named_parameters():
329
+ if param.requires_grad:
330
+ self.initial_weights[name] = param.data.clone()
331
+
332
+ # ==================== LEARNING RATE SCHEDULING ====================
333
+
334
+ def _compute_lr(self, step: int) -> float:
335
+ """
336
+ Compute learning rate with warmup and cosine annealing.
337
+
338
+ Schedule:
339
+ 1. Warmup phase (0 -> warmup_steps): Linear ramp from 0 to base_lr
340
+ 2. Decay phase (warmup_steps -> decay_steps): Cosine decay to min_lr
341
+ 3. After decay_steps: Hold at min_lr
342
+
343
+ Args:
344
+ step: Current total training step
345
+
346
+ Returns:
347
+ Learning rate for this step
348
+ """
349
+ import math
350
+
351
+ warmup_steps = self.config.warmup_steps
352
+ decay_steps = self.config.lr_decay_steps
353
+ base_lr = self._base_lr
354
+ min_lr = self._min_lr
355
+
356
+ if step < warmup_steps:
357
+ # Linear warmup
358
+ return base_lr * (step / warmup_steps)
359
+ elif step < decay_steps:
360
+ # Cosine annealing decay
361
+ progress = (step - warmup_steps) / (decay_steps - warmup_steps)
362
+ cosine_decay = 0.5 * (1 + math.cos(math.pi * progress))
363
+ return min_lr + (base_lr - min_lr) * cosine_decay
364
+ else:
365
+ # After full decay, hold at min_lr
366
+ return min_lr
367
+
368
+ def _apply_lr(self, lr: float):
369
+ """Apply learning rate to the inner optimizer."""
370
+ for param_group in self.inner_optimizer.param_groups:
371
+ param_group['lr'] = lr
372
+ self._current_lr = lr
373
+
374
+ def get_current_lr(self) -> float:
375
+ """Get the current learning rate."""
376
+ return self._current_lr
377
+
378
+ # ==================== INNER LOOP ====================
379
+
380
+ def inner_step(self, loss: torch.Tensor) -> float:
381
+ """
382
+ Execute one inner optimization step.
383
+
384
+ This is normal SGD/AdamW training - no communication needed.
385
+ Applies learning rate scheduling if enabled.
386
+
387
+ Args:
388
+ loss: Loss tensor from forward pass
389
+
390
+ Returns:
391
+ Loss value as float
392
+ """
393
+ with self._lock:
394
+ if self.phase == DiLoCoPhase.IDLE:
395
+ self.start_inner_loop()
396
+
397
+ loss_value = loss.item()
398
+
399
+ # Gradient accumulation
400
+ scaled_loss = loss / self.config.gradient_accumulation
401
+ scaled_loss.backward()
402
+
403
+ self._accumulated_loss += loss_value
404
+ self._accumulation_count += 1
405
+
406
+ # Only step optimizer after accumulation
407
+ if self._accumulation_count >= self.config.gradient_accumulation:
408
+ # Apply learning rate scheduling before step
409
+ if self.config.use_lr_scheduler:
410
+ new_lr = self._compute_lr(self.stats.total_inner_steps)
411
+ self._apply_lr(new_lr)
412
+
413
+ # Gradient clipping
414
+ torch.nn.utils.clip_grad_norm_(
415
+ self.model.parameters(),
416
+ self.config.max_grad_norm
417
+ )
418
+
419
+ # Inner optimizer step
420
+ self.inner_optimizer.step()
421
+ self.inner_optimizer.zero_grad()
422
+
423
+ # Update stats
424
+ self.stats.inner_step_count += 1
425
+ self.stats.total_inner_steps += 1
426
+ self.stats.inner_loss_sum += self._accumulated_loss / self._accumulation_count
427
+ self.stats.inner_loss_count += 1
428
+
429
+ # Reset accumulation
430
+ self._accumulated_loss = 0.0
431
+ self._accumulation_count = 0
432
+
433
+ return loss_value
434
+
435
+ def should_sync(self) -> bool:
436
+ """Check if we should trigger outer sync."""
437
+ return self.stats.inner_step_count >= self.config.inner_steps
438
+
439
+ # ==================== OUTER LOOP ====================
440
+
441
+ def compute_pseudo_gradient(self) -> Dict[str, torch.Tensor]:
442
+ """
443
+ Compute pseudo-gradient (delta from initial weights).
444
+
445
+ Pseudo-gradient = current_weights - initial_weights
446
+ This represents the DIRECTION we improved during the inner loop.
447
+
448
+ The outer optimizer will then AMPLIFY this direction with momentum,
449
+ effectively saying "we moved this way and it reduced loss, so let's
450
+ continue moving this way".
451
+
452
+ NOTE: The sign here is CRITICAL for training to work!
453
+ - current - initial = direction we moved (positive = training progress)
454
+ - The outer optimizer ADDs this to weights, amplifying the improvement
455
+
456
+ Returns:
457
+ Dict of name -> pseudo-gradient tensor
458
+ """
459
+ with self._lock:
460
+ self.phase = DiLoCoPhase.COMPUTING_DELTA
461
+
462
+ pseudo_grads = {}
463
+ total_norm = 0.0
464
+
465
+ for name, param in self.model.named_parameters():
466
+ if name in self.initial_weights:
467
+ # Delta = current - initial (direction we moved during training)
468
+ # This is POSITIVE when training made progress
469
+ delta = param.data - self.initial_weights[name]
470
+ pseudo_grads[name] = delta
471
+ total_norm += delta.norm().item() ** 2
472
+
473
+ # Update stats
474
+ self.stats.avg_pseudo_grad_norm = (total_norm ** 0.5)
475
+
476
+ logger.info(f"Computed pseudo-gradient: "
477
+ f"norm={self.stats.avg_pseudo_grad_norm:.4f}, "
478
+ f"params={len(pseudo_grads)}")
479
+
480
+ return pseudo_grads
481
+
482
+ async def sync_with_peers(
483
+ self,
484
+ pseudo_grads: Dict[str, torch.Tensor],
485
+ ) -> Optional[Dict[str, torch.Tensor]]:
486
+ """
487
+ Synchronize pseudo-gradients with peers.
488
+
489
+ Uses the sync callback to gossip and aggregate.
490
+
491
+ Args:
492
+ pseudo_grads: Local pseudo-gradients
493
+
494
+ Returns:
495
+ Aggregated pseudo-gradients from all peers, or None on failure
496
+ """
497
+ with self._lock:
498
+ self.phase = DiLoCoPhase.SYNCING
499
+
500
+ if self._sync_callback is None:
501
+ logger.warning("No sync callback set - using local gradients only")
502
+ return pseudo_grads
503
+
504
+ start_time = time.time()
505
+
506
+ try:
507
+ # Call sync callback (should gossip to peers)
508
+ aggregated = await asyncio.wait_for(
509
+ self._sync_callback(pseudo_grads),
510
+ timeout=self.config.sync_timeout
511
+ )
512
+
513
+ sync_time = time.time() - start_time
514
+ self.stats.sync_time += sync_time
515
+ self.stats.successful_syncs += 1
516
+
517
+ logger.info(f"Sync completed in {sync_time:.2f}s")
518
+
519
+ return aggregated
520
+
521
+ except asyncio.TimeoutError:
522
+ logger.warning(f"Sync timeout after {self.config.sync_timeout}s")
523
+ self.stats.failed_syncs += 1
524
+ return None
525
+
526
+ except Exception as e:
527
+ logger.error(f"Sync failed: {e}")
528
+ self.stats.failed_syncs += 1
529
+ return None
530
+
531
+ def apply_outer_update(
532
+ self,
533
+ aggregated_pseudo_grads: Optional[Dict[str, torch.Tensor]] = None,
534
+ ):
535
+ """
536
+ Apply outer optimizer step with aggregated pseudo-gradients.
537
+
538
+ If no aggregated gradients provided, uses local pseudo-gradients.
539
+
540
+ Args:
541
+ aggregated_pseudo_grads: Aggregated pseudo-gradients from peers
542
+ """
543
+ with self._lock:
544
+ self.phase = DiLoCoPhase.OUTER_STEP
545
+ start_time = time.time()
546
+
547
+ # Use local gradients if sync failed
548
+ if aggregated_pseudo_grads is None:
549
+ logger.warning("Using local pseudo-gradients (sync failed)")
550
+ aggregated_pseudo_grads = self.compute_pseudo_gradient()
551
+ self.stats.local_only_outer_steps += 1
552
+
553
+ # Apply outer optimizer
554
+ self.outer_optimizer.step(self.model, aggregated_pseudo_grads)
555
+
556
+ # Update stats
557
+ self.stats.outer_step_count += 1
558
+ self.stats.outer_loop_time += time.time() - start_time
559
+
560
+ logger.info(f"Outer step {self.stats.outer_step_count} complete "
561
+ f"(after {self.config.inner_steps} inner steps, "
562
+ f"avg_loss={self.stats.avg_inner_loss:.4f})")
563
+
564
+ # Callback
565
+ if self._on_outer_step:
566
+ self._on_outer_step(self.stats)
567
+
568
+ # Start new inner loop
569
+ self.start_inner_loop()
570
+
571
+ async def outer_step_async(self) -> bool:
572
+ """
573
+ Execute full outer step: compute, sync, apply.
574
+
575
+ Async version that handles the full sync flow.
576
+
577
+ Returns:
578
+ True if sync succeeded, False if used local gradients
579
+ """
580
+ # Compute pseudo-gradient
581
+ pseudo_grads = self.compute_pseudo_gradient()
582
+
583
+ # Sync with peers
584
+ aggregated = await self.sync_with_peers(pseudo_grads)
585
+
586
+ # Apply update
587
+ self.apply_outer_update(aggregated)
588
+
589
+ return aggregated is not None
590
+
591
+ def outer_step_sync(self) -> bool:
592
+ """
593
+ Synchronous version of outer step.
594
+
595
+ Runs async outer step in new event loop.
596
+ """
597
+ loop = asyncio.new_event_loop()
598
+ try:
599
+ return loop.run_until_complete(self.outer_step_async())
600
+ finally:
601
+ loop.close()
602
+
603
+ # ==================== UTILITIES ====================
604
+
605
+ def get_stats(self) -> Dict[str, Any]:
606
+ """Get training statistics."""
607
+ with self._lock:
608
+ return {
609
+ 'phase': self.phase.value,
610
+ 'inner_step_count': self.stats.inner_step_count,
611
+ 'outer_step_count': self.stats.outer_step_count,
612
+ 'total_inner_steps': self.stats.total_inner_steps,
613
+ 'avg_inner_loss': self.stats.avg_inner_loss,
614
+ 'successful_syncs': self.stats.successful_syncs,
615
+ 'failed_syncs': self.stats.failed_syncs,
616
+ 'local_only_outer_steps': self.stats.local_only_outer_steps,
617
+ 'avg_pseudo_grad_norm': self.stats.avg_pseudo_grad_norm,
618
+ 'inner_loop_time': self.stats.inner_loop_time,
619
+ 'outer_loop_time': self.stats.outer_loop_time,
620
+ 'sync_time': self.stats.sync_time,
621
+ # Learning rate info
622
+ 'current_lr': self._current_lr,
623
+ 'base_lr': self._base_lr,
624
+ 'min_lr': self._min_lr,
625
+ 'lr_scheduler_enabled': self.config.use_lr_scheduler,
626
+ }
627
+
628
+ def state_dict(self) -> Dict[str, Any]:
629
+ """Get full trainer state for checkpointing."""
630
+ with self._lock:
631
+ return {
632
+ 'config': self.config.__dict__,
633
+ 'inner_optimizer': self.inner_optimizer.state_dict(),
634
+ 'outer_optimizer': self.outer_optimizer.state_dict(),
635
+ 'initial_weights': {k: v.clone() for k, v in self.initial_weights.items()},
636
+ 'stats': {
637
+ 'inner_step_count': self.stats.inner_step_count,
638
+ 'outer_step_count': self.stats.outer_step_count,
639
+ 'total_inner_steps': self.stats.total_inner_steps,
640
+ },
641
+ 'phase': self.phase.value,
642
+ }
643
+
644
+ def load_state_dict(self, state: Dict[str, Any]):
645
+ """Load trainer state from checkpoint."""
646
+ with self._lock:
647
+ # Load config
648
+ for k, v in state.get('config', {}).items():
649
+ if hasattr(self.config, k):
650
+ setattr(self.config, k, v)
651
+
652
+ # Load optimizers (move tensors to model's device)
653
+ device = next(self.model.parameters()).device if list(self.model.parameters()) else 'cpu'
654
+ if 'inner_optimizer' in state:
655
+ self.inner_optimizer.load_state_dict(state['inner_optimizer'])
656
+ if 'outer_optimizer' in state:
657
+ self.outer_optimizer.load_state_dict(state['outer_optimizer'], device=str(device))
658
+
659
+ # Load initial weights (move to model's device)
660
+ device = next(self.model.parameters()).device if list(self.model.parameters()) else 'cpu'
661
+ self.initial_weights = {
662
+ k: v.clone().to(device) for k, v in state.get('initial_weights', {}).items()
663
+ }
664
+
665
+ # Load stats
666
+ stats = state.get('stats', {})
667
+ self.stats.inner_step_count = stats.get('inner_step_count', 0)
668
+ self.stats.outer_step_count = stats.get('outer_step_count', 0)
669
+ self.stats.total_inner_steps = stats.get('total_inner_steps', 0)
670
+
671
+ # Load phase
672
+ phase_str = state.get('phase', 'idle')
673
+ try:
674
+ self.phase = DiLoCoPhase(phase_str)
675
+ except ValueError:
676
+ self.phase = DiLoCoPhase.IDLE
677
+
678
+
679
+ # ==================== GOSSIP INTEGRATION ====================
680
+
681
+ class DiLoCoGossipProtocol:
682
+ """
683
+ Gossip protocol for DiLoCo pseudo-gradient synchronization.
684
+
685
+ Handles the communication aspect of DiLoCo:
686
+ - Broadcast pseudo-gradients to peers
687
+ - Collect and aggregate responses
688
+ - Handle stragglers with timeout
689
+ """
690
+
691
+ def __init__(
692
+ self,
693
+ node_id: str,
694
+ router: Any = None, # SwarmRouter
695
+ min_peers: int = 1,
696
+ timeout: float = 30.0,
697
+ ):
698
+ self.node_id = node_id
699
+ self.router = router
700
+ self.min_peers = min_peers
701
+ self.timeout = timeout
702
+
703
+ # Pending contributions
704
+ self.pending_contributions: Dict[str, Dict[str, torch.Tensor]] = {}
705
+ self._lock = threading.Lock()
706
+
707
+ async def sync_pseudo_gradients(
708
+ self,
709
+ round_id: int,
710
+ local_grads: Dict[str, torch.Tensor],
711
+ ) -> Dict[str, torch.Tensor]:
712
+ """
713
+ Synchronize pseudo-gradients with peers.
714
+
715
+ Args:
716
+ round_id: Outer step round ID
717
+ local_grads: Local pseudo-gradients
718
+
719
+ Returns:
720
+ Averaged pseudo-gradients
721
+ """
722
+ # Get peers
723
+ if self.router is None:
724
+ logger.warning("No router - returning local gradients")
725
+ return local_grads
726
+
727
+ # Broadcast to peers
728
+ await self._broadcast_grads(round_id, local_grads)
729
+
730
+ # Wait for responses
731
+ contributions = await self._collect_contributions(round_id)
732
+
733
+ # Add our own contribution
734
+ contributions[self.node_id] = local_grads
735
+
736
+ # Aggregate
737
+ if len(contributions) < self.min_peers:
738
+ logger.warning(f"Only {len(contributions)} peers - below minimum {self.min_peers}")
739
+
740
+ aggregated = self._aggregate_contributions(contributions)
741
+
742
+ return aggregated
743
+
744
+ async def _broadcast_grads(
745
+ self,
746
+ round_id: int,
747
+ grads: Dict[str, torch.Tensor],
748
+ ):
749
+ """Broadcast pseudo-gradients to peers."""
750
+ # Implementation would use gRPC to broadcast
751
+ # Placeholder for integration with SwarmRouter
752
+ pass
753
+
754
+ async def _collect_contributions(
755
+ self,
756
+ round_id: int,
757
+ ) -> Dict[str, Dict[str, torch.Tensor]]:
758
+ """Collect contributions from peers."""
759
+ # Wait up to timeout for contributions
760
+ start = time.time()
761
+
762
+ while time.time() - start < self.timeout:
763
+ with self._lock:
764
+ if len(self.pending_contributions) >= self.min_peers:
765
+ contributions = dict(self.pending_contributions)
766
+ self.pending_contributions.clear()
767
+ return contributions
768
+
769
+ await asyncio.sleep(0.1)
770
+
771
+ # Timeout - return what we have
772
+ with self._lock:
773
+ contributions = dict(self.pending_contributions)
774
+ self.pending_contributions.clear()
775
+ return contributions
776
+
777
+ def _aggregate_contributions(
778
+ self,
779
+ contributions: Dict[str, Dict[str, torch.Tensor]],
780
+ ) -> Dict[str, torch.Tensor]:
781
+ """Average contributions from all peers."""
782
+ if not contributions:
783
+ return {}
784
+
785
+ # Get all param names
786
+ param_names = set()
787
+ for grads in contributions.values():
788
+ param_names.update(grads.keys())
789
+
790
+ # Average each parameter
791
+ aggregated = {}
792
+ for name in param_names:
793
+ tensors = [
794
+ grads[name] for grads in contributions.values()
795
+ if name in grads
796
+ ]
797
+ if tensors:
798
+ aggregated[name] = torch.stack(tensors).mean(dim=0)
799
+
800
+ return aggregated
801
+
802
+ def receive_contribution(
803
+ self,
804
+ peer_id: str,
805
+ grads: Dict[str, torch.Tensor],
806
+ ):
807
+ """Receive pseudo-gradient contribution from peer."""
808
+ with self._lock:
809
+ self.pending_contributions[peer_id] = grads
810
+
811
+
812
+ # ==================== FACTORY FUNCTIONS ====================
813
+
814
+ def create_diloco_trainer(
815
+ model: nn.Module,
816
+ inner_steps: int = 500,
817
+ outer_lr: float = 0.7,
818
+ inner_lr: float = 1e-4,
819
+ device: str = "cpu",
820
+ **config_kwargs,
821
+ ) -> DiLoCoTrainer:
822
+ """
823
+ Factory function to create a DiLoCo trainer.
824
+
825
+ Args:
826
+ model: Model to train
827
+ inner_steps: Steps before each sync
828
+ outer_lr: Outer optimizer learning rate
829
+ inner_lr: Inner optimizer learning rate
830
+ device: Training device
831
+ **config_kwargs: Additional config options
832
+
833
+ Returns:
834
+ Configured DiLoCoTrainer
835
+ """
836
+ config = DiLoCoConfig(
837
+ inner_steps=inner_steps,
838
+ outer_lr=outer_lr,
839
+ inner_lr=inner_lr,
840
+ **config_kwargs,
841
+ )
842
+
843
+ return DiLoCoTrainer(model, config=config, device=device)
844
+