nexaroa 0.0.111__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. neuroshard/__init__.py +93 -0
  2. neuroshard/__main__.py +4 -0
  3. neuroshard/cli.py +466 -0
  4. neuroshard/core/__init__.py +92 -0
  5. neuroshard/core/consensus/verifier.py +252 -0
  6. neuroshard/core/crypto/__init__.py +20 -0
  7. neuroshard/core/crypto/ecdsa.py +392 -0
  8. neuroshard/core/economics/__init__.py +52 -0
  9. neuroshard/core/economics/constants.py +387 -0
  10. neuroshard/core/economics/ledger.py +2111 -0
  11. neuroshard/core/economics/market.py +975 -0
  12. neuroshard/core/economics/wallet.py +168 -0
  13. neuroshard/core/governance/__init__.py +74 -0
  14. neuroshard/core/governance/proposal.py +561 -0
  15. neuroshard/core/governance/registry.py +545 -0
  16. neuroshard/core/governance/versioning.py +332 -0
  17. neuroshard/core/governance/voting.py +453 -0
  18. neuroshard/core/model/__init__.py +30 -0
  19. neuroshard/core/model/dynamic.py +4186 -0
  20. neuroshard/core/model/llm.py +905 -0
  21. neuroshard/core/model/registry.py +164 -0
  22. neuroshard/core/model/scaler.py +387 -0
  23. neuroshard/core/model/tokenizer.py +568 -0
  24. neuroshard/core/network/__init__.py +56 -0
  25. neuroshard/core/network/connection_pool.py +72 -0
  26. neuroshard/core/network/dht.py +130 -0
  27. neuroshard/core/network/dht_plan.py +55 -0
  28. neuroshard/core/network/dht_proof_store.py +516 -0
  29. neuroshard/core/network/dht_protocol.py +261 -0
  30. neuroshard/core/network/dht_service.py +506 -0
  31. neuroshard/core/network/encrypted_channel.py +141 -0
  32. neuroshard/core/network/nat.py +201 -0
  33. neuroshard/core/network/nat_traversal.py +695 -0
  34. neuroshard/core/network/p2p.py +929 -0
  35. neuroshard/core/network/p2p_data.py +150 -0
  36. neuroshard/core/swarm/__init__.py +106 -0
  37. neuroshard/core/swarm/aggregation.py +729 -0
  38. neuroshard/core/swarm/buffers.py +643 -0
  39. neuroshard/core/swarm/checkpoint.py +709 -0
  40. neuroshard/core/swarm/compute.py +624 -0
  41. neuroshard/core/swarm/diloco.py +844 -0
  42. neuroshard/core/swarm/factory.py +1288 -0
  43. neuroshard/core/swarm/heartbeat.py +669 -0
  44. neuroshard/core/swarm/logger.py +487 -0
  45. neuroshard/core/swarm/router.py +658 -0
  46. neuroshard/core/swarm/service.py +640 -0
  47. neuroshard/core/training/__init__.py +29 -0
  48. neuroshard/core/training/checkpoint.py +600 -0
  49. neuroshard/core/training/distributed.py +1602 -0
  50. neuroshard/core/training/global_tracker.py +617 -0
  51. neuroshard/core/training/production.py +276 -0
  52. neuroshard/governance_cli.py +729 -0
  53. neuroshard/grpc_server.py +895 -0
  54. neuroshard/runner.py +3223 -0
  55. neuroshard/sdk/__init__.py +92 -0
  56. neuroshard/sdk/client.py +990 -0
  57. neuroshard/sdk/errors.py +101 -0
  58. neuroshard/sdk/types.py +282 -0
  59. neuroshard/tracker/__init__.py +0 -0
  60. neuroshard/tracker/server.py +864 -0
  61. neuroshard/ui/__init__.py +0 -0
  62. neuroshard/ui/app.py +102 -0
  63. neuroshard/ui/templates/index.html +1052 -0
  64. neuroshard/utils/__init__.py +0 -0
  65. neuroshard/utils/autostart.py +81 -0
  66. neuroshard/utils/hardware.py +121 -0
  67. neuroshard/utils/serialization.py +90 -0
  68. neuroshard/version.py +1 -0
  69. nexaroa-0.0.111.dist-info/METADATA +283 -0
  70. nexaroa-0.0.111.dist-info/RECORD +78 -0
  71. nexaroa-0.0.111.dist-info/WHEEL +5 -0
  72. nexaroa-0.0.111.dist-info/entry_points.txt +4 -0
  73. nexaroa-0.0.111.dist-info/licenses/LICENSE +190 -0
  74. nexaroa-0.0.111.dist-info/top_level.txt +2 -0
  75. protos/__init__.py +0 -0
  76. protos/neuroshard.proto +651 -0
  77. protos/neuroshard_pb2.py +160 -0
  78. protos/neuroshard_pb2_grpc.py +1298 -0
@@ -0,0 +1,709 @@
1
+ """
2
+ Speculative Checkpointing - High-Frequency Snapshots for Fast Recovery
3
+
4
+ Implements background checkpointing for resilient distributed training:
5
+ - Saves snapshots every 2 minutes (configurable)
6
+ - Keeps rolling window of last N snapshots
7
+ - Announces availability to DHT for peer recovery
8
+ - Enables fast crash recovery vs full restart
9
+
10
+ Key Insight: "Cheaper to over-checkpoint than to re-train."
11
+
12
+ On crash, neighbors can fetch the "hot" snapshot and resume
13
+ with minimal loss of training progress.
14
+
15
+ Usage:
16
+ checkpointer = SpeculativeCheckpointer(
17
+ model=model,
18
+ optimizer=optimizer,
19
+ diloco_trainer=trainer,
20
+ checkpoint_dir="/path/to/checkpoints"
21
+ )
22
+ checkpointer.start()
23
+
24
+ # On crash recovery:
25
+ checkpoint = await checkpointer.fetch_neighbor_snapshot(peer_id)
26
+ checkpointer.restore_from_checkpoint(checkpoint)
27
+ """
28
+
29
+ import asyncio
30
+ import gzip
31
+ import hashlib
32
+ import io
33
+ import logging
34
+ import os
35
+ import shutil
36
+ import threading
37
+ import time
38
+ from dataclasses import dataclass, field
39
+ from pathlib import Path
40
+ from typing import Dict, List, Optional, Any, Callable, Tuple
41
+ from enum import Enum
42
+
43
+ import torch
44
+ import torch.nn as nn
45
+
46
+ logger = logging.getLogger(__name__)
47
+
48
+
49
+ class CheckpointType(Enum):
50
+ """Type of checkpoint."""
51
+ HOT = "hot" # Frequent speculative snapshot
52
+ COLD = "cold" # Less frequent, more complete
53
+ RECOVERY = "recovery" # Fetched from peer for recovery
54
+
55
+
56
+ @dataclass
57
+ class CheckpointMetadata:
58
+ """Metadata for a checkpoint."""
59
+ checkpoint_id: str
60
+ timestamp: float
61
+ checkpoint_type: CheckpointType
62
+
63
+ # Training state
64
+ training_step: int
65
+ outer_step: int
66
+ inner_step: int
67
+
68
+ # Model info
69
+ model_hash: str
70
+ num_params: int
71
+ layer_ids: List[int]
72
+
73
+ # Storage
74
+ file_path: str
75
+ compressed_size: int
76
+ original_size: int
77
+
78
+ # Node info
79
+ node_id: str
80
+
81
+ @property
82
+ def age_seconds(self) -> float:
83
+ return time.time() - self.timestamp
84
+
85
+ @property
86
+ def compression_ratio(self) -> float:
87
+ if self.original_size == 0:
88
+ return 1.0
89
+ return self.compressed_size / self.original_size
90
+
91
+
92
+ @dataclass
93
+ class CheckpointConfig:
94
+ """Configuration for speculative checkpointing."""
95
+ # Timing
96
+ snapshot_interval: float = 120.0 # 2 minutes
97
+ cold_checkpoint_interval: float = 3600.0 # 1 hour
98
+
99
+ # Storage
100
+ max_hot_snapshots: int = 5 # Keep last 5 hot snapshots
101
+ max_cold_checkpoints: int = 3 # Keep last 3 cold checkpoints
102
+ checkpoint_dir: str = "./checkpoints"
103
+
104
+ # Compression
105
+ compression_level: int = 6 # gzip compression (1-9)
106
+
107
+ # Networking
108
+ announce_to_dht: bool = True # Announce availability
109
+ serve_to_peers: bool = True # Allow peers to fetch
110
+
111
+ # Recovery
112
+ auto_fetch_on_start: bool = True # Try to fetch from peers on start
113
+ recovery_timeout: float = 60.0 # Timeout for fetching
114
+
115
+
116
+ class SpeculativeCheckpointer:
117
+ """
118
+ Background checkpointer for resilient distributed training.
119
+
120
+ Runs in a background thread, periodically saving:
121
+ - Hot snapshots (every 2 minutes) - for fast recovery
122
+ - Cold checkpoints (hourly) - more complete, for long-term storage
123
+
124
+ Integrates with:
125
+ - DiLoCoTrainer for training state
126
+ - P2P/DHT for checkpoint announcement and fetching
127
+ - gRPC for serving checkpoints to peers
128
+ """
129
+
130
+ def __init__(
131
+ self,
132
+ model: nn.Module,
133
+ optimizer: torch.optim.Optimizer,
134
+ diloco_trainer: Optional[Any] = None, # DiLoCoTrainer
135
+ config: Optional[CheckpointConfig] = None,
136
+ node_id: str = "",
137
+ p2p_manager: Optional[Any] = None,
138
+ ):
139
+ """
140
+ Initialize speculative checkpointer.
141
+
142
+ Args:
143
+ model: Model to checkpoint
144
+ optimizer: Optimizer to checkpoint
145
+ diloco_trainer: Optional DiLoCo trainer for additional state
146
+ config: Checkpoint configuration
147
+ node_id: This node's ID
148
+ p2p_manager: P2P manager for DHT announcements
149
+ """
150
+ self.model = model
151
+ self.optimizer = optimizer
152
+ self.diloco = diloco_trainer
153
+ self.config = config or CheckpointConfig()
154
+ self.node_id = node_id
155
+ self.p2p = p2p_manager
156
+
157
+ # Ensure checkpoint directory exists
158
+ self.checkpoint_dir = Path(self.config.checkpoint_dir)
159
+ self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
160
+
161
+ # Snapshot tracking
162
+ self.hot_snapshots: List[CheckpointMetadata] = []
163
+ self.cold_checkpoints: List[CheckpointMetadata] = []
164
+
165
+ # Current state
166
+ self.training_step = 0
167
+ self.outer_step = 0
168
+ self.inner_step = 0
169
+
170
+ # Background thread
171
+ self.running = False
172
+ self._thread: Optional[threading.Thread] = None
173
+ self._last_hot_snapshot = 0.0
174
+ self._last_cold_checkpoint = 0.0
175
+
176
+ # Stats
177
+ self.snapshots_saved = 0
178
+ self.snapshots_served = 0
179
+ self.recoveries_performed = 0
180
+
181
+ # Lock for thread safety
182
+ self._lock = threading.RLock()
183
+
184
+ logger.info(f"SpeculativeCheckpointer initialized: "
185
+ f"hot_interval={self.config.snapshot_interval}s, "
186
+ f"cold_interval={self.config.cold_checkpoint_interval}s")
187
+
188
+ # ==================== LIFECYCLE ====================
189
+
190
+ def start(self):
191
+ """Start background checkpointing."""
192
+ if self.running:
193
+ return
194
+
195
+ self.running = True
196
+ self._thread = threading.Thread(
197
+ target=self._checkpoint_loop,
198
+ daemon=True,
199
+ name="SpeculativeCheckpointer"
200
+ )
201
+ self._thread.start()
202
+
203
+ logger.info("Speculative checkpointing started")
204
+
205
+ def stop(self):
206
+ """Stop background checkpointing."""
207
+ self.running = False
208
+ if self._thread:
209
+ self._thread.join(timeout=5.0)
210
+
211
+ logger.info("Speculative checkpointing stopped")
212
+
213
+ def _checkpoint_loop(self):
214
+ """Background loop for periodic checkpointing."""
215
+ while self.running:
216
+ try:
217
+ now = time.time()
218
+
219
+ # Hot snapshot check
220
+ if (now - self._last_hot_snapshot) >= self.config.snapshot_interval:
221
+ self._save_hot_snapshot()
222
+ self._last_hot_snapshot = now
223
+
224
+ # Cold checkpoint check
225
+ if (now - self._last_cold_checkpoint) >= self.config.cold_checkpoint_interval:
226
+ self._save_cold_checkpoint()
227
+ self._last_cold_checkpoint = now
228
+
229
+ # Sleep for a bit
230
+ time.sleep(10) # Check every 10 seconds
231
+
232
+ except Exception as e:
233
+ logger.error(f"Checkpoint loop error: {e}")
234
+ time.sleep(30) # Back off on error
235
+
236
+ # ==================== SAVING ====================
237
+
238
+ def _save_hot_snapshot(self) -> Optional[CheckpointMetadata]:
239
+ """Save a hot snapshot for fast recovery."""
240
+ with self._lock:
241
+ try:
242
+ checkpoint_id = f"hot_{int(time.time())}_{self.node_id[:8]}"
243
+ filename = f"{checkpoint_id}.pt.gz"
244
+ filepath = self.checkpoint_dir / filename
245
+
246
+ # Build checkpoint
247
+ checkpoint = self._build_checkpoint(CheckpointType.HOT)
248
+
249
+ # Save compressed
250
+ original_size, compressed_size = self._save_compressed(
251
+ checkpoint, filepath
252
+ )
253
+
254
+ # Create metadata
255
+ metadata = CheckpointMetadata(
256
+ checkpoint_id=checkpoint_id,
257
+ timestamp=time.time(),
258
+ checkpoint_type=CheckpointType.HOT,
259
+ training_step=self.training_step,
260
+ outer_step=self.outer_step,
261
+ inner_step=self.inner_step,
262
+ model_hash=self._compute_model_hash(),
263
+ num_params=sum(p.numel() for p in self.model.parameters()),
264
+ layer_ids=self._get_layer_ids(),
265
+ file_path=str(filepath),
266
+ compressed_size=compressed_size,
267
+ original_size=original_size,
268
+ node_id=self.node_id,
269
+ )
270
+
271
+ # Track snapshot
272
+ self.hot_snapshots.append(metadata)
273
+ self.snapshots_saved += 1
274
+
275
+ # Cleanup old snapshots
276
+ self._cleanup_old_snapshots()
277
+
278
+ # Announce to DHT
279
+ if self.config.announce_to_dht:
280
+ self._announce_checkpoint(metadata)
281
+
282
+ logger.info(f"Hot snapshot saved: {checkpoint_id} "
283
+ f"({compressed_size/1024:.1f}KB, "
284
+ f"ratio={metadata.compression_ratio:.2f})")
285
+
286
+ return metadata
287
+
288
+ except Exception as e:
289
+ logger.error(f"Failed to save hot snapshot: {e}")
290
+ return None
291
+
292
+ def _save_cold_checkpoint(self) -> Optional[CheckpointMetadata]:
293
+ """Save a cold checkpoint with full state."""
294
+ with self._lock:
295
+ try:
296
+ checkpoint_id = f"cold_{int(time.time())}_{self.node_id[:8]}"
297
+ filename = f"{checkpoint_id}.pt.gz"
298
+ filepath = self.checkpoint_dir / filename
299
+
300
+ # Build checkpoint (more complete than hot)
301
+ checkpoint = self._build_checkpoint(CheckpointType.COLD)
302
+
303
+ # Save compressed
304
+ original_size, compressed_size = self._save_compressed(
305
+ checkpoint, filepath
306
+ )
307
+
308
+ # Create metadata
309
+ metadata = CheckpointMetadata(
310
+ checkpoint_id=checkpoint_id,
311
+ timestamp=time.time(),
312
+ checkpoint_type=CheckpointType.COLD,
313
+ training_step=self.training_step,
314
+ outer_step=self.outer_step,
315
+ inner_step=self.inner_step,
316
+ model_hash=self._compute_model_hash(),
317
+ num_params=sum(p.numel() for p in self.model.parameters()),
318
+ layer_ids=self._get_layer_ids(),
319
+ file_path=str(filepath),
320
+ compressed_size=compressed_size,
321
+ original_size=original_size,
322
+ node_id=self.node_id,
323
+ )
324
+
325
+ # Track checkpoint
326
+ self.cold_checkpoints.append(metadata)
327
+
328
+ # Cleanup old checkpoints
329
+ self._cleanup_old_checkpoints()
330
+
331
+ logger.info(f"Cold checkpoint saved: {checkpoint_id} "
332
+ f"({compressed_size/1024/1024:.1f}MB)")
333
+
334
+ return metadata
335
+
336
+ except Exception as e:
337
+ logger.error(f"Failed to save cold checkpoint: {e}")
338
+ return None
339
+
340
+ def _build_checkpoint(self, checkpoint_type: CheckpointType) -> Dict[str, Any]:
341
+ """Build checkpoint dictionary."""
342
+ checkpoint = {
343
+ 'checkpoint_type': checkpoint_type.value,
344
+ 'timestamp': time.time(),
345
+ 'node_id': self.node_id,
346
+
347
+ # Model state
348
+ 'model_state_dict': self.model.state_dict(),
349
+
350
+ # Optimizer state
351
+ 'optimizer_state_dict': self.optimizer.state_dict(),
352
+
353
+ # Training progress
354
+ 'training_step': self.training_step,
355
+ 'outer_step': self.outer_step,
356
+ 'inner_step': self.inner_step,
357
+ }
358
+
359
+ # Add DiLoCo state if available
360
+ if self.diloco is not None:
361
+ checkpoint['diloco_state'] = {
362
+ 'initial_weights': {
363
+ k: v.clone() for k, v in self.diloco.initial_weights.items()
364
+ },
365
+ 'outer_optimizer': self.diloco.outer_optimizer.state_dict(),
366
+ 'stats': {
367
+ 'inner_step_count': self.diloco.stats.inner_step_count,
368
+ 'outer_step_count': self.diloco.stats.outer_step_count,
369
+ 'total_inner_steps': self.diloco.stats.total_inner_steps,
370
+ },
371
+ 'phase': self.diloco.phase.value,
372
+ }
373
+
374
+ # For cold checkpoints, add extra info
375
+ if checkpoint_type == CheckpointType.COLD:
376
+ checkpoint['model_hash'] = self._compute_model_hash()
377
+ checkpoint['layer_ids'] = self._get_layer_ids()
378
+
379
+ return checkpoint
380
+
381
+ def _save_compressed(
382
+ self,
383
+ checkpoint: Dict[str, Any],
384
+ filepath: Path,
385
+ ) -> Tuple[int, int]:
386
+ """
387
+ Save checkpoint with gzip compression.
388
+
389
+ Returns:
390
+ (original_size, compressed_size)
391
+ """
392
+ # Serialize to buffer
393
+ buffer = io.BytesIO()
394
+ torch.save(checkpoint, buffer)
395
+ original_data = buffer.getvalue()
396
+ original_size = len(original_data)
397
+
398
+ # Compress
399
+ compressed_data = gzip.compress(
400
+ original_data,
401
+ compresslevel=self.config.compression_level
402
+ )
403
+ compressed_size = len(compressed_data)
404
+
405
+ # Write to file
406
+ with open(filepath, 'wb') as f:
407
+ f.write(compressed_data)
408
+
409
+ return original_size, compressed_size
410
+
411
+ # ==================== CLEANUP ====================
412
+
413
+ def _cleanup_old_snapshots(self):
414
+ """Remove old hot snapshots beyond max limit."""
415
+ while len(self.hot_snapshots) > self.config.max_hot_snapshots:
416
+ oldest = self.hot_snapshots.pop(0)
417
+ try:
418
+ Path(oldest.file_path).unlink(missing_ok=True)
419
+ logger.debug(f"Removed old snapshot: {oldest.checkpoint_id}")
420
+ except Exception as e:
421
+ logger.warning(f"Failed to remove snapshot: {e}")
422
+
423
+ def _cleanup_old_checkpoints(self):
424
+ """Remove old cold checkpoints beyond max limit."""
425
+ while len(self.cold_checkpoints) > self.config.max_cold_checkpoints:
426
+ oldest = self.cold_checkpoints.pop(0)
427
+ try:
428
+ Path(oldest.file_path).unlink(missing_ok=True)
429
+ logger.debug(f"Removed old checkpoint: {oldest.checkpoint_id}")
430
+ except Exception as e:
431
+ logger.warning(f"Failed to remove checkpoint: {e}")
432
+
433
+ # ==================== LOADING ====================
434
+
435
+ def load_latest_checkpoint(self) -> Optional[Dict[str, Any]]:
436
+ """Load the most recent checkpoint (hot or cold)."""
437
+ with self._lock:
438
+ # Find latest
439
+ all_checkpoints = self.hot_snapshots + self.cold_checkpoints
440
+ if not all_checkpoints:
441
+ return None
442
+
443
+ latest = max(all_checkpoints, key=lambda c: c.timestamp)
444
+ return self.load_checkpoint(latest.file_path)
445
+
446
+ def load_checkpoint(self, filepath: str) -> Optional[Dict[str, Any]]:
447
+ """Load checkpoint from file."""
448
+ try:
449
+ filepath = Path(filepath)
450
+
451
+ if filepath.suffix == '.gz' or str(filepath).endswith('.pt.gz'):
452
+ # Compressed
453
+ with gzip.open(filepath, 'rb') as f:
454
+ buffer = io.BytesIO(f.read())
455
+ return torch.load(buffer, map_location='cpu')
456
+ else:
457
+ # Uncompressed
458
+ return torch.load(filepath, map_location='cpu')
459
+
460
+ except Exception as e:
461
+ logger.error(f"Failed to load checkpoint {filepath}: {e}")
462
+ return None
463
+
464
+ def restore_from_checkpoint(self, checkpoint: Dict[str, Any]) -> bool:
465
+ """
466
+ Restore model/optimizer/trainer from checkpoint.
467
+
468
+ Args:
469
+ checkpoint: Checkpoint dictionary
470
+
471
+ Returns:
472
+ True if successful
473
+ """
474
+ with self._lock:
475
+ try:
476
+ # Restore model
477
+ if 'model_state_dict' in checkpoint:
478
+ self.model.load_state_dict(checkpoint['model_state_dict'])
479
+
480
+ # Restore optimizer
481
+ if 'optimizer_state_dict' in checkpoint:
482
+ self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
483
+
484
+ # Restore training state
485
+ self.training_step = checkpoint.get('training_step', 0)
486
+ self.outer_step = checkpoint.get('outer_step', 0)
487
+ self.inner_step = checkpoint.get('inner_step', 0)
488
+
489
+ # Restore DiLoCo state
490
+ if self.diloco is not None and 'diloco_state' in checkpoint:
491
+ self.diloco.load_state_dict(checkpoint['diloco_state'])
492
+
493
+ self.recoveries_performed += 1
494
+
495
+ logger.info(f"Restored from checkpoint: "
496
+ f"step={self.training_step}, "
497
+ f"outer={self.outer_step}")
498
+
499
+ return True
500
+
501
+ except Exception as e:
502
+ logger.error(f"Failed to restore from checkpoint: {e}")
503
+ return False
504
+
505
+ # ==================== PEER RECOVERY ====================
506
+
507
+ async def fetch_neighbor_snapshot(
508
+ self,
509
+ peer_id: str,
510
+ ) -> Optional[Dict[str, Any]]:
511
+ """
512
+ Fetch hot snapshot from a neighbor peer.
513
+
514
+ Used for fast recovery after crash.
515
+
516
+ Args:
517
+ peer_id: ID of peer to fetch from
518
+
519
+ Returns:
520
+ Checkpoint dict if successful, None otherwise
521
+ """
522
+ if self.p2p is None:
523
+ logger.warning("No P2P manager - cannot fetch from peer")
524
+ return None
525
+
526
+ try:
527
+ # Look up peer's checkpoint in DHT
528
+ key = f"checkpoint_{peer_id}"
529
+
530
+ if hasattr(self.p2p, 'dht') and self.p2p.dht:
531
+ checkpoint_info = self.p2p.dht.lookup_value(key)
532
+
533
+ if checkpoint_info:
534
+ # Fetch via gRPC
535
+ # This would use a GetHotSnapshot RPC
536
+ logger.info(f"Found checkpoint from peer {peer_id}")
537
+ # Implementation would go here
538
+
539
+ return None
540
+
541
+ except Exception as e:
542
+ logger.error(f"Failed to fetch from peer {peer_id}: {e}")
543
+ return None
544
+
545
+ async def try_auto_recovery(self) -> bool:
546
+ """
547
+ Attempt automatic recovery from peers.
548
+
549
+ Tries to find and load a recent checkpoint from any available peer.
550
+
551
+ Returns:
552
+ True if recovery succeeded
553
+ """
554
+ if not self.config.auto_fetch_on_start:
555
+ return False
556
+
557
+ if self.p2p is None:
558
+ return False
559
+
560
+ # Get list of known peers
561
+ peers = []
562
+ if hasattr(self.p2p, 'get_peers'):
563
+ peers = self.p2p.get_peers()
564
+
565
+ # Try each peer
566
+ for peer_id in peers:
567
+ checkpoint = await self.fetch_neighbor_snapshot(peer_id)
568
+ if checkpoint:
569
+ return self.restore_from_checkpoint(checkpoint)
570
+
571
+ return False
572
+
573
+ def _announce_checkpoint(self, metadata: CheckpointMetadata):
574
+ """Announce checkpoint availability to DHT."""
575
+ if self.p2p is None:
576
+ return
577
+
578
+ try:
579
+ if hasattr(self.p2p, 'dht') and self.p2p.dht:
580
+ key = f"checkpoint_{self.node_id}"
581
+ value = {
582
+ 'checkpoint_id': metadata.checkpoint_id,
583
+ 'timestamp': metadata.timestamp,
584
+ 'training_step': metadata.training_step,
585
+ 'model_hash': metadata.model_hash,
586
+ }
587
+ self.p2p.dht.store(key, str(value))
588
+ logger.debug(f"Announced checkpoint to DHT: {metadata.checkpoint_id}")
589
+
590
+ except Exception as e:
591
+ logger.warning(f"Failed to announce checkpoint: {e}")
592
+
593
+ # ==================== SERVING ====================
594
+
595
+ def get_latest_snapshot_for_serving(self) -> Optional[bytes]:
596
+ """
597
+ Get latest snapshot data for serving to peers.
598
+
599
+ Returns compressed checkpoint bytes.
600
+ """
601
+ if not self.config.serve_to_peers:
602
+ return None
603
+
604
+ with self._lock:
605
+ if not self.hot_snapshots:
606
+ return None
607
+
608
+ latest = self.hot_snapshots[-1]
609
+
610
+ try:
611
+ with open(latest.file_path, 'rb') as f:
612
+ data = f.read()
613
+
614
+ self.snapshots_served += 1
615
+ return data
616
+
617
+ except Exception as e:
618
+ logger.error(f"Failed to read snapshot for serving: {e}")
619
+ return None
620
+
621
+ # ==================== UTILITIES ====================
622
+
623
+ def _compute_model_hash(self) -> str:
624
+ """Compute hash of model parameters."""
625
+ hasher = hashlib.sha256()
626
+
627
+ for name, param in sorted(self.model.named_parameters()):
628
+ hasher.update(name.encode())
629
+ hasher.update(param.data.cpu().numpy().tobytes()[:1000]) # First 1000 bytes
630
+
631
+ return hasher.hexdigest()[:16]
632
+
633
+ def _get_layer_ids(self) -> List[int]:
634
+ """Get layer IDs from model if available."""
635
+ if hasattr(self.model, 'my_layer_ids'):
636
+ return list(self.model.my_layer_ids)
637
+ return []
638
+
639
+ def update_training_state(
640
+ self,
641
+ training_step: int,
642
+ outer_step: int = 0,
643
+ inner_step: int = 0,
644
+ ):
645
+ """Update training state for checkpoints."""
646
+ with self._lock:
647
+ self.training_step = training_step
648
+ self.outer_step = outer_step
649
+ self.inner_step = inner_step
650
+
651
+ def get_stats(self) -> Dict[str, Any]:
652
+ """Get checkpointer statistics."""
653
+ with self._lock:
654
+ return {
655
+ 'running': self.running,
656
+ 'snapshots_saved': self.snapshots_saved,
657
+ 'snapshots_served': self.snapshots_served,
658
+ 'recoveries_performed': self.recoveries_performed,
659
+ 'hot_snapshot_count': len(self.hot_snapshots),
660
+ 'cold_checkpoint_count': len(self.cold_checkpoints),
661
+ 'latest_snapshot_age': (
662
+ self.hot_snapshots[-1].age_seconds
663
+ if self.hot_snapshots else None
664
+ ),
665
+ 'training_step': self.training_step,
666
+ }
667
+
668
+ def force_snapshot(self) -> Optional[CheckpointMetadata]:
669
+ """Force an immediate hot snapshot."""
670
+ return self._save_hot_snapshot()
671
+
672
+
673
+ # ==================== FACTORY FUNCTIONS ====================
674
+
675
+ def create_checkpointer(
676
+ model: nn.Module,
677
+ optimizer: torch.optim.Optimizer,
678
+ checkpoint_dir: str = "./checkpoints",
679
+ snapshot_interval: float = 120.0,
680
+ node_id: str = "",
681
+ **config_kwargs,
682
+ ) -> SpeculativeCheckpointer:
683
+ """
684
+ Factory function to create a speculative checkpointer.
685
+
686
+ Args:
687
+ model: Model to checkpoint
688
+ optimizer: Optimizer to checkpoint
689
+ checkpoint_dir: Directory for checkpoints
690
+ snapshot_interval: Seconds between hot snapshots
691
+ node_id: This node's ID
692
+ **config_kwargs: Additional config options
693
+
694
+ Returns:
695
+ Configured SpeculativeCheckpointer
696
+ """
697
+ config = CheckpointConfig(
698
+ checkpoint_dir=checkpoint_dir,
699
+ snapshot_interval=snapshot_interval,
700
+ **config_kwargs,
701
+ )
702
+
703
+ return SpeculativeCheckpointer(
704
+ model=model,
705
+ optimizer=optimizer,
706
+ config=config,
707
+ node_id=node_id,
708
+ )
709
+