nexaroa 0.0.111__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. neuroshard/__init__.py +93 -0
  2. neuroshard/__main__.py +4 -0
  3. neuroshard/cli.py +466 -0
  4. neuroshard/core/__init__.py +92 -0
  5. neuroshard/core/consensus/verifier.py +252 -0
  6. neuroshard/core/crypto/__init__.py +20 -0
  7. neuroshard/core/crypto/ecdsa.py +392 -0
  8. neuroshard/core/economics/__init__.py +52 -0
  9. neuroshard/core/economics/constants.py +387 -0
  10. neuroshard/core/economics/ledger.py +2111 -0
  11. neuroshard/core/economics/market.py +975 -0
  12. neuroshard/core/economics/wallet.py +168 -0
  13. neuroshard/core/governance/__init__.py +74 -0
  14. neuroshard/core/governance/proposal.py +561 -0
  15. neuroshard/core/governance/registry.py +545 -0
  16. neuroshard/core/governance/versioning.py +332 -0
  17. neuroshard/core/governance/voting.py +453 -0
  18. neuroshard/core/model/__init__.py +30 -0
  19. neuroshard/core/model/dynamic.py +4186 -0
  20. neuroshard/core/model/llm.py +905 -0
  21. neuroshard/core/model/registry.py +164 -0
  22. neuroshard/core/model/scaler.py +387 -0
  23. neuroshard/core/model/tokenizer.py +568 -0
  24. neuroshard/core/network/__init__.py +56 -0
  25. neuroshard/core/network/connection_pool.py +72 -0
  26. neuroshard/core/network/dht.py +130 -0
  27. neuroshard/core/network/dht_plan.py +55 -0
  28. neuroshard/core/network/dht_proof_store.py +516 -0
  29. neuroshard/core/network/dht_protocol.py +261 -0
  30. neuroshard/core/network/dht_service.py +506 -0
  31. neuroshard/core/network/encrypted_channel.py +141 -0
  32. neuroshard/core/network/nat.py +201 -0
  33. neuroshard/core/network/nat_traversal.py +695 -0
  34. neuroshard/core/network/p2p.py +929 -0
  35. neuroshard/core/network/p2p_data.py +150 -0
  36. neuroshard/core/swarm/__init__.py +106 -0
  37. neuroshard/core/swarm/aggregation.py +729 -0
  38. neuroshard/core/swarm/buffers.py +643 -0
  39. neuroshard/core/swarm/checkpoint.py +709 -0
  40. neuroshard/core/swarm/compute.py +624 -0
  41. neuroshard/core/swarm/diloco.py +844 -0
  42. neuroshard/core/swarm/factory.py +1288 -0
  43. neuroshard/core/swarm/heartbeat.py +669 -0
  44. neuroshard/core/swarm/logger.py +487 -0
  45. neuroshard/core/swarm/router.py +658 -0
  46. neuroshard/core/swarm/service.py +640 -0
  47. neuroshard/core/training/__init__.py +29 -0
  48. neuroshard/core/training/checkpoint.py +600 -0
  49. neuroshard/core/training/distributed.py +1602 -0
  50. neuroshard/core/training/global_tracker.py +617 -0
  51. neuroshard/core/training/production.py +276 -0
  52. neuroshard/governance_cli.py +729 -0
  53. neuroshard/grpc_server.py +895 -0
  54. neuroshard/runner.py +3223 -0
  55. neuroshard/sdk/__init__.py +92 -0
  56. neuroshard/sdk/client.py +990 -0
  57. neuroshard/sdk/errors.py +101 -0
  58. neuroshard/sdk/types.py +282 -0
  59. neuroshard/tracker/__init__.py +0 -0
  60. neuroshard/tracker/server.py +864 -0
  61. neuroshard/ui/__init__.py +0 -0
  62. neuroshard/ui/app.py +102 -0
  63. neuroshard/ui/templates/index.html +1052 -0
  64. neuroshard/utils/__init__.py +0 -0
  65. neuroshard/utils/autostart.py +81 -0
  66. neuroshard/utils/hardware.py +121 -0
  67. neuroshard/utils/serialization.py +90 -0
  68. neuroshard/version.py +1 -0
  69. nexaroa-0.0.111.dist-info/METADATA +283 -0
  70. nexaroa-0.0.111.dist-info/RECORD +78 -0
  71. nexaroa-0.0.111.dist-info/WHEEL +5 -0
  72. nexaroa-0.0.111.dist-info/entry_points.txt +4 -0
  73. nexaroa-0.0.111.dist-info/licenses/LICENSE +190 -0
  74. nexaroa-0.0.111.dist-info/top_level.txt +2 -0
  75. protos/__init__.py +0 -0
  76. protos/neuroshard.proto +651 -0
  77. protos/neuroshard_pb2.py +160 -0
  78. protos/neuroshard_pb2_grpc.py +1298 -0
@@ -0,0 +1,617 @@
1
+ """
2
+ Global Training Tracker - Verify Distributed Training is Working
3
+
4
+ This module provides network-wide training verification:
5
+ 1. Tracks loss across ALL nodes in the network
6
+ 2. Monitors model hash convergence (ensures nodes sync)
7
+ 3. Computes global training metrics (not just local batch loss)
8
+ 4. Provides dashboards/APIs for monitoring
9
+
10
+ Key Concepts:
11
+ - Moving Average Loss: Smoothed loss over time (local)
12
+ - Global Loss: Average loss across all network nodes
13
+ - Model Hash: SHA256 of model weights (should converge across nodes)
14
+ - Sync Rate: How often nodes successfully sync gradients
15
+ """
16
+
17
+ import time
18
+ import hashlib
19
+ import threading
20
+ import logging
21
+ import json
22
+ from collections import deque
23
+ from dataclasses import dataclass, field
24
+ from typing import Dict, List, Optional, Any, Tuple
25
+ from pathlib import Path
26
+ import torch
27
+ import torch.nn as nn
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+
32
+ @dataclass
33
+ class TrainingSnapshot:
34
+ """A snapshot of training state at a point in time."""
35
+ timestamp: float
36
+ node_id: str
37
+
38
+ # Loss metrics
39
+ batch_loss: float # Raw batch loss
40
+ moving_avg_loss: float # Smoothed loss (EMA)
41
+ min_loss_seen: float # Best loss achieved
42
+
43
+ # Training progress
44
+ training_step: int # Global step count
45
+ inner_step: int # DiLoCo inner step (0-500)
46
+ outer_step: int # DiLoCo outer step (sync count)
47
+
48
+ # Convergence metrics
49
+ model_hash: str # Hash of model weights
50
+ gradient_norm: float # L2 norm of gradients
51
+
52
+ # Data coverage
53
+ shard_id: int # Current data shard
54
+ tokens_trained: int # Total tokens seen
55
+
56
+
57
+ @dataclass
58
+ class GlobalTrainingStats:
59
+ """Network-wide training statistics."""
60
+ # Aggregated metrics
61
+ global_avg_loss: float = 0.0
62
+ global_min_loss: float = float('inf')
63
+
64
+ # Convergence tracking
65
+ model_hashes: Dict[str, str] = field(default_factory=dict) # node_id -> hash
66
+ hash_agreement_rate: float = 0.0 # % of nodes with same hash
67
+
68
+ # Network health
69
+ total_nodes_training: int = 0
70
+ successful_syncs: int = 0
71
+ failed_syncs: int = 0
72
+
73
+ # Progress
74
+ global_steps: int = 0
75
+ global_tokens: int = 0
76
+ data_shards_covered: set = field(default_factory=set)
77
+
78
+ # Time tracking
79
+ last_update: float = 0.0
80
+
81
+
82
+ class GlobalTrainingTracker:
83
+ """
84
+ Tracks and verifies distributed training across the network.
85
+
86
+ Each node runs this to:
87
+ 1. Track its own training progress
88
+ 2. Receive training stats from peers via gossip
89
+ 3. Compute global metrics
90
+ 4. Verify convergence (all nodes should have similar model hash)
91
+
92
+ Usage:
93
+ tracker = GlobalTrainingTracker(node_id, model)
94
+
95
+ # During training
96
+ tracker.record_step(loss, step, shard_id)
97
+
98
+ # Get status
99
+ status = tracker.get_global_status()
100
+ print(f"Global loss: {status['global_avg_loss']:.4f}")
101
+ print(f"Hash agreement: {status['hash_agreement_rate']*100:.1f}%")
102
+ """
103
+
104
+ # EMA smoothing factor (lower = smoother)
105
+ EMA_ALPHA = 0.1
106
+
107
+ # History window for computing metrics
108
+ HISTORY_WINDOW = 100
109
+
110
+ # Minimum nodes for global metrics
111
+ MIN_NODES_FOR_GLOBAL = 1
112
+
113
+ def __init__(
114
+ self,
115
+ node_id: str,
116
+ model: nn.Module,
117
+ checkpoint_dir: Optional[Path] = None,
118
+ ):
119
+ self.node_id = node_id
120
+ self.model = model
121
+ self.checkpoint_dir = checkpoint_dir or Path.home() / ".neuroshard" / "training_logs"
122
+ self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
123
+
124
+ # Local tracking
125
+ self._local_history: deque = deque(maxlen=self.HISTORY_WINDOW)
126
+ self._moving_avg_loss = 0.0
127
+ self._min_loss = float('inf')
128
+ self._total_steps = 0
129
+ self._total_tokens = 0
130
+ self._current_shard = 0
131
+
132
+ # Peer tracking (received via gossip)
133
+ self._peer_stats: Dict[str, TrainingSnapshot] = {}
134
+ self._peer_stats_lock = threading.Lock()
135
+
136
+ # Global aggregated stats
137
+ self._global_stats = GlobalTrainingStats()
138
+
139
+ # Sync tracking
140
+ self._sync_history: deque = deque(maxlen=50) # (timestamp, success)
141
+ self._last_model_hash = ""
142
+
143
+ # Verification
144
+ self._loss_checkpoints: List[Tuple[int, float]] = [] # (step, loss)
145
+
146
+ # Try to restore previous state
147
+ self._load_state()
148
+
149
+ logger.info(f"GlobalTrainingTracker initialized for node {node_id[:8]}...")
150
+
151
+ def _get_state_path(self) -> Path:
152
+ """Get path to state file."""
153
+ return self.checkpoint_dir / f"tracker_state_{self.node_id[:16]}.json"
154
+
155
+ def _save_state(self):
156
+ """Persist tracker state to disk."""
157
+ try:
158
+ state = {
159
+ "node_id": self.node_id,
160
+ "saved_at": time.time(),
161
+ "moving_avg_loss": self._moving_avg_loss,
162
+ "min_loss": self._min_loss if self._min_loss != float('inf') else None,
163
+ "total_steps": self._total_steps,
164
+ "total_tokens": self._total_tokens,
165
+ "current_shard": self._current_shard,
166
+ "last_model_hash": self._last_model_hash,
167
+ "loss_checkpoints": self._loss_checkpoints[-100:], # Keep last 100
168
+ "global_stats": {
169
+ "global_avg_loss": self._global_stats.global_avg_loss,
170
+ "global_min_loss": self._global_stats.global_min_loss if self._global_stats.global_min_loss != float('inf') else None,
171
+ "total_nodes_training": self._global_stats.total_nodes_training,
172
+ "global_steps": self._global_stats.global_steps,
173
+ "global_tokens": self._global_stats.global_tokens,
174
+ "successful_syncs": self._global_stats.successful_syncs,
175
+ "failed_syncs": self._global_stats.failed_syncs,
176
+ },
177
+ }
178
+
179
+ with open(self._get_state_path(), 'w') as f:
180
+ json.dump(state, f, indent=2, default=str)
181
+
182
+ logger.debug(f"Tracker state saved: {self._total_steps} steps, loss={self._moving_avg_loss:.4f}")
183
+ except Exception as e:
184
+ logger.warning(f"Failed to save tracker state: {e}")
185
+
186
+ def _load_state(self):
187
+ """Load tracker state from disk."""
188
+ path = self._get_state_path()
189
+ if not path.exists():
190
+ logger.debug("No previous tracker state found")
191
+ return
192
+
193
+ try:
194
+ with open(path, 'r') as f:
195
+ state = json.load(f)
196
+
197
+ # Restore local state
198
+ self._moving_avg_loss = state.get("moving_avg_loss", 0.0)
199
+ self._min_loss = state.get("min_loss") or float('inf')
200
+ self._total_steps = state.get("total_steps", 0)
201
+ self._total_tokens = state.get("total_tokens", 0)
202
+ self._current_shard = state.get("current_shard", 0)
203
+ self._last_model_hash = state.get("last_model_hash", "")
204
+ self._loss_checkpoints = state.get("loss_checkpoints", [])
205
+
206
+ # Restore global stats
207
+ global_stats = state.get("global_stats", {})
208
+ self._global_stats.global_avg_loss = global_stats.get("global_avg_loss", 0.0)
209
+ self._global_stats.global_min_loss = global_stats.get("global_min_loss") or float('inf')
210
+ self._global_stats.total_nodes_training = global_stats.get("total_nodes_training", 0)
211
+ self._global_stats.global_steps = global_stats.get("global_steps", 0)
212
+ self._global_stats.global_tokens = global_stats.get("global_tokens", 0)
213
+ self._global_stats.successful_syncs = global_stats.get("successful_syncs", 0)
214
+ self._global_stats.failed_syncs = global_stats.get("failed_syncs", 0)
215
+
216
+ logger.info(f"Restored tracker state: {self._total_steps} steps, avg_loss={self._moving_avg_loss:.4f}")
217
+ except Exception as e:
218
+ logger.warning(f"Failed to load tracker state: {e}")
219
+
220
+ def record_step(
221
+ self,
222
+ loss: float,
223
+ step: int,
224
+ shard_id: int = 0,
225
+ tokens_in_batch: int = 0,
226
+ gradient_norm: Optional[float] = None,
227
+ inner_step: int = 0,
228
+ outer_step: int = 0,
229
+ ) -> TrainingSnapshot:
230
+ """
231
+ Record a training step and update metrics.
232
+
233
+ Args:
234
+ loss: Raw loss value for this batch
235
+ step: Global training step number
236
+ shard_id: Data shard being trained on
237
+ tokens_in_batch: Tokens processed in this batch
238
+ gradient_norm: Optional gradient L2 norm
239
+ inner_step: DiLoCo inner step (0-500)
240
+ outer_step: DiLoCo outer sync count
241
+
242
+ Returns:
243
+ TrainingSnapshot with current state
244
+ """
245
+ # Update EMA loss
246
+ if self._moving_avg_loss == 0.0:
247
+ self._moving_avg_loss = loss
248
+ else:
249
+ self._moving_avg_loss = (
250
+ self.EMA_ALPHA * loss +
251
+ (1 - self.EMA_ALPHA) * self._moving_avg_loss
252
+ )
253
+
254
+ # Track minimum loss
255
+ if loss < self._min_loss:
256
+ self._min_loss = loss
257
+
258
+ # Update counters
259
+ self._total_steps = step
260
+ self._total_tokens += tokens_in_batch
261
+ self._current_shard = shard_id
262
+
263
+ # Compute model hash (every 50 steps to save compute)
264
+ if step % 50 == 0:
265
+ self._last_model_hash = self._compute_model_hash()
266
+
267
+ # Create snapshot
268
+ snapshot = TrainingSnapshot(
269
+ timestamp=time.time(),
270
+ node_id=self.node_id,
271
+ batch_loss=loss,
272
+ moving_avg_loss=self._moving_avg_loss,
273
+ min_loss_seen=self._min_loss,
274
+ training_step=step,
275
+ inner_step=inner_step,
276
+ outer_step=outer_step,
277
+ model_hash=self._last_model_hash,
278
+ gradient_norm=gradient_norm or 0.0,
279
+ shard_id=shard_id,
280
+ tokens_trained=self._total_tokens,
281
+ )
282
+
283
+ self._local_history.append(snapshot)
284
+
285
+ # Record loss checkpoint every 100 steps
286
+ if step % 100 == 0:
287
+ self._loss_checkpoints.append((step, self._moving_avg_loss))
288
+ # Keep only last 100 checkpoints
289
+ if len(self._loss_checkpoints) > 100:
290
+ self._loss_checkpoints = self._loss_checkpoints[-100:]
291
+
292
+ # Update global stats
293
+ self._update_global_stats()
294
+
295
+ # Periodically save state (every 10 steps to match checkpoint frequency)
296
+ if step % 10 == 0:
297
+ self._save_state()
298
+
299
+ return snapshot
300
+
301
+ def _compute_model_hash(self) -> str:
302
+ """Compute SHA256 hash of model weights (sampled for speed)."""
303
+ hasher = hashlib.sha256()
304
+
305
+ # Sample some parameters for speed
306
+ params_to_hash = list(self.model.named_parameters())[:10]
307
+
308
+ for name, param in params_to_hash:
309
+ hasher.update(name.encode())
310
+ # Sample first 1000 values
311
+ data = param.data.flatten()[:1000].cpu().numpy().tobytes()
312
+ hasher.update(data)
313
+
314
+ return hasher.hexdigest()[:16]
315
+
316
+ def receive_peer_stats(self, peer_id: str, snapshot_data: Dict[str, Any]):
317
+ """
318
+ Receive training stats from a peer via gossip.
319
+
320
+ Args:
321
+ peer_id: ID of peer node
322
+ snapshot_data: Serialized TrainingSnapshot
323
+ """
324
+ with self._peer_stats_lock:
325
+ snapshot = TrainingSnapshot(
326
+ timestamp=snapshot_data.get("timestamp", time.time()),
327
+ node_id=peer_id,
328
+ batch_loss=snapshot_data.get("batch_loss", 0.0),
329
+ moving_avg_loss=snapshot_data.get("moving_avg_loss", 0.0),
330
+ min_loss_seen=snapshot_data.get("min_loss_seen", float('inf')),
331
+ training_step=snapshot_data.get("training_step", 0),
332
+ inner_step=snapshot_data.get("inner_step", 0),
333
+ outer_step=snapshot_data.get("outer_step", 0),
334
+ model_hash=snapshot_data.get("model_hash", ""),
335
+ gradient_norm=snapshot_data.get("gradient_norm", 0.0),
336
+ shard_id=snapshot_data.get("shard_id", 0),
337
+ tokens_trained=snapshot_data.get("tokens_trained", 0),
338
+ )
339
+
340
+ self._peer_stats[peer_id] = snapshot
341
+
342
+ # Clean up stale peers (>5 min old)
343
+ now = time.time()
344
+ stale_peers = [
345
+ pid for pid, s in self._peer_stats.items()
346
+ if now - s.timestamp > 300
347
+ ]
348
+ for pid in stale_peers:
349
+ del self._peer_stats[pid]
350
+
351
+ self._update_global_stats()
352
+
353
+ def record_sync_result(self, success: bool, peers_synced: int = 0):
354
+ """Record a gradient sync attempt result."""
355
+ self._sync_history.append((time.time(), success, peers_synced))
356
+
357
+ if success:
358
+ self._global_stats.successful_syncs += 1
359
+ else:
360
+ self._global_stats.failed_syncs += 1
361
+
362
+ # Persist state after each sync (important milestone)
363
+ self._save_state()
364
+
365
+ def _update_global_stats(self):
366
+ """Recompute global statistics from local + peer data."""
367
+ with self._peer_stats_lock:
368
+ # Collect all snapshots (local + peers)
369
+ all_snapshots = list(self._peer_stats.values())
370
+
371
+ # Add our own latest
372
+ if self._local_history:
373
+ all_snapshots.append(self._local_history[-1])
374
+
375
+ if not all_snapshots:
376
+ return
377
+
378
+ # Compute global averages
379
+ self._global_stats.global_avg_loss = sum(
380
+ s.moving_avg_loss for s in all_snapshots
381
+ ) / len(all_snapshots)
382
+
383
+ self._global_stats.global_min_loss = min(
384
+ s.min_loss_seen for s in all_snapshots
385
+ )
386
+
387
+ self._global_stats.total_nodes_training = len(all_snapshots)
388
+
389
+ self._global_stats.global_steps = max(
390
+ s.training_step for s in all_snapshots
391
+ )
392
+
393
+ self._global_stats.global_tokens = sum(
394
+ s.tokens_trained for s in all_snapshots
395
+ )
396
+
397
+ # Track data shard coverage
398
+ self._global_stats.data_shards_covered = set(
399
+ s.shard_id for s in all_snapshots
400
+ )
401
+
402
+ # Check model hash convergence
403
+ hashes = [s.model_hash for s in all_snapshots if s.model_hash]
404
+ if hashes:
405
+ # Count most common hash
406
+ from collections import Counter
407
+ hash_counts = Counter(hashes)
408
+ most_common_hash, count = hash_counts.most_common(1)[0]
409
+ self._global_stats.hash_agreement_rate = count / len(hashes)
410
+ self._global_stats.model_hashes = {
411
+ s.node_id: s.model_hash for s in all_snapshots
412
+ }
413
+
414
+ self._global_stats.last_update = time.time()
415
+
416
+ @staticmethod
417
+ def _sanitize_float(value: float) -> Optional[float]:
418
+ """Convert inf/nan to None for JSON serialization."""
419
+ import math
420
+ if value is None or math.isinf(value) or math.isnan(value):
421
+ return None
422
+ return value
423
+
424
+ def get_local_status(self) -> Dict[str, Any]:
425
+ """Get this node's training status."""
426
+ return {
427
+ "node_id": self.node_id,
428
+ "training_step": self._total_steps,
429
+ "moving_avg_loss": self._sanitize_float(self._moving_avg_loss),
430
+ "min_loss_seen": self._sanitize_float(self._min_loss),
431
+ "tokens_trained": self._total_tokens,
432
+ "current_shard": self._current_shard,
433
+ "model_hash": self._last_model_hash,
434
+ "loss_trend": self._compute_loss_trend(),
435
+ }
436
+
437
+ def get_global_status(self) -> Dict[str, Any]:
438
+ """
439
+ Get network-wide training status.
440
+
441
+ This is the key method for verifying distributed training is working.
442
+
443
+ Returns:
444
+ Dict with:
445
+ - global_avg_loss: Average loss across all nodes
446
+ - global_min_loss: Best loss achieved by any node
447
+ - hash_agreement_rate: % of nodes with same model hash (should be 100%)
448
+ - total_nodes_training: Number of active nodes
449
+ - is_converging: Whether the network appears to be converging
450
+ - training_verified: Whether training is definitely improving the model
451
+ """
452
+ # Check if training is actually improving
453
+ is_converging = self._check_convergence()
454
+ training_verified = self._verify_training()
455
+
456
+ return {
457
+ # Global metrics (sanitized for JSON)
458
+ "global_avg_loss": self._sanitize_float(self._global_stats.global_avg_loss),
459
+ "global_min_loss": self._sanitize_float(self._global_stats.global_min_loss),
460
+
461
+ # Convergence
462
+ "hash_agreement_rate": self._global_stats.hash_agreement_rate,
463
+ "model_hashes": dict(self._global_stats.model_hashes),
464
+
465
+ # Network health
466
+ "total_nodes_training": self._global_stats.total_nodes_training,
467
+ "successful_syncs": self._global_stats.successful_syncs,
468
+ "failed_syncs": self._global_stats.failed_syncs,
469
+ "sync_success_rate": (
470
+ self._global_stats.successful_syncs /
471
+ max(1, self._global_stats.successful_syncs + self._global_stats.failed_syncs)
472
+ ),
473
+
474
+ # Progress
475
+ "global_steps": self._global_stats.global_steps,
476
+ "global_tokens": self._global_stats.global_tokens,
477
+ "data_shards_covered": list(self._global_stats.data_shards_covered),
478
+
479
+ # Verification
480
+ "is_converging": is_converging,
481
+ "training_verified": training_verified,
482
+ "loss_trend": self._compute_loss_trend(),
483
+
484
+ # Timestamp
485
+ "last_update": self._global_stats.last_update,
486
+ }
487
+
488
+ def _compute_loss_trend(self) -> str:
489
+ """Compute loss trend over recent history."""
490
+ if len(self._loss_checkpoints) < 2:
491
+ return "insufficient_data"
492
+
493
+ # Compare first half to second half
494
+ mid = len(self._loss_checkpoints) // 2
495
+ first_half_avg = sum(l for _, l in self._loss_checkpoints[:mid]) / mid
496
+ second_half_avg = sum(l for _, l in self._loss_checkpoints[mid:]) / (len(self._loss_checkpoints) - mid)
497
+
498
+ improvement = (first_half_avg - second_half_avg) / first_half_avg if first_half_avg > 0 else 0
499
+
500
+ if improvement > 0.1:
501
+ return "improving_strongly"
502
+ elif improvement > 0.02:
503
+ return "improving"
504
+ elif improvement > -0.02:
505
+ return "stable"
506
+ elif improvement > -0.1:
507
+ return "degrading_slightly"
508
+ else:
509
+ return "degrading"
510
+
511
+ def _check_convergence(self) -> bool:
512
+ """Check if the network appears to be converging."""
513
+ # Need at least 2 nodes with matching hashes
514
+ if self._global_stats.hash_agreement_rate < 0.5:
515
+ return False
516
+
517
+ # Loss should be trending down
518
+ trend = self._compute_loss_trend()
519
+ return trend in ["improving", "improving_strongly", "stable"]
520
+
521
+ def _verify_training(self) -> bool:
522
+ """
523
+ Verify that training is actually improving the model.
524
+
525
+ Returns True if we can confirm the model is learning.
526
+ """
527
+ # Need sufficient data
528
+ if len(self._loss_checkpoints) < 5:
529
+ return False
530
+
531
+ # Check that loss has decreased overall
532
+ first_losses = [l for _, l in self._loss_checkpoints[:3]]
533
+ recent_losses = [l for _, l in self._loss_checkpoints[-3:]]
534
+
535
+ first_avg = sum(first_losses) / len(first_losses)
536
+ recent_avg = sum(recent_losses) / len(recent_losses)
537
+
538
+ # Loss should have decreased by at least 10%
539
+ return recent_avg < first_avg * 0.9
540
+
541
+ def get_snapshot_for_gossip(self) -> Dict[str, Any]:
542
+ """Get current snapshot data to send to peers."""
543
+ if not self._local_history:
544
+ return {}
545
+
546
+ latest = self._local_history[-1]
547
+ return {
548
+ "timestamp": latest.timestamp,
549
+ "batch_loss": latest.batch_loss,
550
+ "moving_avg_loss": latest.moving_avg_loss,
551
+ "min_loss_seen": latest.min_loss_seen,
552
+ "training_step": latest.training_step,
553
+ "inner_step": latest.inner_step,
554
+ "outer_step": latest.outer_step,
555
+ "model_hash": latest.model_hash,
556
+ "gradient_norm": latest.gradient_norm,
557
+ "shard_id": latest.shard_id,
558
+ "tokens_trained": latest.tokens_trained,
559
+ }
560
+
561
+ def save_training_log(self, filename: str = None):
562
+ """Save training history to disk for analysis."""
563
+ if filename is None:
564
+ filename = f"training_log_{self.node_id[:8]}_{int(time.time())}.json"
565
+
566
+ filepath = self.checkpoint_dir / filename
567
+
568
+ log_data = {
569
+ "node_id": self.node_id,
570
+ "saved_at": time.time(),
571
+ "local_status": self.get_local_status(),
572
+ "global_status": self.get_global_status(),
573
+ "loss_checkpoints": self._loss_checkpoints,
574
+ "sync_history": list(self._sync_history),
575
+ }
576
+
577
+ with open(filepath, 'w') as f:
578
+ json.dump(log_data, f, indent=2, default=str)
579
+
580
+ logger.info(f"Training log saved to {filepath}")
581
+ return filepath
582
+
583
+
584
+ def format_training_status(tracker: GlobalTrainingTracker) -> str:
585
+ """Format training status for display."""
586
+ local = tracker.get_local_status()
587
+ global_stats = tracker.get_global_status()
588
+
589
+ lines = [
590
+ "=" * 60,
591
+ "NEUROSHARD GLOBAL TRAINING STATUS",
592
+ "=" * 60,
593
+ "",
594
+ "LOCAL NODE:",
595
+ f" Step: {local['training_step']:,}",
596
+ f" Loss: {local['moving_avg_loss']:.4f} (min: {local['min_loss_seen']:.4f})",
597
+ f" Tokens: {local['tokens_trained']:,}",
598
+ f" Trend: {local['loss_trend']}",
599
+ f" Model Hash: {local['model_hash']}",
600
+ "",
601
+ "GLOBAL NETWORK:",
602
+ f" Nodes Training: {global_stats['total_nodes_training']}",
603
+ f" Global Avg Loss: {global_stats['global_avg_loss']:.4f}",
604
+ f" Global Min Loss: {global_stats['global_min_loss']:.4f}",
605
+ f" Hash Agreement: {global_stats['hash_agreement_rate']*100:.1f}%",
606
+ f" Shards Covered: {len(global_stats['data_shards_covered'])}",
607
+ "",
608
+ "VERIFICATION:",
609
+ f" Is Converging: {'✓' if global_stats['is_converging'] else '✗'}",
610
+ f" Training Verified: {'✓' if global_stats['training_verified'] else '✗'}",
611
+ f" Sync Success Rate: {global_stats['sync_success_rate']*100:.1f}%",
612
+ "",
613
+ "=" * 60,
614
+ ]
615
+
616
+ return "\n".join(lines)
617
+