nexaroa 0.0.111__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- neuroshard/__init__.py +93 -0
- neuroshard/__main__.py +4 -0
- neuroshard/cli.py +466 -0
- neuroshard/core/__init__.py +92 -0
- neuroshard/core/consensus/verifier.py +252 -0
- neuroshard/core/crypto/__init__.py +20 -0
- neuroshard/core/crypto/ecdsa.py +392 -0
- neuroshard/core/economics/__init__.py +52 -0
- neuroshard/core/economics/constants.py +387 -0
- neuroshard/core/economics/ledger.py +2111 -0
- neuroshard/core/economics/market.py +975 -0
- neuroshard/core/economics/wallet.py +168 -0
- neuroshard/core/governance/__init__.py +74 -0
- neuroshard/core/governance/proposal.py +561 -0
- neuroshard/core/governance/registry.py +545 -0
- neuroshard/core/governance/versioning.py +332 -0
- neuroshard/core/governance/voting.py +453 -0
- neuroshard/core/model/__init__.py +30 -0
- neuroshard/core/model/dynamic.py +4186 -0
- neuroshard/core/model/llm.py +905 -0
- neuroshard/core/model/registry.py +164 -0
- neuroshard/core/model/scaler.py +387 -0
- neuroshard/core/model/tokenizer.py +568 -0
- neuroshard/core/network/__init__.py +56 -0
- neuroshard/core/network/connection_pool.py +72 -0
- neuroshard/core/network/dht.py +130 -0
- neuroshard/core/network/dht_plan.py +55 -0
- neuroshard/core/network/dht_proof_store.py +516 -0
- neuroshard/core/network/dht_protocol.py +261 -0
- neuroshard/core/network/dht_service.py +506 -0
- neuroshard/core/network/encrypted_channel.py +141 -0
- neuroshard/core/network/nat.py +201 -0
- neuroshard/core/network/nat_traversal.py +695 -0
- neuroshard/core/network/p2p.py +929 -0
- neuroshard/core/network/p2p_data.py +150 -0
- neuroshard/core/swarm/__init__.py +106 -0
- neuroshard/core/swarm/aggregation.py +729 -0
- neuroshard/core/swarm/buffers.py +643 -0
- neuroshard/core/swarm/checkpoint.py +709 -0
- neuroshard/core/swarm/compute.py +624 -0
- neuroshard/core/swarm/diloco.py +844 -0
- neuroshard/core/swarm/factory.py +1288 -0
- neuroshard/core/swarm/heartbeat.py +669 -0
- neuroshard/core/swarm/logger.py +487 -0
- neuroshard/core/swarm/router.py +658 -0
- neuroshard/core/swarm/service.py +640 -0
- neuroshard/core/training/__init__.py +29 -0
- neuroshard/core/training/checkpoint.py +600 -0
- neuroshard/core/training/distributed.py +1602 -0
- neuroshard/core/training/global_tracker.py +617 -0
- neuroshard/core/training/production.py +276 -0
- neuroshard/governance_cli.py +729 -0
- neuroshard/grpc_server.py +895 -0
- neuroshard/runner.py +3223 -0
- neuroshard/sdk/__init__.py +92 -0
- neuroshard/sdk/client.py +990 -0
- neuroshard/sdk/errors.py +101 -0
- neuroshard/sdk/types.py +282 -0
- neuroshard/tracker/__init__.py +0 -0
- neuroshard/tracker/server.py +864 -0
- neuroshard/ui/__init__.py +0 -0
- neuroshard/ui/app.py +102 -0
- neuroshard/ui/templates/index.html +1052 -0
- neuroshard/utils/__init__.py +0 -0
- neuroshard/utils/autostart.py +81 -0
- neuroshard/utils/hardware.py +121 -0
- neuroshard/utils/serialization.py +90 -0
- neuroshard/version.py +1 -0
- nexaroa-0.0.111.dist-info/METADATA +283 -0
- nexaroa-0.0.111.dist-info/RECORD +78 -0
- nexaroa-0.0.111.dist-info/WHEEL +5 -0
- nexaroa-0.0.111.dist-info/entry_points.txt +4 -0
- nexaroa-0.0.111.dist-info/licenses/LICENSE +190 -0
- nexaroa-0.0.111.dist-info/top_level.txt +2 -0
- protos/__init__.py +0 -0
- protos/neuroshard.proto +651 -0
- protos/neuroshard_pb2.py +160 -0
- protos/neuroshard_pb2_grpc.py +1298 -0
|
@@ -0,0 +1,617 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Global Training Tracker - Verify Distributed Training is Working
|
|
3
|
+
|
|
4
|
+
This module provides network-wide training verification:
|
|
5
|
+
1. Tracks loss across ALL nodes in the network
|
|
6
|
+
2. Monitors model hash convergence (ensures nodes sync)
|
|
7
|
+
3. Computes global training metrics (not just local batch loss)
|
|
8
|
+
4. Provides dashboards/APIs for monitoring
|
|
9
|
+
|
|
10
|
+
Key Concepts:
|
|
11
|
+
- Moving Average Loss: Smoothed loss over time (local)
|
|
12
|
+
- Global Loss: Average loss across all network nodes
|
|
13
|
+
- Model Hash: SHA256 of model weights (should converge across nodes)
|
|
14
|
+
- Sync Rate: How often nodes successfully sync gradients
|
|
15
|
+
"""
|
|
16
|
+
|
|
17
|
+
import time
|
|
18
|
+
import hashlib
|
|
19
|
+
import threading
|
|
20
|
+
import logging
|
|
21
|
+
import json
|
|
22
|
+
from collections import deque
|
|
23
|
+
from dataclasses import dataclass, field
|
|
24
|
+
from typing import Dict, List, Optional, Any, Tuple
|
|
25
|
+
from pathlib import Path
|
|
26
|
+
import torch
|
|
27
|
+
import torch.nn as nn
|
|
28
|
+
|
|
29
|
+
logger = logging.getLogger(__name__)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
@dataclass
|
|
33
|
+
class TrainingSnapshot:
|
|
34
|
+
"""A snapshot of training state at a point in time."""
|
|
35
|
+
timestamp: float
|
|
36
|
+
node_id: str
|
|
37
|
+
|
|
38
|
+
# Loss metrics
|
|
39
|
+
batch_loss: float # Raw batch loss
|
|
40
|
+
moving_avg_loss: float # Smoothed loss (EMA)
|
|
41
|
+
min_loss_seen: float # Best loss achieved
|
|
42
|
+
|
|
43
|
+
# Training progress
|
|
44
|
+
training_step: int # Global step count
|
|
45
|
+
inner_step: int # DiLoCo inner step (0-500)
|
|
46
|
+
outer_step: int # DiLoCo outer step (sync count)
|
|
47
|
+
|
|
48
|
+
# Convergence metrics
|
|
49
|
+
model_hash: str # Hash of model weights
|
|
50
|
+
gradient_norm: float # L2 norm of gradients
|
|
51
|
+
|
|
52
|
+
# Data coverage
|
|
53
|
+
shard_id: int # Current data shard
|
|
54
|
+
tokens_trained: int # Total tokens seen
|
|
55
|
+
|
|
56
|
+
|
|
57
|
+
@dataclass
|
|
58
|
+
class GlobalTrainingStats:
|
|
59
|
+
"""Network-wide training statistics."""
|
|
60
|
+
# Aggregated metrics
|
|
61
|
+
global_avg_loss: float = 0.0
|
|
62
|
+
global_min_loss: float = float('inf')
|
|
63
|
+
|
|
64
|
+
# Convergence tracking
|
|
65
|
+
model_hashes: Dict[str, str] = field(default_factory=dict) # node_id -> hash
|
|
66
|
+
hash_agreement_rate: float = 0.0 # % of nodes with same hash
|
|
67
|
+
|
|
68
|
+
# Network health
|
|
69
|
+
total_nodes_training: int = 0
|
|
70
|
+
successful_syncs: int = 0
|
|
71
|
+
failed_syncs: int = 0
|
|
72
|
+
|
|
73
|
+
# Progress
|
|
74
|
+
global_steps: int = 0
|
|
75
|
+
global_tokens: int = 0
|
|
76
|
+
data_shards_covered: set = field(default_factory=set)
|
|
77
|
+
|
|
78
|
+
# Time tracking
|
|
79
|
+
last_update: float = 0.0
|
|
80
|
+
|
|
81
|
+
|
|
82
|
+
class GlobalTrainingTracker:
|
|
83
|
+
"""
|
|
84
|
+
Tracks and verifies distributed training across the network.
|
|
85
|
+
|
|
86
|
+
Each node runs this to:
|
|
87
|
+
1. Track its own training progress
|
|
88
|
+
2. Receive training stats from peers via gossip
|
|
89
|
+
3. Compute global metrics
|
|
90
|
+
4. Verify convergence (all nodes should have similar model hash)
|
|
91
|
+
|
|
92
|
+
Usage:
|
|
93
|
+
tracker = GlobalTrainingTracker(node_id, model)
|
|
94
|
+
|
|
95
|
+
# During training
|
|
96
|
+
tracker.record_step(loss, step, shard_id)
|
|
97
|
+
|
|
98
|
+
# Get status
|
|
99
|
+
status = tracker.get_global_status()
|
|
100
|
+
print(f"Global loss: {status['global_avg_loss']:.4f}")
|
|
101
|
+
print(f"Hash agreement: {status['hash_agreement_rate']*100:.1f}%")
|
|
102
|
+
"""
|
|
103
|
+
|
|
104
|
+
# EMA smoothing factor (lower = smoother)
|
|
105
|
+
EMA_ALPHA = 0.1
|
|
106
|
+
|
|
107
|
+
# History window for computing metrics
|
|
108
|
+
HISTORY_WINDOW = 100
|
|
109
|
+
|
|
110
|
+
# Minimum nodes for global metrics
|
|
111
|
+
MIN_NODES_FOR_GLOBAL = 1
|
|
112
|
+
|
|
113
|
+
def __init__(
|
|
114
|
+
self,
|
|
115
|
+
node_id: str,
|
|
116
|
+
model: nn.Module,
|
|
117
|
+
checkpoint_dir: Optional[Path] = None,
|
|
118
|
+
):
|
|
119
|
+
self.node_id = node_id
|
|
120
|
+
self.model = model
|
|
121
|
+
self.checkpoint_dir = checkpoint_dir or Path.home() / ".neuroshard" / "training_logs"
|
|
122
|
+
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
|
123
|
+
|
|
124
|
+
# Local tracking
|
|
125
|
+
self._local_history: deque = deque(maxlen=self.HISTORY_WINDOW)
|
|
126
|
+
self._moving_avg_loss = 0.0
|
|
127
|
+
self._min_loss = float('inf')
|
|
128
|
+
self._total_steps = 0
|
|
129
|
+
self._total_tokens = 0
|
|
130
|
+
self._current_shard = 0
|
|
131
|
+
|
|
132
|
+
# Peer tracking (received via gossip)
|
|
133
|
+
self._peer_stats: Dict[str, TrainingSnapshot] = {}
|
|
134
|
+
self._peer_stats_lock = threading.Lock()
|
|
135
|
+
|
|
136
|
+
# Global aggregated stats
|
|
137
|
+
self._global_stats = GlobalTrainingStats()
|
|
138
|
+
|
|
139
|
+
# Sync tracking
|
|
140
|
+
self._sync_history: deque = deque(maxlen=50) # (timestamp, success)
|
|
141
|
+
self._last_model_hash = ""
|
|
142
|
+
|
|
143
|
+
# Verification
|
|
144
|
+
self._loss_checkpoints: List[Tuple[int, float]] = [] # (step, loss)
|
|
145
|
+
|
|
146
|
+
# Try to restore previous state
|
|
147
|
+
self._load_state()
|
|
148
|
+
|
|
149
|
+
logger.info(f"GlobalTrainingTracker initialized for node {node_id[:8]}...")
|
|
150
|
+
|
|
151
|
+
def _get_state_path(self) -> Path:
|
|
152
|
+
"""Get path to state file."""
|
|
153
|
+
return self.checkpoint_dir / f"tracker_state_{self.node_id[:16]}.json"
|
|
154
|
+
|
|
155
|
+
def _save_state(self):
|
|
156
|
+
"""Persist tracker state to disk."""
|
|
157
|
+
try:
|
|
158
|
+
state = {
|
|
159
|
+
"node_id": self.node_id,
|
|
160
|
+
"saved_at": time.time(),
|
|
161
|
+
"moving_avg_loss": self._moving_avg_loss,
|
|
162
|
+
"min_loss": self._min_loss if self._min_loss != float('inf') else None,
|
|
163
|
+
"total_steps": self._total_steps,
|
|
164
|
+
"total_tokens": self._total_tokens,
|
|
165
|
+
"current_shard": self._current_shard,
|
|
166
|
+
"last_model_hash": self._last_model_hash,
|
|
167
|
+
"loss_checkpoints": self._loss_checkpoints[-100:], # Keep last 100
|
|
168
|
+
"global_stats": {
|
|
169
|
+
"global_avg_loss": self._global_stats.global_avg_loss,
|
|
170
|
+
"global_min_loss": self._global_stats.global_min_loss if self._global_stats.global_min_loss != float('inf') else None,
|
|
171
|
+
"total_nodes_training": self._global_stats.total_nodes_training,
|
|
172
|
+
"global_steps": self._global_stats.global_steps,
|
|
173
|
+
"global_tokens": self._global_stats.global_tokens,
|
|
174
|
+
"successful_syncs": self._global_stats.successful_syncs,
|
|
175
|
+
"failed_syncs": self._global_stats.failed_syncs,
|
|
176
|
+
},
|
|
177
|
+
}
|
|
178
|
+
|
|
179
|
+
with open(self._get_state_path(), 'w') as f:
|
|
180
|
+
json.dump(state, f, indent=2, default=str)
|
|
181
|
+
|
|
182
|
+
logger.debug(f"Tracker state saved: {self._total_steps} steps, loss={self._moving_avg_loss:.4f}")
|
|
183
|
+
except Exception as e:
|
|
184
|
+
logger.warning(f"Failed to save tracker state: {e}")
|
|
185
|
+
|
|
186
|
+
def _load_state(self):
|
|
187
|
+
"""Load tracker state from disk."""
|
|
188
|
+
path = self._get_state_path()
|
|
189
|
+
if not path.exists():
|
|
190
|
+
logger.debug("No previous tracker state found")
|
|
191
|
+
return
|
|
192
|
+
|
|
193
|
+
try:
|
|
194
|
+
with open(path, 'r') as f:
|
|
195
|
+
state = json.load(f)
|
|
196
|
+
|
|
197
|
+
# Restore local state
|
|
198
|
+
self._moving_avg_loss = state.get("moving_avg_loss", 0.0)
|
|
199
|
+
self._min_loss = state.get("min_loss") or float('inf')
|
|
200
|
+
self._total_steps = state.get("total_steps", 0)
|
|
201
|
+
self._total_tokens = state.get("total_tokens", 0)
|
|
202
|
+
self._current_shard = state.get("current_shard", 0)
|
|
203
|
+
self._last_model_hash = state.get("last_model_hash", "")
|
|
204
|
+
self._loss_checkpoints = state.get("loss_checkpoints", [])
|
|
205
|
+
|
|
206
|
+
# Restore global stats
|
|
207
|
+
global_stats = state.get("global_stats", {})
|
|
208
|
+
self._global_stats.global_avg_loss = global_stats.get("global_avg_loss", 0.0)
|
|
209
|
+
self._global_stats.global_min_loss = global_stats.get("global_min_loss") or float('inf')
|
|
210
|
+
self._global_stats.total_nodes_training = global_stats.get("total_nodes_training", 0)
|
|
211
|
+
self._global_stats.global_steps = global_stats.get("global_steps", 0)
|
|
212
|
+
self._global_stats.global_tokens = global_stats.get("global_tokens", 0)
|
|
213
|
+
self._global_stats.successful_syncs = global_stats.get("successful_syncs", 0)
|
|
214
|
+
self._global_stats.failed_syncs = global_stats.get("failed_syncs", 0)
|
|
215
|
+
|
|
216
|
+
logger.info(f"Restored tracker state: {self._total_steps} steps, avg_loss={self._moving_avg_loss:.4f}")
|
|
217
|
+
except Exception as e:
|
|
218
|
+
logger.warning(f"Failed to load tracker state: {e}")
|
|
219
|
+
|
|
220
|
+
def record_step(
|
|
221
|
+
self,
|
|
222
|
+
loss: float,
|
|
223
|
+
step: int,
|
|
224
|
+
shard_id: int = 0,
|
|
225
|
+
tokens_in_batch: int = 0,
|
|
226
|
+
gradient_norm: Optional[float] = None,
|
|
227
|
+
inner_step: int = 0,
|
|
228
|
+
outer_step: int = 0,
|
|
229
|
+
) -> TrainingSnapshot:
|
|
230
|
+
"""
|
|
231
|
+
Record a training step and update metrics.
|
|
232
|
+
|
|
233
|
+
Args:
|
|
234
|
+
loss: Raw loss value for this batch
|
|
235
|
+
step: Global training step number
|
|
236
|
+
shard_id: Data shard being trained on
|
|
237
|
+
tokens_in_batch: Tokens processed in this batch
|
|
238
|
+
gradient_norm: Optional gradient L2 norm
|
|
239
|
+
inner_step: DiLoCo inner step (0-500)
|
|
240
|
+
outer_step: DiLoCo outer sync count
|
|
241
|
+
|
|
242
|
+
Returns:
|
|
243
|
+
TrainingSnapshot with current state
|
|
244
|
+
"""
|
|
245
|
+
# Update EMA loss
|
|
246
|
+
if self._moving_avg_loss == 0.0:
|
|
247
|
+
self._moving_avg_loss = loss
|
|
248
|
+
else:
|
|
249
|
+
self._moving_avg_loss = (
|
|
250
|
+
self.EMA_ALPHA * loss +
|
|
251
|
+
(1 - self.EMA_ALPHA) * self._moving_avg_loss
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
# Track minimum loss
|
|
255
|
+
if loss < self._min_loss:
|
|
256
|
+
self._min_loss = loss
|
|
257
|
+
|
|
258
|
+
# Update counters
|
|
259
|
+
self._total_steps = step
|
|
260
|
+
self._total_tokens += tokens_in_batch
|
|
261
|
+
self._current_shard = shard_id
|
|
262
|
+
|
|
263
|
+
# Compute model hash (every 50 steps to save compute)
|
|
264
|
+
if step % 50 == 0:
|
|
265
|
+
self._last_model_hash = self._compute_model_hash()
|
|
266
|
+
|
|
267
|
+
# Create snapshot
|
|
268
|
+
snapshot = TrainingSnapshot(
|
|
269
|
+
timestamp=time.time(),
|
|
270
|
+
node_id=self.node_id,
|
|
271
|
+
batch_loss=loss,
|
|
272
|
+
moving_avg_loss=self._moving_avg_loss,
|
|
273
|
+
min_loss_seen=self._min_loss,
|
|
274
|
+
training_step=step,
|
|
275
|
+
inner_step=inner_step,
|
|
276
|
+
outer_step=outer_step,
|
|
277
|
+
model_hash=self._last_model_hash,
|
|
278
|
+
gradient_norm=gradient_norm or 0.0,
|
|
279
|
+
shard_id=shard_id,
|
|
280
|
+
tokens_trained=self._total_tokens,
|
|
281
|
+
)
|
|
282
|
+
|
|
283
|
+
self._local_history.append(snapshot)
|
|
284
|
+
|
|
285
|
+
# Record loss checkpoint every 100 steps
|
|
286
|
+
if step % 100 == 0:
|
|
287
|
+
self._loss_checkpoints.append((step, self._moving_avg_loss))
|
|
288
|
+
# Keep only last 100 checkpoints
|
|
289
|
+
if len(self._loss_checkpoints) > 100:
|
|
290
|
+
self._loss_checkpoints = self._loss_checkpoints[-100:]
|
|
291
|
+
|
|
292
|
+
# Update global stats
|
|
293
|
+
self._update_global_stats()
|
|
294
|
+
|
|
295
|
+
# Periodically save state (every 10 steps to match checkpoint frequency)
|
|
296
|
+
if step % 10 == 0:
|
|
297
|
+
self._save_state()
|
|
298
|
+
|
|
299
|
+
return snapshot
|
|
300
|
+
|
|
301
|
+
def _compute_model_hash(self) -> str:
|
|
302
|
+
"""Compute SHA256 hash of model weights (sampled for speed)."""
|
|
303
|
+
hasher = hashlib.sha256()
|
|
304
|
+
|
|
305
|
+
# Sample some parameters for speed
|
|
306
|
+
params_to_hash = list(self.model.named_parameters())[:10]
|
|
307
|
+
|
|
308
|
+
for name, param in params_to_hash:
|
|
309
|
+
hasher.update(name.encode())
|
|
310
|
+
# Sample first 1000 values
|
|
311
|
+
data = param.data.flatten()[:1000].cpu().numpy().tobytes()
|
|
312
|
+
hasher.update(data)
|
|
313
|
+
|
|
314
|
+
return hasher.hexdigest()[:16]
|
|
315
|
+
|
|
316
|
+
def receive_peer_stats(self, peer_id: str, snapshot_data: Dict[str, Any]):
|
|
317
|
+
"""
|
|
318
|
+
Receive training stats from a peer via gossip.
|
|
319
|
+
|
|
320
|
+
Args:
|
|
321
|
+
peer_id: ID of peer node
|
|
322
|
+
snapshot_data: Serialized TrainingSnapshot
|
|
323
|
+
"""
|
|
324
|
+
with self._peer_stats_lock:
|
|
325
|
+
snapshot = TrainingSnapshot(
|
|
326
|
+
timestamp=snapshot_data.get("timestamp", time.time()),
|
|
327
|
+
node_id=peer_id,
|
|
328
|
+
batch_loss=snapshot_data.get("batch_loss", 0.0),
|
|
329
|
+
moving_avg_loss=snapshot_data.get("moving_avg_loss", 0.0),
|
|
330
|
+
min_loss_seen=snapshot_data.get("min_loss_seen", float('inf')),
|
|
331
|
+
training_step=snapshot_data.get("training_step", 0),
|
|
332
|
+
inner_step=snapshot_data.get("inner_step", 0),
|
|
333
|
+
outer_step=snapshot_data.get("outer_step", 0),
|
|
334
|
+
model_hash=snapshot_data.get("model_hash", ""),
|
|
335
|
+
gradient_norm=snapshot_data.get("gradient_norm", 0.0),
|
|
336
|
+
shard_id=snapshot_data.get("shard_id", 0),
|
|
337
|
+
tokens_trained=snapshot_data.get("tokens_trained", 0),
|
|
338
|
+
)
|
|
339
|
+
|
|
340
|
+
self._peer_stats[peer_id] = snapshot
|
|
341
|
+
|
|
342
|
+
# Clean up stale peers (>5 min old)
|
|
343
|
+
now = time.time()
|
|
344
|
+
stale_peers = [
|
|
345
|
+
pid for pid, s in self._peer_stats.items()
|
|
346
|
+
if now - s.timestamp > 300
|
|
347
|
+
]
|
|
348
|
+
for pid in stale_peers:
|
|
349
|
+
del self._peer_stats[pid]
|
|
350
|
+
|
|
351
|
+
self._update_global_stats()
|
|
352
|
+
|
|
353
|
+
def record_sync_result(self, success: bool, peers_synced: int = 0):
|
|
354
|
+
"""Record a gradient sync attempt result."""
|
|
355
|
+
self._sync_history.append((time.time(), success, peers_synced))
|
|
356
|
+
|
|
357
|
+
if success:
|
|
358
|
+
self._global_stats.successful_syncs += 1
|
|
359
|
+
else:
|
|
360
|
+
self._global_stats.failed_syncs += 1
|
|
361
|
+
|
|
362
|
+
# Persist state after each sync (important milestone)
|
|
363
|
+
self._save_state()
|
|
364
|
+
|
|
365
|
+
def _update_global_stats(self):
|
|
366
|
+
"""Recompute global statistics from local + peer data."""
|
|
367
|
+
with self._peer_stats_lock:
|
|
368
|
+
# Collect all snapshots (local + peers)
|
|
369
|
+
all_snapshots = list(self._peer_stats.values())
|
|
370
|
+
|
|
371
|
+
# Add our own latest
|
|
372
|
+
if self._local_history:
|
|
373
|
+
all_snapshots.append(self._local_history[-1])
|
|
374
|
+
|
|
375
|
+
if not all_snapshots:
|
|
376
|
+
return
|
|
377
|
+
|
|
378
|
+
# Compute global averages
|
|
379
|
+
self._global_stats.global_avg_loss = sum(
|
|
380
|
+
s.moving_avg_loss for s in all_snapshots
|
|
381
|
+
) / len(all_snapshots)
|
|
382
|
+
|
|
383
|
+
self._global_stats.global_min_loss = min(
|
|
384
|
+
s.min_loss_seen for s in all_snapshots
|
|
385
|
+
)
|
|
386
|
+
|
|
387
|
+
self._global_stats.total_nodes_training = len(all_snapshots)
|
|
388
|
+
|
|
389
|
+
self._global_stats.global_steps = max(
|
|
390
|
+
s.training_step for s in all_snapshots
|
|
391
|
+
)
|
|
392
|
+
|
|
393
|
+
self._global_stats.global_tokens = sum(
|
|
394
|
+
s.tokens_trained for s in all_snapshots
|
|
395
|
+
)
|
|
396
|
+
|
|
397
|
+
# Track data shard coverage
|
|
398
|
+
self._global_stats.data_shards_covered = set(
|
|
399
|
+
s.shard_id for s in all_snapshots
|
|
400
|
+
)
|
|
401
|
+
|
|
402
|
+
# Check model hash convergence
|
|
403
|
+
hashes = [s.model_hash for s in all_snapshots if s.model_hash]
|
|
404
|
+
if hashes:
|
|
405
|
+
# Count most common hash
|
|
406
|
+
from collections import Counter
|
|
407
|
+
hash_counts = Counter(hashes)
|
|
408
|
+
most_common_hash, count = hash_counts.most_common(1)[0]
|
|
409
|
+
self._global_stats.hash_agreement_rate = count / len(hashes)
|
|
410
|
+
self._global_stats.model_hashes = {
|
|
411
|
+
s.node_id: s.model_hash for s in all_snapshots
|
|
412
|
+
}
|
|
413
|
+
|
|
414
|
+
self._global_stats.last_update = time.time()
|
|
415
|
+
|
|
416
|
+
@staticmethod
|
|
417
|
+
def _sanitize_float(value: float) -> Optional[float]:
|
|
418
|
+
"""Convert inf/nan to None for JSON serialization."""
|
|
419
|
+
import math
|
|
420
|
+
if value is None or math.isinf(value) or math.isnan(value):
|
|
421
|
+
return None
|
|
422
|
+
return value
|
|
423
|
+
|
|
424
|
+
def get_local_status(self) -> Dict[str, Any]:
|
|
425
|
+
"""Get this node's training status."""
|
|
426
|
+
return {
|
|
427
|
+
"node_id": self.node_id,
|
|
428
|
+
"training_step": self._total_steps,
|
|
429
|
+
"moving_avg_loss": self._sanitize_float(self._moving_avg_loss),
|
|
430
|
+
"min_loss_seen": self._sanitize_float(self._min_loss),
|
|
431
|
+
"tokens_trained": self._total_tokens,
|
|
432
|
+
"current_shard": self._current_shard,
|
|
433
|
+
"model_hash": self._last_model_hash,
|
|
434
|
+
"loss_trend": self._compute_loss_trend(),
|
|
435
|
+
}
|
|
436
|
+
|
|
437
|
+
def get_global_status(self) -> Dict[str, Any]:
|
|
438
|
+
"""
|
|
439
|
+
Get network-wide training status.
|
|
440
|
+
|
|
441
|
+
This is the key method for verifying distributed training is working.
|
|
442
|
+
|
|
443
|
+
Returns:
|
|
444
|
+
Dict with:
|
|
445
|
+
- global_avg_loss: Average loss across all nodes
|
|
446
|
+
- global_min_loss: Best loss achieved by any node
|
|
447
|
+
- hash_agreement_rate: % of nodes with same model hash (should be 100%)
|
|
448
|
+
- total_nodes_training: Number of active nodes
|
|
449
|
+
- is_converging: Whether the network appears to be converging
|
|
450
|
+
- training_verified: Whether training is definitely improving the model
|
|
451
|
+
"""
|
|
452
|
+
# Check if training is actually improving
|
|
453
|
+
is_converging = self._check_convergence()
|
|
454
|
+
training_verified = self._verify_training()
|
|
455
|
+
|
|
456
|
+
return {
|
|
457
|
+
# Global metrics (sanitized for JSON)
|
|
458
|
+
"global_avg_loss": self._sanitize_float(self._global_stats.global_avg_loss),
|
|
459
|
+
"global_min_loss": self._sanitize_float(self._global_stats.global_min_loss),
|
|
460
|
+
|
|
461
|
+
# Convergence
|
|
462
|
+
"hash_agreement_rate": self._global_stats.hash_agreement_rate,
|
|
463
|
+
"model_hashes": dict(self._global_stats.model_hashes),
|
|
464
|
+
|
|
465
|
+
# Network health
|
|
466
|
+
"total_nodes_training": self._global_stats.total_nodes_training,
|
|
467
|
+
"successful_syncs": self._global_stats.successful_syncs,
|
|
468
|
+
"failed_syncs": self._global_stats.failed_syncs,
|
|
469
|
+
"sync_success_rate": (
|
|
470
|
+
self._global_stats.successful_syncs /
|
|
471
|
+
max(1, self._global_stats.successful_syncs + self._global_stats.failed_syncs)
|
|
472
|
+
),
|
|
473
|
+
|
|
474
|
+
# Progress
|
|
475
|
+
"global_steps": self._global_stats.global_steps,
|
|
476
|
+
"global_tokens": self._global_stats.global_tokens,
|
|
477
|
+
"data_shards_covered": list(self._global_stats.data_shards_covered),
|
|
478
|
+
|
|
479
|
+
# Verification
|
|
480
|
+
"is_converging": is_converging,
|
|
481
|
+
"training_verified": training_verified,
|
|
482
|
+
"loss_trend": self._compute_loss_trend(),
|
|
483
|
+
|
|
484
|
+
# Timestamp
|
|
485
|
+
"last_update": self._global_stats.last_update,
|
|
486
|
+
}
|
|
487
|
+
|
|
488
|
+
def _compute_loss_trend(self) -> str:
|
|
489
|
+
"""Compute loss trend over recent history."""
|
|
490
|
+
if len(self._loss_checkpoints) < 2:
|
|
491
|
+
return "insufficient_data"
|
|
492
|
+
|
|
493
|
+
# Compare first half to second half
|
|
494
|
+
mid = len(self._loss_checkpoints) // 2
|
|
495
|
+
first_half_avg = sum(l for _, l in self._loss_checkpoints[:mid]) / mid
|
|
496
|
+
second_half_avg = sum(l for _, l in self._loss_checkpoints[mid:]) / (len(self._loss_checkpoints) - mid)
|
|
497
|
+
|
|
498
|
+
improvement = (first_half_avg - second_half_avg) / first_half_avg if first_half_avg > 0 else 0
|
|
499
|
+
|
|
500
|
+
if improvement > 0.1:
|
|
501
|
+
return "improving_strongly"
|
|
502
|
+
elif improvement > 0.02:
|
|
503
|
+
return "improving"
|
|
504
|
+
elif improvement > -0.02:
|
|
505
|
+
return "stable"
|
|
506
|
+
elif improvement > -0.1:
|
|
507
|
+
return "degrading_slightly"
|
|
508
|
+
else:
|
|
509
|
+
return "degrading"
|
|
510
|
+
|
|
511
|
+
def _check_convergence(self) -> bool:
|
|
512
|
+
"""Check if the network appears to be converging."""
|
|
513
|
+
# Need at least 2 nodes with matching hashes
|
|
514
|
+
if self._global_stats.hash_agreement_rate < 0.5:
|
|
515
|
+
return False
|
|
516
|
+
|
|
517
|
+
# Loss should be trending down
|
|
518
|
+
trend = self._compute_loss_trend()
|
|
519
|
+
return trend in ["improving", "improving_strongly", "stable"]
|
|
520
|
+
|
|
521
|
+
def _verify_training(self) -> bool:
|
|
522
|
+
"""
|
|
523
|
+
Verify that training is actually improving the model.
|
|
524
|
+
|
|
525
|
+
Returns True if we can confirm the model is learning.
|
|
526
|
+
"""
|
|
527
|
+
# Need sufficient data
|
|
528
|
+
if len(self._loss_checkpoints) < 5:
|
|
529
|
+
return False
|
|
530
|
+
|
|
531
|
+
# Check that loss has decreased overall
|
|
532
|
+
first_losses = [l for _, l in self._loss_checkpoints[:3]]
|
|
533
|
+
recent_losses = [l for _, l in self._loss_checkpoints[-3:]]
|
|
534
|
+
|
|
535
|
+
first_avg = sum(first_losses) / len(first_losses)
|
|
536
|
+
recent_avg = sum(recent_losses) / len(recent_losses)
|
|
537
|
+
|
|
538
|
+
# Loss should have decreased by at least 10%
|
|
539
|
+
return recent_avg < first_avg * 0.9
|
|
540
|
+
|
|
541
|
+
def get_snapshot_for_gossip(self) -> Dict[str, Any]:
|
|
542
|
+
"""Get current snapshot data to send to peers."""
|
|
543
|
+
if not self._local_history:
|
|
544
|
+
return {}
|
|
545
|
+
|
|
546
|
+
latest = self._local_history[-1]
|
|
547
|
+
return {
|
|
548
|
+
"timestamp": latest.timestamp,
|
|
549
|
+
"batch_loss": latest.batch_loss,
|
|
550
|
+
"moving_avg_loss": latest.moving_avg_loss,
|
|
551
|
+
"min_loss_seen": latest.min_loss_seen,
|
|
552
|
+
"training_step": latest.training_step,
|
|
553
|
+
"inner_step": latest.inner_step,
|
|
554
|
+
"outer_step": latest.outer_step,
|
|
555
|
+
"model_hash": latest.model_hash,
|
|
556
|
+
"gradient_norm": latest.gradient_norm,
|
|
557
|
+
"shard_id": latest.shard_id,
|
|
558
|
+
"tokens_trained": latest.tokens_trained,
|
|
559
|
+
}
|
|
560
|
+
|
|
561
|
+
def save_training_log(self, filename: str = None):
|
|
562
|
+
"""Save training history to disk for analysis."""
|
|
563
|
+
if filename is None:
|
|
564
|
+
filename = f"training_log_{self.node_id[:8]}_{int(time.time())}.json"
|
|
565
|
+
|
|
566
|
+
filepath = self.checkpoint_dir / filename
|
|
567
|
+
|
|
568
|
+
log_data = {
|
|
569
|
+
"node_id": self.node_id,
|
|
570
|
+
"saved_at": time.time(),
|
|
571
|
+
"local_status": self.get_local_status(),
|
|
572
|
+
"global_status": self.get_global_status(),
|
|
573
|
+
"loss_checkpoints": self._loss_checkpoints,
|
|
574
|
+
"sync_history": list(self._sync_history),
|
|
575
|
+
}
|
|
576
|
+
|
|
577
|
+
with open(filepath, 'w') as f:
|
|
578
|
+
json.dump(log_data, f, indent=2, default=str)
|
|
579
|
+
|
|
580
|
+
logger.info(f"Training log saved to {filepath}")
|
|
581
|
+
return filepath
|
|
582
|
+
|
|
583
|
+
|
|
584
|
+
def format_training_status(tracker: GlobalTrainingTracker) -> str:
|
|
585
|
+
"""Format training status for display."""
|
|
586
|
+
local = tracker.get_local_status()
|
|
587
|
+
global_stats = tracker.get_global_status()
|
|
588
|
+
|
|
589
|
+
lines = [
|
|
590
|
+
"=" * 60,
|
|
591
|
+
"NEUROSHARD GLOBAL TRAINING STATUS",
|
|
592
|
+
"=" * 60,
|
|
593
|
+
"",
|
|
594
|
+
"LOCAL NODE:",
|
|
595
|
+
f" Step: {local['training_step']:,}",
|
|
596
|
+
f" Loss: {local['moving_avg_loss']:.4f} (min: {local['min_loss_seen']:.4f})",
|
|
597
|
+
f" Tokens: {local['tokens_trained']:,}",
|
|
598
|
+
f" Trend: {local['loss_trend']}",
|
|
599
|
+
f" Model Hash: {local['model_hash']}",
|
|
600
|
+
"",
|
|
601
|
+
"GLOBAL NETWORK:",
|
|
602
|
+
f" Nodes Training: {global_stats['total_nodes_training']}",
|
|
603
|
+
f" Global Avg Loss: {global_stats['global_avg_loss']:.4f}",
|
|
604
|
+
f" Global Min Loss: {global_stats['global_min_loss']:.4f}",
|
|
605
|
+
f" Hash Agreement: {global_stats['hash_agreement_rate']*100:.1f}%",
|
|
606
|
+
f" Shards Covered: {len(global_stats['data_shards_covered'])}",
|
|
607
|
+
"",
|
|
608
|
+
"VERIFICATION:",
|
|
609
|
+
f" Is Converging: {'✓' if global_stats['is_converging'] else '✗'}",
|
|
610
|
+
f" Training Verified: {'✓' if global_stats['training_verified'] else '✗'}",
|
|
611
|
+
f" Sync Success Rate: {global_stats['sync_success_rate']*100:.1f}%",
|
|
612
|
+
"",
|
|
613
|
+
"=" * 60,
|
|
614
|
+
]
|
|
615
|
+
|
|
616
|
+
return "\n".join(lines)
|
|
617
|
+
|