nexaroa 0.0.111__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- neuroshard/__init__.py +93 -0
- neuroshard/__main__.py +4 -0
- neuroshard/cli.py +466 -0
- neuroshard/core/__init__.py +92 -0
- neuroshard/core/consensus/verifier.py +252 -0
- neuroshard/core/crypto/__init__.py +20 -0
- neuroshard/core/crypto/ecdsa.py +392 -0
- neuroshard/core/economics/__init__.py +52 -0
- neuroshard/core/economics/constants.py +387 -0
- neuroshard/core/economics/ledger.py +2111 -0
- neuroshard/core/economics/market.py +975 -0
- neuroshard/core/economics/wallet.py +168 -0
- neuroshard/core/governance/__init__.py +74 -0
- neuroshard/core/governance/proposal.py +561 -0
- neuroshard/core/governance/registry.py +545 -0
- neuroshard/core/governance/versioning.py +332 -0
- neuroshard/core/governance/voting.py +453 -0
- neuroshard/core/model/__init__.py +30 -0
- neuroshard/core/model/dynamic.py +4186 -0
- neuroshard/core/model/llm.py +905 -0
- neuroshard/core/model/registry.py +164 -0
- neuroshard/core/model/scaler.py +387 -0
- neuroshard/core/model/tokenizer.py +568 -0
- neuroshard/core/network/__init__.py +56 -0
- neuroshard/core/network/connection_pool.py +72 -0
- neuroshard/core/network/dht.py +130 -0
- neuroshard/core/network/dht_plan.py +55 -0
- neuroshard/core/network/dht_proof_store.py +516 -0
- neuroshard/core/network/dht_protocol.py +261 -0
- neuroshard/core/network/dht_service.py +506 -0
- neuroshard/core/network/encrypted_channel.py +141 -0
- neuroshard/core/network/nat.py +201 -0
- neuroshard/core/network/nat_traversal.py +695 -0
- neuroshard/core/network/p2p.py +929 -0
- neuroshard/core/network/p2p_data.py +150 -0
- neuroshard/core/swarm/__init__.py +106 -0
- neuroshard/core/swarm/aggregation.py +729 -0
- neuroshard/core/swarm/buffers.py +643 -0
- neuroshard/core/swarm/checkpoint.py +709 -0
- neuroshard/core/swarm/compute.py +624 -0
- neuroshard/core/swarm/diloco.py +844 -0
- neuroshard/core/swarm/factory.py +1288 -0
- neuroshard/core/swarm/heartbeat.py +669 -0
- neuroshard/core/swarm/logger.py +487 -0
- neuroshard/core/swarm/router.py +658 -0
- neuroshard/core/swarm/service.py +640 -0
- neuroshard/core/training/__init__.py +29 -0
- neuroshard/core/training/checkpoint.py +600 -0
- neuroshard/core/training/distributed.py +1602 -0
- neuroshard/core/training/global_tracker.py +617 -0
- neuroshard/core/training/production.py +276 -0
- neuroshard/governance_cli.py +729 -0
- neuroshard/grpc_server.py +895 -0
- neuroshard/runner.py +3223 -0
- neuroshard/sdk/__init__.py +92 -0
- neuroshard/sdk/client.py +990 -0
- neuroshard/sdk/errors.py +101 -0
- neuroshard/sdk/types.py +282 -0
- neuroshard/tracker/__init__.py +0 -0
- neuroshard/tracker/server.py +864 -0
- neuroshard/ui/__init__.py +0 -0
- neuroshard/ui/app.py +102 -0
- neuroshard/ui/templates/index.html +1052 -0
- neuroshard/utils/__init__.py +0 -0
- neuroshard/utils/autostart.py +81 -0
- neuroshard/utils/hardware.py +121 -0
- neuroshard/utils/serialization.py +90 -0
- neuroshard/version.py +1 -0
- nexaroa-0.0.111.dist-info/METADATA +283 -0
- nexaroa-0.0.111.dist-info/RECORD +78 -0
- nexaroa-0.0.111.dist-info/WHEEL +5 -0
- nexaroa-0.0.111.dist-info/entry_points.txt +4 -0
- nexaroa-0.0.111.dist-info/licenses/LICENSE +190 -0
- nexaroa-0.0.111.dist-info/top_level.txt +2 -0
- protos/__init__.py +0 -0
- protos/neuroshard.proto +651 -0
- protos/neuroshard_pb2.py +160 -0
- protos/neuroshard_pb2_grpc.py +1298 -0
|
@@ -0,0 +1,844 @@
|
|
|
1
|
+
"""
|
|
2
|
+
DiLoCo Trainer - Distributed Low-Communication Training
|
|
3
|
+
|
|
4
|
+
Implements the DiLoCo algorithm for bandwidth-efficient distributed training:
|
|
5
|
+
- Inner Loop: Each node trains independently for N steps (local SGD)
|
|
6
|
+
- Outer Loop: Periodically sync pseudo-gradients across peers
|
|
7
|
+
- Outer Optimizer: Nesterov momentum on the aggregated delta
|
|
8
|
+
|
|
9
|
+
Key Benefits:
|
|
10
|
+
- N× reduction in communication (sync every 500 steps vs every step)
|
|
11
|
+
- More robust to stragglers (nodes train at their own pace)
|
|
12
|
+
- Better for high-latency residential networks
|
|
13
|
+
- Naturally supports the "Don't Stop" soft overflow mechanism
|
|
14
|
+
|
|
15
|
+
Based on: "DiLoCo: Distributed Low-Communication Training of Language Models"
|
|
16
|
+
(Douillard et al., 2023)
|
|
17
|
+
|
|
18
|
+
Usage:
|
|
19
|
+
trainer = DiLoCoTrainer(model, optimizer, inner_steps=500)
|
|
20
|
+
|
|
21
|
+
# Training loop
|
|
22
|
+
while training:
|
|
23
|
+
loss = trainer.inner_step(batch)
|
|
24
|
+
|
|
25
|
+
if trainer.should_sync():
|
|
26
|
+
pseudo_grads = trainer.compute_pseudo_gradient()
|
|
27
|
+
aggregated = await gossip_gradients(pseudo_grads)
|
|
28
|
+
trainer.apply_outer_update(aggregated)
|
|
29
|
+
"""
|
|
30
|
+
|
|
31
|
+
import asyncio
|
|
32
|
+
import copy
|
|
33
|
+
import logging
|
|
34
|
+
import threading
|
|
35
|
+
import time
|
|
36
|
+
from dataclasses import dataclass, field
|
|
37
|
+
from typing import Dict, List, Optional, Any, Callable, Tuple
|
|
38
|
+
from enum import Enum
|
|
39
|
+
|
|
40
|
+
import torch
|
|
41
|
+
import torch.nn as nn
|
|
42
|
+
import torch.nn.functional as F
|
|
43
|
+
|
|
44
|
+
logger = logging.getLogger(__name__)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
class DiLoCoPhase(Enum):
|
|
48
|
+
"""Current phase of DiLoCo training."""
|
|
49
|
+
INNER_LOOP = "inner_loop" # Local training
|
|
50
|
+
COMPUTING_DELTA = "computing_delta" # Computing pseudo-gradient
|
|
51
|
+
SYNCING = "syncing" # Waiting for peer aggregation
|
|
52
|
+
OUTER_STEP = "outer_step" # Applying outer update
|
|
53
|
+
IDLE = "idle"
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
@dataclass
|
|
57
|
+
class DiLoCoConfig:
|
|
58
|
+
"""Configuration for DiLoCo training."""
|
|
59
|
+
# Inner loop settings
|
|
60
|
+
inner_steps: int = 500 # Steps before sync
|
|
61
|
+
inner_lr: float = 1e-4 # Inner optimizer learning rate
|
|
62
|
+
inner_weight_decay: float = 0.1 # Weight decay for inner optimizer
|
|
63
|
+
|
|
64
|
+
# Outer loop settings
|
|
65
|
+
outer_lr: float = 0.7 # Outer optimizer learning rate
|
|
66
|
+
outer_momentum: float = 0.9 # Nesterov momentum
|
|
67
|
+
outer_weight_decay: float = 0.0 # Outer weight decay (usually 0)
|
|
68
|
+
|
|
69
|
+
# Gradient settings
|
|
70
|
+
max_grad_norm: float = 1.0 # Gradient clipping
|
|
71
|
+
gradient_accumulation: int = 1 # Accumulation steps
|
|
72
|
+
|
|
73
|
+
# Sync settings
|
|
74
|
+
sync_timeout: float = 60.0 # Timeout waiting for peers
|
|
75
|
+
min_peers_for_sync: int = 1 # Minimum peers to average with
|
|
76
|
+
|
|
77
|
+
# Validation
|
|
78
|
+
validate_gradients: bool = True # Enable gradient validation
|
|
79
|
+
gradient_cosine_threshold: float = 0.5 # Min cosine similarity
|
|
80
|
+
|
|
81
|
+
# Learning rate scheduling (NEW)
|
|
82
|
+
use_lr_scheduler: bool = True # Enable cosine annealing LR
|
|
83
|
+
warmup_steps: int = 1000 # LR warmup steps (linear ramp)
|
|
84
|
+
min_lr_ratio: float = 0.1 # Min LR = min_lr_ratio * inner_lr
|
|
85
|
+
lr_decay_steps: int = 50000 # Steps for full cosine cycle
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
@dataclass
|
|
89
|
+
class DiLoCoStats:
|
|
90
|
+
"""Statistics for DiLoCo training."""
|
|
91
|
+
inner_step_count: int = 0
|
|
92
|
+
outer_step_count: int = 0
|
|
93
|
+
total_inner_steps: int = 0
|
|
94
|
+
|
|
95
|
+
# Loss tracking
|
|
96
|
+
inner_loss_sum: float = 0.0
|
|
97
|
+
inner_loss_count: int = 0
|
|
98
|
+
|
|
99
|
+
# Sync tracking
|
|
100
|
+
successful_syncs: int = 0
|
|
101
|
+
failed_syncs: int = 0
|
|
102
|
+
local_only_outer_steps: int = 0
|
|
103
|
+
|
|
104
|
+
# Timing
|
|
105
|
+
inner_loop_time: float = 0.0
|
|
106
|
+
outer_loop_time: float = 0.0
|
|
107
|
+
sync_time: float = 0.0
|
|
108
|
+
|
|
109
|
+
# Gradient stats
|
|
110
|
+
avg_pseudo_grad_norm: float = 0.0
|
|
111
|
+
avg_cosine_with_peers: float = 0.0
|
|
112
|
+
|
|
113
|
+
@property
|
|
114
|
+
def avg_inner_loss(self) -> float:
|
|
115
|
+
if self.inner_loss_count == 0:
|
|
116
|
+
return 0.0
|
|
117
|
+
return self.inner_loss_sum / self.inner_loss_count
|
|
118
|
+
|
|
119
|
+
def reset_inner_stats(self):
|
|
120
|
+
"""Reset stats for new inner loop."""
|
|
121
|
+
self.inner_loss_sum = 0.0
|
|
122
|
+
self.inner_loss_count = 0
|
|
123
|
+
self.inner_step_count = 0
|
|
124
|
+
|
|
125
|
+
|
|
126
|
+
class OuterOptimizer:
|
|
127
|
+
"""
|
|
128
|
+
Nesterov momentum optimizer for DiLoCo outer loop.
|
|
129
|
+
|
|
130
|
+
Applies Nesterov-style momentum to pseudo-gradients:
|
|
131
|
+
v_t = momentum * v_{t-1} + delta
|
|
132
|
+
w_t = w_{t-1} + lr * (momentum * v_t + delta)
|
|
133
|
+
|
|
134
|
+
This provides better convergence than simple averaging.
|
|
135
|
+
"""
|
|
136
|
+
|
|
137
|
+
def __init__(
|
|
138
|
+
self,
|
|
139
|
+
lr: float = 0.7,
|
|
140
|
+
momentum: float = 0.9,
|
|
141
|
+
weight_decay: float = 0.0,
|
|
142
|
+
):
|
|
143
|
+
self.lr = lr
|
|
144
|
+
self.momentum = momentum
|
|
145
|
+
self.weight_decay = weight_decay
|
|
146
|
+
|
|
147
|
+
# Momentum buffers
|
|
148
|
+
self.velocity: Dict[str, torch.Tensor] = {}
|
|
149
|
+
|
|
150
|
+
def step(
|
|
151
|
+
self,
|
|
152
|
+
model: nn.Module,
|
|
153
|
+
pseudo_gradients: Dict[str, torch.Tensor],
|
|
154
|
+
):
|
|
155
|
+
"""
|
|
156
|
+
Apply outer optimizer step.
|
|
157
|
+
|
|
158
|
+
Args:
|
|
159
|
+
model: Model to update
|
|
160
|
+
pseudo_gradients: Dict of name -> pseudo-gradient tensor
|
|
161
|
+
"""
|
|
162
|
+
with torch.no_grad():
|
|
163
|
+
for name, param in model.named_parameters():
|
|
164
|
+
if name not in pseudo_gradients:
|
|
165
|
+
continue
|
|
166
|
+
|
|
167
|
+
delta = pseudo_gradients[name]
|
|
168
|
+
|
|
169
|
+
# Weight decay (applied to delta, not param)
|
|
170
|
+
if self.weight_decay > 0:
|
|
171
|
+
delta = delta + self.weight_decay * param.data
|
|
172
|
+
|
|
173
|
+
# Initialize velocity if needed
|
|
174
|
+
if name not in self.velocity:
|
|
175
|
+
self.velocity[name] = torch.zeros_like(delta)
|
|
176
|
+
|
|
177
|
+
v = self.velocity[name]
|
|
178
|
+
|
|
179
|
+
# Nesterov momentum update
|
|
180
|
+
# v_new = momentum * v + delta
|
|
181
|
+
v.mul_(self.momentum).add_(delta)
|
|
182
|
+
|
|
183
|
+
# Update: w = w + lr * (momentum * v_new + delta)
|
|
184
|
+
# This is the "look ahead" part of Nesterov
|
|
185
|
+
update = self.lr * (self.momentum * v + delta)
|
|
186
|
+
param.data.add_(update)
|
|
187
|
+
|
|
188
|
+
# Save velocity
|
|
189
|
+
self.velocity[name] = v
|
|
190
|
+
|
|
191
|
+
def state_dict(self) -> Dict[str, Any]:
|
|
192
|
+
"""Get optimizer state."""
|
|
193
|
+
return {
|
|
194
|
+
'lr': self.lr,
|
|
195
|
+
'momentum': self.momentum,
|
|
196
|
+
'velocity': {k: v.clone() for k, v in self.velocity.items()},
|
|
197
|
+
}
|
|
198
|
+
|
|
199
|
+
def load_state_dict(self, state: Dict[str, Any], device: str = None):
|
|
200
|
+
"""Load optimizer state.
|
|
201
|
+
|
|
202
|
+
Args:
|
|
203
|
+
state: State dict to load
|
|
204
|
+
device: Target device for tensors (if None, keeps original device)
|
|
205
|
+
"""
|
|
206
|
+
self.lr = state.get('lr', self.lr)
|
|
207
|
+
self.momentum = state.get('momentum', self.momentum)
|
|
208
|
+
if device:
|
|
209
|
+
self.velocity = {k: v.clone().to(device) for k, v in state.get('velocity', {}).items()}
|
|
210
|
+
else:
|
|
211
|
+
self.velocity = {k: v.clone() for k, v in state.get('velocity', {}).items()}
|
|
212
|
+
|
|
213
|
+
|
|
214
|
+
class DiLoCoTrainer:
|
|
215
|
+
"""
|
|
216
|
+
Distributed Low-Communication Training Coordinator.
|
|
217
|
+
|
|
218
|
+
Manages the DiLoCo algorithm:
|
|
219
|
+
1. Save initial weights at start of inner loop
|
|
220
|
+
2. Train locally for N steps (inner loop)
|
|
221
|
+
3. Compute pseudo-gradient (delta from initial)
|
|
222
|
+
4. Sync with peers via gossip
|
|
223
|
+
5. Apply outer optimizer update
|
|
224
|
+
6. Repeat
|
|
225
|
+
|
|
226
|
+
Integrates with:
|
|
227
|
+
- SwarmRouter for peer discovery
|
|
228
|
+
- ActivationBuffer for async compute
|
|
229
|
+
- RobustAggregator for Byzantine-tolerant sync
|
|
230
|
+
"""
|
|
231
|
+
|
|
232
|
+
def __init__(
|
|
233
|
+
self,
|
|
234
|
+
model: nn.Module,
|
|
235
|
+
config: Optional[DiLoCoConfig] = None,
|
|
236
|
+
inner_optimizer: Optional[torch.optim.Optimizer] = None,
|
|
237
|
+
device: str = "cpu",
|
|
238
|
+
):
|
|
239
|
+
"""
|
|
240
|
+
Initialize DiLoCo trainer.
|
|
241
|
+
|
|
242
|
+
Args:
|
|
243
|
+
model: Model to train
|
|
244
|
+
config: DiLoCo configuration
|
|
245
|
+
inner_optimizer: Optimizer for inner loop (creates AdamW if None)
|
|
246
|
+
device: Device for training
|
|
247
|
+
"""
|
|
248
|
+
self.model = model
|
|
249
|
+
self.config = config or DiLoCoConfig()
|
|
250
|
+
self.device = device
|
|
251
|
+
|
|
252
|
+
# Inner optimizer
|
|
253
|
+
if inner_optimizer is None:
|
|
254
|
+
self.inner_optimizer = torch.optim.AdamW(
|
|
255
|
+
model.parameters(),
|
|
256
|
+
lr=self.config.inner_lr,
|
|
257
|
+
weight_decay=self.config.inner_weight_decay,
|
|
258
|
+
betas=(0.9, 0.95),
|
|
259
|
+
)
|
|
260
|
+
else:
|
|
261
|
+
self.inner_optimizer = inner_optimizer
|
|
262
|
+
|
|
263
|
+
# Outer optimizer
|
|
264
|
+
self.outer_optimizer = OuterOptimizer(
|
|
265
|
+
lr=self.config.outer_lr,
|
|
266
|
+
momentum=self.config.outer_momentum,
|
|
267
|
+
weight_decay=self.config.outer_weight_decay,
|
|
268
|
+
)
|
|
269
|
+
|
|
270
|
+
# Initial weights (saved at start of each inner loop)
|
|
271
|
+
self.initial_weights: Dict[str, torch.Tensor] = {}
|
|
272
|
+
|
|
273
|
+
# State
|
|
274
|
+
self.phase = DiLoCoPhase.IDLE
|
|
275
|
+
self.stats = DiLoCoStats()
|
|
276
|
+
|
|
277
|
+
# Gradient accumulation
|
|
278
|
+
self._accumulated_loss = 0.0
|
|
279
|
+
self._accumulation_count = 0
|
|
280
|
+
|
|
281
|
+
# Callbacks
|
|
282
|
+
self._sync_callback: Optional[Callable] = None
|
|
283
|
+
self._on_outer_step: Optional[Callable] = None
|
|
284
|
+
|
|
285
|
+
# Learning rate scheduling
|
|
286
|
+
self._base_lr = self.config.inner_lr
|
|
287
|
+
self._current_lr = self._base_lr
|
|
288
|
+
self._min_lr = self._base_lr * self.config.min_lr_ratio
|
|
289
|
+
|
|
290
|
+
# Thread safety
|
|
291
|
+
self._lock = threading.RLock()
|
|
292
|
+
|
|
293
|
+
logger.info(f"DiLoCoTrainer initialized: inner_steps={self.config.inner_steps}, "
|
|
294
|
+
f"outer_lr={self.config.outer_lr}, outer_momentum={self.config.outer_momentum}")
|
|
295
|
+
if self.config.use_lr_scheduler:
|
|
296
|
+
logger.info(f" LR scheduler: warmup={self.config.warmup_steps}, "
|
|
297
|
+
f"decay_steps={self.config.lr_decay_steps}, min_ratio={self.config.min_lr_ratio}")
|
|
298
|
+
|
|
299
|
+
def set_sync_callback(self, callback: Callable):
|
|
300
|
+
"""
|
|
301
|
+
Set callback for pseudo-gradient synchronization.
|
|
302
|
+
|
|
303
|
+
Callback signature:
|
|
304
|
+
async def sync(pseudo_grads: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]
|
|
305
|
+
|
|
306
|
+
Should return aggregated pseudo-gradients from peers.
|
|
307
|
+
"""
|
|
308
|
+
self._sync_callback = callback
|
|
309
|
+
|
|
310
|
+
def set_outer_step_callback(self, callback: Callable):
|
|
311
|
+
"""Set callback called after each outer step."""
|
|
312
|
+
self._on_outer_step = callback
|
|
313
|
+
|
|
314
|
+
# ==================== LIFECYCLE ====================
|
|
315
|
+
|
|
316
|
+
def start_inner_loop(self):
|
|
317
|
+
"""Start a new inner loop by saving initial weights."""
|
|
318
|
+
with self._lock:
|
|
319
|
+
self._save_initial_weights()
|
|
320
|
+
self.stats.reset_inner_stats()
|
|
321
|
+
self.phase = DiLoCoPhase.INNER_LOOP
|
|
322
|
+
|
|
323
|
+
logger.debug(f"Started inner loop {self.stats.outer_step_count + 1}")
|
|
324
|
+
|
|
325
|
+
def _save_initial_weights(self):
|
|
326
|
+
"""Save current weights as initial for pseudo-gradient computation."""
|
|
327
|
+
self.initial_weights = {}
|
|
328
|
+
for name, param in self.model.named_parameters():
|
|
329
|
+
if param.requires_grad:
|
|
330
|
+
self.initial_weights[name] = param.data.clone()
|
|
331
|
+
|
|
332
|
+
# ==================== LEARNING RATE SCHEDULING ====================
|
|
333
|
+
|
|
334
|
+
def _compute_lr(self, step: int) -> float:
|
|
335
|
+
"""
|
|
336
|
+
Compute learning rate with warmup and cosine annealing.
|
|
337
|
+
|
|
338
|
+
Schedule:
|
|
339
|
+
1. Warmup phase (0 -> warmup_steps): Linear ramp from 0 to base_lr
|
|
340
|
+
2. Decay phase (warmup_steps -> decay_steps): Cosine decay to min_lr
|
|
341
|
+
3. After decay_steps: Hold at min_lr
|
|
342
|
+
|
|
343
|
+
Args:
|
|
344
|
+
step: Current total training step
|
|
345
|
+
|
|
346
|
+
Returns:
|
|
347
|
+
Learning rate for this step
|
|
348
|
+
"""
|
|
349
|
+
import math
|
|
350
|
+
|
|
351
|
+
warmup_steps = self.config.warmup_steps
|
|
352
|
+
decay_steps = self.config.lr_decay_steps
|
|
353
|
+
base_lr = self._base_lr
|
|
354
|
+
min_lr = self._min_lr
|
|
355
|
+
|
|
356
|
+
if step < warmup_steps:
|
|
357
|
+
# Linear warmup
|
|
358
|
+
return base_lr * (step / warmup_steps)
|
|
359
|
+
elif step < decay_steps:
|
|
360
|
+
# Cosine annealing decay
|
|
361
|
+
progress = (step - warmup_steps) / (decay_steps - warmup_steps)
|
|
362
|
+
cosine_decay = 0.5 * (1 + math.cos(math.pi * progress))
|
|
363
|
+
return min_lr + (base_lr - min_lr) * cosine_decay
|
|
364
|
+
else:
|
|
365
|
+
# After full decay, hold at min_lr
|
|
366
|
+
return min_lr
|
|
367
|
+
|
|
368
|
+
def _apply_lr(self, lr: float):
|
|
369
|
+
"""Apply learning rate to the inner optimizer."""
|
|
370
|
+
for param_group in self.inner_optimizer.param_groups:
|
|
371
|
+
param_group['lr'] = lr
|
|
372
|
+
self._current_lr = lr
|
|
373
|
+
|
|
374
|
+
def get_current_lr(self) -> float:
|
|
375
|
+
"""Get the current learning rate."""
|
|
376
|
+
return self._current_lr
|
|
377
|
+
|
|
378
|
+
# ==================== INNER LOOP ====================
|
|
379
|
+
|
|
380
|
+
def inner_step(self, loss: torch.Tensor) -> float:
|
|
381
|
+
"""
|
|
382
|
+
Execute one inner optimization step.
|
|
383
|
+
|
|
384
|
+
This is normal SGD/AdamW training - no communication needed.
|
|
385
|
+
Applies learning rate scheduling if enabled.
|
|
386
|
+
|
|
387
|
+
Args:
|
|
388
|
+
loss: Loss tensor from forward pass
|
|
389
|
+
|
|
390
|
+
Returns:
|
|
391
|
+
Loss value as float
|
|
392
|
+
"""
|
|
393
|
+
with self._lock:
|
|
394
|
+
if self.phase == DiLoCoPhase.IDLE:
|
|
395
|
+
self.start_inner_loop()
|
|
396
|
+
|
|
397
|
+
loss_value = loss.item()
|
|
398
|
+
|
|
399
|
+
# Gradient accumulation
|
|
400
|
+
scaled_loss = loss / self.config.gradient_accumulation
|
|
401
|
+
scaled_loss.backward()
|
|
402
|
+
|
|
403
|
+
self._accumulated_loss += loss_value
|
|
404
|
+
self._accumulation_count += 1
|
|
405
|
+
|
|
406
|
+
# Only step optimizer after accumulation
|
|
407
|
+
if self._accumulation_count >= self.config.gradient_accumulation:
|
|
408
|
+
# Apply learning rate scheduling before step
|
|
409
|
+
if self.config.use_lr_scheduler:
|
|
410
|
+
new_lr = self._compute_lr(self.stats.total_inner_steps)
|
|
411
|
+
self._apply_lr(new_lr)
|
|
412
|
+
|
|
413
|
+
# Gradient clipping
|
|
414
|
+
torch.nn.utils.clip_grad_norm_(
|
|
415
|
+
self.model.parameters(),
|
|
416
|
+
self.config.max_grad_norm
|
|
417
|
+
)
|
|
418
|
+
|
|
419
|
+
# Inner optimizer step
|
|
420
|
+
self.inner_optimizer.step()
|
|
421
|
+
self.inner_optimizer.zero_grad()
|
|
422
|
+
|
|
423
|
+
# Update stats
|
|
424
|
+
self.stats.inner_step_count += 1
|
|
425
|
+
self.stats.total_inner_steps += 1
|
|
426
|
+
self.stats.inner_loss_sum += self._accumulated_loss / self._accumulation_count
|
|
427
|
+
self.stats.inner_loss_count += 1
|
|
428
|
+
|
|
429
|
+
# Reset accumulation
|
|
430
|
+
self._accumulated_loss = 0.0
|
|
431
|
+
self._accumulation_count = 0
|
|
432
|
+
|
|
433
|
+
return loss_value
|
|
434
|
+
|
|
435
|
+
def should_sync(self) -> bool:
|
|
436
|
+
"""Check if we should trigger outer sync."""
|
|
437
|
+
return self.stats.inner_step_count >= self.config.inner_steps
|
|
438
|
+
|
|
439
|
+
# ==================== OUTER LOOP ====================
|
|
440
|
+
|
|
441
|
+
def compute_pseudo_gradient(self) -> Dict[str, torch.Tensor]:
|
|
442
|
+
"""
|
|
443
|
+
Compute pseudo-gradient (delta from initial weights).
|
|
444
|
+
|
|
445
|
+
Pseudo-gradient = current_weights - initial_weights
|
|
446
|
+
This represents the DIRECTION we improved during the inner loop.
|
|
447
|
+
|
|
448
|
+
The outer optimizer will then AMPLIFY this direction with momentum,
|
|
449
|
+
effectively saying "we moved this way and it reduced loss, so let's
|
|
450
|
+
continue moving this way".
|
|
451
|
+
|
|
452
|
+
NOTE: The sign here is CRITICAL for training to work!
|
|
453
|
+
- current - initial = direction we moved (positive = training progress)
|
|
454
|
+
- The outer optimizer ADDs this to weights, amplifying the improvement
|
|
455
|
+
|
|
456
|
+
Returns:
|
|
457
|
+
Dict of name -> pseudo-gradient tensor
|
|
458
|
+
"""
|
|
459
|
+
with self._lock:
|
|
460
|
+
self.phase = DiLoCoPhase.COMPUTING_DELTA
|
|
461
|
+
|
|
462
|
+
pseudo_grads = {}
|
|
463
|
+
total_norm = 0.0
|
|
464
|
+
|
|
465
|
+
for name, param in self.model.named_parameters():
|
|
466
|
+
if name in self.initial_weights:
|
|
467
|
+
# Delta = current - initial (direction we moved during training)
|
|
468
|
+
# This is POSITIVE when training made progress
|
|
469
|
+
delta = param.data - self.initial_weights[name]
|
|
470
|
+
pseudo_grads[name] = delta
|
|
471
|
+
total_norm += delta.norm().item() ** 2
|
|
472
|
+
|
|
473
|
+
# Update stats
|
|
474
|
+
self.stats.avg_pseudo_grad_norm = (total_norm ** 0.5)
|
|
475
|
+
|
|
476
|
+
logger.info(f"Computed pseudo-gradient: "
|
|
477
|
+
f"norm={self.stats.avg_pseudo_grad_norm:.4f}, "
|
|
478
|
+
f"params={len(pseudo_grads)}")
|
|
479
|
+
|
|
480
|
+
return pseudo_grads
|
|
481
|
+
|
|
482
|
+
async def sync_with_peers(
|
|
483
|
+
self,
|
|
484
|
+
pseudo_grads: Dict[str, torch.Tensor],
|
|
485
|
+
) -> Optional[Dict[str, torch.Tensor]]:
|
|
486
|
+
"""
|
|
487
|
+
Synchronize pseudo-gradients with peers.
|
|
488
|
+
|
|
489
|
+
Uses the sync callback to gossip and aggregate.
|
|
490
|
+
|
|
491
|
+
Args:
|
|
492
|
+
pseudo_grads: Local pseudo-gradients
|
|
493
|
+
|
|
494
|
+
Returns:
|
|
495
|
+
Aggregated pseudo-gradients from all peers, or None on failure
|
|
496
|
+
"""
|
|
497
|
+
with self._lock:
|
|
498
|
+
self.phase = DiLoCoPhase.SYNCING
|
|
499
|
+
|
|
500
|
+
if self._sync_callback is None:
|
|
501
|
+
logger.warning("No sync callback set - using local gradients only")
|
|
502
|
+
return pseudo_grads
|
|
503
|
+
|
|
504
|
+
start_time = time.time()
|
|
505
|
+
|
|
506
|
+
try:
|
|
507
|
+
# Call sync callback (should gossip to peers)
|
|
508
|
+
aggregated = await asyncio.wait_for(
|
|
509
|
+
self._sync_callback(pseudo_grads),
|
|
510
|
+
timeout=self.config.sync_timeout
|
|
511
|
+
)
|
|
512
|
+
|
|
513
|
+
sync_time = time.time() - start_time
|
|
514
|
+
self.stats.sync_time += sync_time
|
|
515
|
+
self.stats.successful_syncs += 1
|
|
516
|
+
|
|
517
|
+
logger.info(f"Sync completed in {sync_time:.2f}s")
|
|
518
|
+
|
|
519
|
+
return aggregated
|
|
520
|
+
|
|
521
|
+
except asyncio.TimeoutError:
|
|
522
|
+
logger.warning(f"Sync timeout after {self.config.sync_timeout}s")
|
|
523
|
+
self.stats.failed_syncs += 1
|
|
524
|
+
return None
|
|
525
|
+
|
|
526
|
+
except Exception as e:
|
|
527
|
+
logger.error(f"Sync failed: {e}")
|
|
528
|
+
self.stats.failed_syncs += 1
|
|
529
|
+
return None
|
|
530
|
+
|
|
531
|
+
def apply_outer_update(
|
|
532
|
+
self,
|
|
533
|
+
aggregated_pseudo_grads: Optional[Dict[str, torch.Tensor]] = None,
|
|
534
|
+
):
|
|
535
|
+
"""
|
|
536
|
+
Apply outer optimizer step with aggregated pseudo-gradients.
|
|
537
|
+
|
|
538
|
+
If no aggregated gradients provided, uses local pseudo-gradients.
|
|
539
|
+
|
|
540
|
+
Args:
|
|
541
|
+
aggregated_pseudo_grads: Aggregated pseudo-gradients from peers
|
|
542
|
+
"""
|
|
543
|
+
with self._lock:
|
|
544
|
+
self.phase = DiLoCoPhase.OUTER_STEP
|
|
545
|
+
start_time = time.time()
|
|
546
|
+
|
|
547
|
+
# Use local gradients if sync failed
|
|
548
|
+
if aggregated_pseudo_grads is None:
|
|
549
|
+
logger.warning("Using local pseudo-gradients (sync failed)")
|
|
550
|
+
aggregated_pseudo_grads = self.compute_pseudo_gradient()
|
|
551
|
+
self.stats.local_only_outer_steps += 1
|
|
552
|
+
|
|
553
|
+
# Apply outer optimizer
|
|
554
|
+
self.outer_optimizer.step(self.model, aggregated_pseudo_grads)
|
|
555
|
+
|
|
556
|
+
# Update stats
|
|
557
|
+
self.stats.outer_step_count += 1
|
|
558
|
+
self.stats.outer_loop_time += time.time() - start_time
|
|
559
|
+
|
|
560
|
+
logger.info(f"Outer step {self.stats.outer_step_count} complete "
|
|
561
|
+
f"(after {self.config.inner_steps} inner steps, "
|
|
562
|
+
f"avg_loss={self.stats.avg_inner_loss:.4f})")
|
|
563
|
+
|
|
564
|
+
# Callback
|
|
565
|
+
if self._on_outer_step:
|
|
566
|
+
self._on_outer_step(self.stats)
|
|
567
|
+
|
|
568
|
+
# Start new inner loop
|
|
569
|
+
self.start_inner_loop()
|
|
570
|
+
|
|
571
|
+
async def outer_step_async(self) -> bool:
|
|
572
|
+
"""
|
|
573
|
+
Execute full outer step: compute, sync, apply.
|
|
574
|
+
|
|
575
|
+
Async version that handles the full sync flow.
|
|
576
|
+
|
|
577
|
+
Returns:
|
|
578
|
+
True if sync succeeded, False if used local gradients
|
|
579
|
+
"""
|
|
580
|
+
# Compute pseudo-gradient
|
|
581
|
+
pseudo_grads = self.compute_pseudo_gradient()
|
|
582
|
+
|
|
583
|
+
# Sync with peers
|
|
584
|
+
aggregated = await self.sync_with_peers(pseudo_grads)
|
|
585
|
+
|
|
586
|
+
# Apply update
|
|
587
|
+
self.apply_outer_update(aggregated)
|
|
588
|
+
|
|
589
|
+
return aggregated is not None
|
|
590
|
+
|
|
591
|
+
def outer_step_sync(self) -> bool:
|
|
592
|
+
"""
|
|
593
|
+
Synchronous version of outer step.
|
|
594
|
+
|
|
595
|
+
Runs async outer step in new event loop.
|
|
596
|
+
"""
|
|
597
|
+
loop = asyncio.new_event_loop()
|
|
598
|
+
try:
|
|
599
|
+
return loop.run_until_complete(self.outer_step_async())
|
|
600
|
+
finally:
|
|
601
|
+
loop.close()
|
|
602
|
+
|
|
603
|
+
# ==================== UTILITIES ====================
|
|
604
|
+
|
|
605
|
+
def get_stats(self) -> Dict[str, Any]:
|
|
606
|
+
"""Get training statistics."""
|
|
607
|
+
with self._lock:
|
|
608
|
+
return {
|
|
609
|
+
'phase': self.phase.value,
|
|
610
|
+
'inner_step_count': self.stats.inner_step_count,
|
|
611
|
+
'outer_step_count': self.stats.outer_step_count,
|
|
612
|
+
'total_inner_steps': self.stats.total_inner_steps,
|
|
613
|
+
'avg_inner_loss': self.stats.avg_inner_loss,
|
|
614
|
+
'successful_syncs': self.stats.successful_syncs,
|
|
615
|
+
'failed_syncs': self.stats.failed_syncs,
|
|
616
|
+
'local_only_outer_steps': self.stats.local_only_outer_steps,
|
|
617
|
+
'avg_pseudo_grad_norm': self.stats.avg_pseudo_grad_norm,
|
|
618
|
+
'inner_loop_time': self.stats.inner_loop_time,
|
|
619
|
+
'outer_loop_time': self.stats.outer_loop_time,
|
|
620
|
+
'sync_time': self.stats.sync_time,
|
|
621
|
+
# Learning rate info
|
|
622
|
+
'current_lr': self._current_lr,
|
|
623
|
+
'base_lr': self._base_lr,
|
|
624
|
+
'min_lr': self._min_lr,
|
|
625
|
+
'lr_scheduler_enabled': self.config.use_lr_scheduler,
|
|
626
|
+
}
|
|
627
|
+
|
|
628
|
+
def state_dict(self) -> Dict[str, Any]:
|
|
629
|
+
"""Get full trainer state for checkpointing."""
|
|
630
|
+
with self._lock:
|
|
631
|
+
return {
|
|
632
|
+
'config': self.config.__dict__,
|
|
633
|
+
'inner_optimizer': self.inner_optimizer.state_dict(),
|
|
634
|
+
'outer_optimizer': self.outer_optimizer.state_dict(),
|
|
635
|
+
'initial_weights': {k: v.clone() for k, v in self.initial_weights.items()},
|
|
636
|
+
'stats': {
|
|
637
|
+
'inner_step_count': self.stats.inner_step_count,
|
|
638
|
+
'outer_step_count': self.stats.outer_step_count,
|
|
639
|
+
'total_inner_steps': self.stats.total_inner_steps,
|
|
640
|
+
},
|
|
641
|
+
'phase': self.phase.value,
|
|
642
|
+
}
|
|
643
|
+
|
|
644
|
+
def load_state_dict(self, state: Dict[str, Any]):
|
|
645
|
+
"""Load trainer state from checkpoint."""
|
|
646
|
+
with self._lock:
|
|
647
|
+
# Load config
|
|
648
|
+
for k, v in state.get('config', {}).items():
|
|
649
|
+
if hasattr(self.config, k):
|
|
650
|
+
setattr(self.config, k, v)
|
|
651
|
+
|
|
652
|
+
# Load optimizers (move tensors to model's device)
|
|
653
|
+
device = next(self.model.parameters()).device if list(self.model.parameters()) else 'cpu'
|
|
654
|
+
if 'inner_optimizer' in state:
|
|
655
|
+
self.inner_optimizer.load_state_dict(state['inner_optimizer'])
|
|
656
|
+
if 'outer_optimizer' in state:
|
|
657
|
+
self.outer_optimizer.load_state_dict(state['outer_optimizer'], device=str(device))
|
|
658
|
+
|
|
659
|
+
# Load initial weights (move to model's device)
|
|
660
|
+
device = next(self.model.parameters()).device if list(self.model.parameters()) else 'cpu'
|
|
661
|
+
self.initial_weights = {
|
|
662
|
+
k: v.clone().to(device) for k, v in state.get('initial_weights', {}).items()
|
|
663
|
+
}
|
|
664
|
+
|
|
665
|
+
# Load stats
|
|
666
|
+
stats = state.get('stats', {})
|
|
667
|
+
self.stats.inner_step_count = stats.get('inner_step_count', 0)
|
|
668
|
+
self.stats.outer_step_count = stats.get('outer_step_count', 0)
|
|
669
|
+
self.stats.total_inner_steps = stats.get('total_inner_steps', 0)
|
|
670
|
+
|
|
671
|
+
# Load phase
|
|
672
|
+
phase_str = state.get('phase', 'idle')
|
|
673
|
+
try:
|
|
674
|
+
self.phase = DiLoCoPhase(phase_str)
|
|
675
|
+
except ValueError:
|
|
676
|
+
self.phase = DiLoCoPhase.IDLE
|
|
677
|
+
|
|
678
|
+
|
|
679
|
+
# ==================== GOSSIP INTEGRATION ====================
|
|
680
|
+
|
|
681
|
+
class DiLoCoGossipProtocol:
|
|
682
|
+
"""
|
|
683
|
+
Gossip protocol for DiLoCo pseudo-gradient synchronization.
|
|
684
|
+
|
|
685
|
+
Handles the communication aspect of DiLoCo:
|
|
686
|
+
- Broadcast pseudo-gradients to peers
|
|
687
|
+
- Collect and aggregate responses
|
|
688
|
+
- Handle stragglers with timeout
|
|
689
|
+
"""
|
|
690
|
+
|
|
691
|
+
def __init__(
|
|
692
|
+
self,
|
|
693
|
+
node_id: str,
|
|
694
|
+
router: Any = None, # SwarmRouter
|
|
695
|
+
min_peers: int = 1,
|
|
696
|
+
timeout: float = 30.0,
|
|
697
|
+
):
|
|
698
|
+
self.node_id = node_id
|
|
699
|
+
self.router = router
|
|
700
|
+
self.min_peers = min_peers
|
|
701
|
+
self.timeout = timeout
|
|
702
|
+
|
|
703
|
+
# Pending contributions
|
|
704
|
+
self.pending_contributions: Dict[str, Dict[str, torch.Tensor]] = {}
|
|
705
|
+
self._lock = threading.Lock()
|
|
706
|
+
|
|
707
|
+
async def sync_pseudo_gradients(
|
|
708
|
+
self,
|
|
709
|
+
round_id: int,
|
|
710
|
+
local_grads: Dict[str, torch.Tensor],
|
|
711
|
+
) -> Dict[str, torch.Tensor]:
|
|
712
|
+
"""
|
|
713
|
+
Synchronize pseudo-gradients with peers.
|
|
714
|
+
|
|
715
|
+
Args:
|
|
716
|
+
round_id: Outer step round ID
|
|
717
|
+
local_grads: Local pseudo-gradients
|
|
718
|
+
|
|
719
|
+
Returns:
|
|
720
|
+
Averaged pseudo-gradients
|
|
721
|
+
"""
|
|
722
|
+
# Get peers
|
|
723
|
+
if self.router is None:
|
|
724
|
+
logger.warning("No router - returning local gradients")
|
|
725
|
+
return local_grads
|
|
726
|
+
|
|
727
|
+
# Broadcast to peers
|
|
728
|
+
await self._broadcast_grads(round_id, local_grads)
|
|
729
|
+
|
|
730
|
+
# Wait for responses
|
|
731
|
+
contributions = await self._collect_contributions(round_id)
|
|
732
|
+
|
|
733
|
+
# Add our own contribution
|
|
734
|
+
contributions[self.node_id] = local_grads
|
|
735
|
+
|
|
736
|
+
# Aggregate
|
|
737
|
+
if len(contributions) < self.min_peers:
|
|
738
|
+
logger.warning(f"Only {len(contributions)} peers - below minimum {self.min_peers}")
|
|
739
|
+
|
|
740
|
+
aggregated = self._aggregate_contributions(contributions)
|
|
741
|
+
|
|
742
|
+
return aggregated
|
|
743
|
+
|
|
744
|
+
async def _broadcast_grads(
|
|
745
|
+
self,
|
|
746
|
+
round_id: int,
|
|
747
|
+
grads: Dict[str, torch.Tensor],
|
|
748
|
+
):
|
|
749
|
+
"""Broadcast pseudo-gradients to peers."""
|
|
750
|
+
# Implementation would use gRPC to broadcast
|
|
751
|
+
# Placeholder for integration with SwarmRouter
|
|
752
|
+
pass
|
|
753
|
+
|
|
754
|
+
async def _collect_contributions(
|
|
755
|
+
self,
|
|
756
|
+
round_id: int,
|
|
757
|
+
) -> Dict[str, Dict[str, torch.Tensor]]:
|
|
758
|
+
"""Collect contributions from peers."""
|
|
759
|
+
# Wait up to timeout for contributions
|
|
760
|
+
start = time.time()
|
|
761
|
+
|
|
762
|
+
while time.time() - start < self.timeout:
|
|
763
|
+
with self._lock:
|
|
764
|
+
if len(self.pending_contributions) >= self.min_peers:
|
|
765
|
+
contributions = dict(self.pending_contributions)
|
|
766
|
+
self.pending_contributions.clear()
|
|
767
|
+
return contributions
|
|
768
|
+
|
|
769
|
+
await asyncio.sleep(0.1)
|
|
770
|
+
|
|
771
|
+
# Timeout - return what we have
|
|
772
|
+
with self._lock:
|
|
773
|
+
contributions = dict(self.pending_contributions)
|
|
774
|
+
self.pending_contributions.clear()
|
|
775
|
+
return contributions
|
|
776
|
+
|
|
777
|
+
def _aggregate_contributions(
|
|
778
|
+
self,
|
|
779
|
+
contributions: Dict[str, Dict[str, torch.Tensor]],
|
|
780
|
+
) -> Dict[str, torch.Tensor]:
|
|
781
|
+
"""Average contributions from all peers."""
|
|
782
|
+
if not contributions:
|
|
783
|
+
return {}
|
|
784
|
+
|
|
785
|
+
# Get all param names
|
|
786
|
+
param_names = set()
|
|
787
|
+
for grads in contributions.values():
|
|
788
|
+
param_names.update(grads.keys())
|
|
789
|
+
|
|
790
|
+
# Average each parameter
|
|
791
|
+
aggregated = {}
|
|
792
|
+
for name in param_names:
|
|
793
|
+
tensors = [
|
|
794
|
+
grads[name] for grads in contributions.values()
|
|
795
|
+
if name in grads
|
|
796
|
+
]
|
|
797
|
+
if tensors:
|
|
798
|
+
aggregated[name] = torch.stack(tensors).mean(dim=0)
|
|
799
|
+
|
|
800
|
+
return aggregated
|
|
801
|
+
|
|
802
|
+
def receive_contribution(
|
|
803
|
+
self,
|
|
804
|
+
peer_id: str,
|
|
805
|
+
grads: Dict[str, torch.Tensor],
|
|
806
|
+
):
|
|
807
|
+
"""Receive pseudo-gradient contribution from peer."""
|
|
808
|
+
with self._lock:
|
|
809
|
+
self.pending_contributions[peer_id] = grads
|
|
810
|
+
|
|
811
|
+
|
|
812
|
+
# ==================== FACTORY FUNCTIONS ====================
|
|
813
|
+
|
|
814
|
+
def create_diloco_trainer(
|
|
815
|
+
model: nn.Module,
|
|
816
|
+
inner_steps: int = 500,
|
|
817
|
+
outer_lr: float = 0.7,
|
|
818
|
+
inner_lr: float = 1e-4,
|
|
819
|
+
device: str = "cpu",
|
|
820
|
+
**config_kwargs,
|
|
821
|
+
) -> DiLoCoTrainer:
|
|
822
|
+
"""
|
|
823
|
+
Factory function to create a DiLoCo trainer.
|
|
824
|
+
|
|
825
|
+
Args:
|
|
826
|
+
model: Model to train
|
|
827
|
+
inner_steps: Steps before each sync
|
|
828
|
+
outer_lr: Outer optimizer learning rate
|
|
829
|
+
inner_lr: Inner optimizer learning rate
|
|
830
|
+
device: Training device
|
|
831
|
+
**config_kwargs: Additional config options
|
|
832
|
+
|
|
833
|
+
Returns:
|
|
834
|
+
Configured DiLoCoTrainer
|
|
835
|
+
"""
|
|
836
|
+
config = DiLoCoConfig(
|
|
837
|
+
inner_steps=inner_steps,
|
|
838
|
+
outer_lr=outer_lr,
|
|
839
|
+
inner_lr=inner_lr,
|
|
840
|
+
**config_kwargs,
|
|
841
|
+
)
|
|
842
|
+
|
|
843
|
+
return DiLoCoTrainer(model, config=config, device=device)
|
|
844
|
+
|