nexaroa 0.0.111__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- neuroshard/__init__.py +93 -0
- neuroshard/__main__.py +4 -0
- neuroshard/cli.py +466 -0
- neuroshard/core/__init__.py +92 -0
- neuroshard/core/consensus/verifier.py +252 -0
- neuroshard/core/crypto/__init__.py +20 -0
- neuroshard/core/crypto/ecdsa.py +392 -0
- neuroshard/core/economics/__init__.py +52 -0
- neuroshard/core/economics/constants.py +387 -0
- neuroshard/core/economics/ledger.py +2111 -0
- neuroshard/core/economics/market.py +975 -0
- neuroshard/core/economics/wallet.py +168 -0
- neuroshard/core/governance/__init__.py +74 -0
- neuroshard/core/governance/proposal.py +561 -0
- neuroshard/core/governance/registry.py +545 -0
- neuroshard/core/governance/versioning.py +332 -0
- neuroshard/core/governance/voting.py +453 -0
- neuroshard/core/model/__init__.py +30 -0
- neuroshard/core/model/dynamic.py +4186 -0
- neuroshard/core/model/llm.py +905 -0
- neuroshard/core/model/registry.py +164 -0
- neuroshard/core/model/scaler.py +387 -0
- neuroshard/core/model/tokenizer.py +568 -0
- neuroshard/core/network/__init__.py +56 -0
- neuroshard/core/network/connection_pool.py +72 -0
- neuroshard/core/network/dht.py +130 -0
- neuroshard/core/network/dht_plan.py +55 -0
- neuroshard/core/network/dht_proof_store.py +516 -0
- neuroshard/core/network/dht_protocol.py +261 -0
- neuroshard/core/network/dht_service.py +506 -0
- neuroshard/core/network/encrypted_channel.py +141 -0
- neuroshard/core/network/nat.py +201 -0
- neuroshard/core/network/nat_traversal.py +695 -0
- neuroshard/core/network/p2p.py +929 -0
- neuroshard/core/network/p2p_data.py +150 -0
- neuroshard/core/swarm/__init__.py +106 -0
- neuroshard/core/swarm/aggregation.py +729 -0
- neuroshard/core/swarm/buffers.py +643 -0
- neuroshard/core/swarm/checkpoint.py +709 -0
- neuroshard/core/swarm/compute.py +624 -0
- neuroshard/core/swarm/diloco.py +844 -0
- neuroshard/core/swarm/factory.py +1288 -0
- neuroshard/core/swarm/heartbeat.py +669 -0
- neuroshard/core/swarm/logger.py +487 -0
- neuroshard/core/swarm/router.py +658 -0
- neuroshard/core/swarm/service.py +640 -0
- neuroshard/core/training/__init__.py +29 -0
- neuroshard/core/training/checkpoint.py +600 -0
- neuroshard/core/training/distributed.py +1602 -0
- neuroshard/core/training/global_tracker.py +617 -0
- neuroshard/core/training/production.py +276 -0
- neuroshard/governance_cli.py +729 -0
- neuroshard/grpc_server.py +895 -0
- neuroshard/runner.py +3223 -0
- neuroshard/sdk/__init__.py +92 -0
- neuroshard/sdk/client.py +990 -0
- neuroshard/sdk/errors.py +101 -0
- neuroshard/sdk/types.py +282 -0
- neuroshard/tracker/__init__.py +0 -0
- neuroshard/tracker/server.py +864 -0
- neuroshard/ui/__init__.py +0 -0
- neuroshard/ui/app.py +102 -0
- neuroshard/ui/templates/index.html +1052 -0
- neuroshard/utils/__init__.py +0 -0
- neuroshard/utils/autostart.py +81 -0
- neuroshard/utils/hardware.py +121 -0
- neuroshard/utils/serialization.py +90 -0
- neuroshard/version.py +1 -0
- nexaroa-0.0.111.dist-info/METADATA +283 -0
- nexaroa-0.0.111.dist-info/RECORD +78 -0
- nexaroa-0.0.111.dist-info/WHEEL +5 -0
- nexaroa-0.0.111.dist-info/entry_points.txt +4 -0
- nexaroa-0.0.111.dist-info/licenses/LICENSE +190 -0
- nexaroa-0.0.111.dist-info/top_level.txt +2 -0
- protos/__init__.py +0 -0
- protos/neuroshard.proto +651 -0
- protos/neuroshard_pb2.py +160 -0
- protos/neuroshard_pb2_grpc.py +1298 -0
|
@@ -0,0 +1,709 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Speculative Checkpointing - High-Frequency Snapshots for Fast Recovery
|
|
3
|
+
|
|
4
|
+
Implements background checkpointing for resilient distributed training:
|
|
5
|
+
- Saves snapshots every 2 minutes (configurable)
|
|
6
|
+
- Keeps rolling window of last N snapshots
|
|
7
|
+
- Announces availability to DHT for peer recovery
|
|
8
|
+
- Enables fast crash recovery vs full restart
|
|
9
|
+
|
|
10
|
+
Key Insight: "Cheaper to over-checkpoint than to re-train."
|
|
11
|
+
|
|
12
|
+
On crash, neighbors can fetch the "hot" snapshot and resume
|
|
13
|
+
with minimal loss of training progress.
|
|
14
|
+
|
|
15
|
+
Usage:
|
|
16
|
+
checkpointer = SpeculativeCheckpointer(
|
|
17
|
+
model=model,
|
|
18
|
+
optimizer=optimizer,
|
|
19
|
+
diloco_trainer=trainer,
|
|
20
|
+
checkpoint_dir="/path/to/checkpoints"
|
|
21
|
+
)
|
|
22
|
+
checkpointer.start()
|
|
23
|
+
|
|
24
|
+
# On crash recovery:
|
|
25
|
+
checkpoint = await checkpointer.fetch_neighbor_snapshot(peer_id)
|
|
26
|
+
checkpointer.restore_from_checkpoint(checkpoint)
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
import asyncio
|
|
30
|
+
import gzip
|
|
31
|
+
import hashlib
|
|
32
|
+
import io
|
|
33
|
+
import logging
|
|
34
|
+
import os
|
|
35
|
+
import shutil
|
|
36
|
+
import threading
|
|
37
|
+
import time
|
|
38
|
+
from dataclasses import dataclass, field
|
|
39
|
+
from pathlib import Path
|
|
40
|
+
from typing import Dict, List, Optional, Any, Callable, Tuple
|
|
41
|
+
from enum import Enum
|
|
42
|
+
|
|
43
|
+
import torch
|
|
44
|
+
import torch.nn as nn
|
|
45
|
+
|
|
46
|
+
logger = logging.getLogger(__name__)
|
|
47
|
+
|
|
48
|
+
|
|
49
|
+
class CheckpointType(Enum):
|
|
50
|
+
"""Type of checkpoint."""
|
|
51
|
+
HOT = "hot" # Frequent speculative snapshot
|
|
52
|
+
COLD = "cold" # Less frequent, more complete
|
|
53
|
+
RECOVERY = "recovery" # Fetched from peer for recovery
|
|
54
|
+
|
|
55
|
+
|
|
56
|
+
@dataclass
|
|
57
|
+
class CheckpointMetadata:
|
|
58
|
+
"""Metadata for a checkpoint."""
|
|
59
|
+
checkpoint_id: str
|
|
60
|
+
timestamp: float
|
|
61
|
+
checkpoint_type: CheckpointType
|
|
62
|
+
|
|
63
|
+
# Training state
|
|
64
|
+
training_step: int
|
|
65
|
+
outer_step: int
|
|
66
|
+
inner_step: int
|
|
67
|
+
|
|
68
|
+
# Model info
|
|
69
|
+
model_hash: str
|
|
70
|
+
num_params: int
|
|
71
|
+
layer_ids: List[int]
|
|
72
|
+
|
|
73
|
+
# Storage
|
|
74
|
+
file_path: str
|
|
75
|
+
compressed_size: int
|
|
76
|
+
original_size: int
|
|
77
|
+
|
|
78
|
+
# Node info
|
|
79
|
+
node_id: str
|
|
80
|
+
|
|
81
|
+
@property
|
|
82
|
+
def age_seconds(self) -> float:
|
|
83
|
+
return time.time() - self.timestamp
|
|
84
|
+
|
|
85
|
+
@property
|
|
86
|
+
def compression_ratio(self) -> float:
|
|
87
|
+
if self.original_size == 0:
|
|
88
|
+
return 1.0
|
|
89
|
+
return self.compressed_size / self.original_size
|
|
90
|
+
|
|
91
|
+
|
|
92
|
+
@dataclass
|
|
93
|
+
class CheckpointConfig:
|
|
94
|
+
"""Configuration for speculative checkpointing."""
|
|
95
|
+
# Timing
|
|
96
|
+
snapshot_interval: float = 120.0 # 2 minutes
|
|
97
|
+
cold_checkpoint_interval: float = 3600.0 # 1 hour
|
|
98
|
+
|
|
99
|
+
# Storage
|
|
100
|
+
max_hot_snapshots: int = 5 # Keep last 5 hot snapshots
|
|
101
|
+
max_cold_checkpoints: int = 3 # Keep last 3 cold checkpoints
|
|
102
|
+
checkpoint_dir: str = "./checkpoints"
|
|
103
|
+
|
|
104
|
+
# Compression
|
|
105
|
+
compression_level: int = 6 # gzip compression (1-9)
|
|
106
|
+
|
|
107
|
+
# Networking
|
|
108
|
+
announce_to_dht: bool = True # Announce availability
|
|
109
|
+
serve_to_peers: bool = True # Allow peers to fetch
|
|
110
|
+
|
|
111
|
+
# Recovery
|
|
112
|
+
auto_fetch_on_start: bool = True # Try to fetch from peers on start
|
|
113
|
+
recovery_timeout: float = 60.0 # Timeout for fetching
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
class SpeculativeCheckpointer:
|
|
117
|
+
"""
|
|
118
|
+
Background checkpointer for resilient distributed training.
|
|
119
|
+
|
|
120
|
+
Runs in a background thread, periodically saving:
|
|
121
|
+
- Hot snapshots (every 2 minutes) - for fast recovery
|
|
122
|
+
- Cold checkpoints (hourly) - more complete, for long-term storage
|
|
123
|
+
|
|
124
|
+
Integrates with:
|
|
125
|
+
- DiLoCoTrainer for training state
|
|
126
|
+
- P2P/DHT for checkpoint announcement and fetching
|
|
127
|
+
- gRPC for serving checkpoints to peers
|
|
128
|
+
"""
|
|
129
|
+
|
|
130
|
+
def __init__(
|
|
131
|
+
self,
|
|
132
|
+
model: nn.Module,
|
|
133
|
+
optimizer: torch.optim.Optimizer,
|
|
134
|
+
diloco_trainer: Optional[Any] = None, # DiLoCoTrainer
|
|
135
|
+
config: Optional[CheckpointConfig] = None,
|
|
136
|
+
node_id: str = "",
|
|
137
|
+
p2p_manager: Optional[Any] = None,
|
|
138
|
+
):
|
|
139
|
+
"""
|
|
140
|
+
Initialize speculative checkpointer.
|
|
141
|
+
|
|
142
|
+
Args:
|
|
143
|
+
model: Model to checkpoint
|
|
144
|
+
optimizer: Optimizer to checkpoint
|
|
145
|
+
diloco_trainer: Optional DiLoCo trainer for additional state
|
|
146
|
+
config: Checkpoint configuration
|
|
147
|
+
node_id: This node's ID
|
|
148
|
+
p2p_manager: P2P manager for DHT announcements
|
|
149
|
+
"""
|
|
150
|
+
self.model = model
|
|
151
|
+
self.optimizer = optimizer
|
|
152
|
+
self.diloco = diloco_trainer
|
|
153
|
+
self.config = config or CheckpointConfig()
|
|
154
|
+
self.node_id = node_id
|
|
155
|
+
self.p2p = p2p_manager
|
|
156
|
+
|
|
157
|
+
# Ensure checkpoint directory exists
|
|
158
|
+
self.checkpoint_dir = Path(self.config.checkpoint_dir)
|
|
159
|
+
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
|
160
|
+
|
|
161
|
+
# Snapshot tracking
|
|
162
|
+
self.hot_snapshots: List[CheckpointMetadata] = []
|
|
163
|
+
self.cold_checkpoints: List[CheckpointMetadata] = []
|
|
164
|
+
|
|
165
|
+
# Current state
|
|
166
|
+
self.training_step = 0
|
|
167
|
+
self.outer_step = 0
|
|
168
|
+
self.inner_step = 0
|
|
169
|
+
|
|
170
|
+
# Background thread
|
|
171
|
+
self.running = False
|
|
172
|
+
self._thread: Optional[threading.Thread] = None
|
|
173
|
+
self._last_hot_snapshot = 0.0
|
|
174
|
+
self._last_cold_checkpoint = 0.0
|
|
175
|
+
|
|
176
|
+
# Stats
|
|
177
|
+
self.snapshots_saved = 0
|
|
178
|
+
self.snapshots_served = 0
|
|
179
|
+
self.recoveries_performed = 0
|
|
180
|
+
|
|
181
|
+
# Lock for thread safety
|
|
182
|
+
self._lock = threading.RLock()
|
|
183
|
+
|
|
184
|
+
logger.info(f"SpeculativeCheckpointer initialized: "
|
|
185
|
+
f"hot_interval={self.config.snapshot_interval}s, "
|
|
186
|
+
f"cold_interval={self.config.cold_checkpoint_interval}s")
|
|
187
|
+
|
|
188
|
+
# ==================== LIFECYCLE ====================
|
|
189
|
+
|
|
190
|
+
def start(self):
|
|
191
|
+
"""Start background checkpointing."""
|
|
192
|
+
if self.running:
|
|
193
|
+
return
|
|
194
|
+
|
|
195
|
+
self.running = True
|
|
196
|
+
self._thread = threading.Thread(
|
|
197
|
+
target=self._checkpoint_loop,
|
|
198
|
+
daemon=True,
|
|
199
|
+
name="SpeculativeCheckpointer"
|
|
200
|
+
)
|
|
201
|
+
self._thread.start()
|
|
202
|
+
|
|
203
|
+
logger.info("Speculative checkpointing started")
|
|
204
|
+
|
|
205
|
+
def stop(self):
|
|
206
|
+
"""Stop background checkpointing."""
|
|
207
|
+
self.running = False
|
|
208
|
+
if self._thread:
|
|
209
|
+
self._thread.join(timeout=5.0)
|
|
210
|
+
|
|
211
|
+
logger.info("Speculative checkpointing stopped")
|
|
212
|
+
|
|
213
|
+
def _checkpoint_loop(self):
|
|
214
|
+
"""Background loop for periodic checkpointing."""
|
|
215
|
+
while self.running:
|
|
216
|
+
try:
|
|
217
|
+
now = time.time()
|
|
218
|
+
|
|
219
|
+
# Hot snapshot check
|
|
220
|
+
if (now - self._last_hot_snapshot) >= self.config.snapshot_interval:
|
|
221
|
+
self._save_hot_snapshot()
|
|
222
|
+
self._last_hot_snapshot = now
|
|
223
|
+
|
|
224
|
+
# Cold checkpoint check
|
|
225
|
+
if (now - self._last_cold_checkpoint) >= self.config.cold_checkpoint_interval:
|
|
226
|
+
self._save_cold_checkpoint()
|
|
227
|
+
self._last_cold_checkpoint = now
|
|
228
|
+
|
|
229
|
+
# Sleep for a bit
|
|
230
|
+
time.sleep(10) # Check every 10 seconds
|
|
231
|
+
|
|
232
|
+
except Exception as e:
|
|
233
|
+
logger.error(f"Checkpoint loop error: {e}")
|
|
234
|
+
time.sleep(30) # Back off on error
|
|
235
|
+
|
|
236
|
+
# ==================== SAVING ====================
|
|
237
|
+
|
|
238
|
+
def _save_hot_snapshot(self) -> Optional[CheckpointMetadata]:
|
|
239
|
+
"""Save a hot snapshot for fast recovery."""
|
|
240
|
+
with self._lock:
|
|
241
|
+
try:
|
|
242
|
+
checkpoint_id = f"hot_{int(time.time())}_{self.node_id[:8]}"
|
|
243
|
+
filename = f"{checkpoint_id}.pt.gz"
|
|
244
|
+
filepath = self.checkpoint_dir / filename
|
|
245
|
+
|
|
246
|
+
# Build checkpoint
|
|
247
|
+
checkpoint = self._build_checkpoint(CheckpointType.HOT)
|
|
248
|
+
|
|
249
|
+
# Save compressed
|
|
250
|
+
original_size, compressed_size = self._save_compressed(
|
|
251
|
+
checkpoint, filepath
|
|
252
|
+
)
|
|
253
|
+
|
|
254
|
+
# Create metadata
|
|
255
|
+
metadata = CheckpointMetadata(
|
|
256
|
+
checkpoint_id=checkpoint_id,
|
|
257
|
+
timestamp=time.time(),
|
|
258
|
+
checkpoint_type=CheckpointType.HOT,
|
|
259
|
+
training_step=self.training_step,
|
|
260
|
+
outer_step=self.outer_step,
|
|
261
|
+
inner_step=self.inner_step,
|
|
262
|
+
model_hash=self._compute_model_hash(),
|
|
263
|
+
num_params=sum(p.numel() for p in self.model.parameters()),
|
|
264
|
+
layer_ids=self._get_layer_ids(),
|
|
265
|
+
file_path=str(filepath),
|
|
266
|
+
compressed_size=compressed_size,
|
|
267
|
+
original_size=original_size,
|
|
268
|
+
node_id=self.node_id,
|
|
269
|
+
)
|
|
270
|
+
|
|
271
|
+
# Track snapshot
|
|
272
|
+
self.hot_snapshots.append(metadata)
|
|
273
|
+
self.snapshots_saved += 1
|
|
274
|
+
|
|
275
|
+
# Cleanup old snapshots
|
|
276
|
+
self._cleanup_old_snapshots()
|
|
277
|
+
|
|
278
|
+
# Announce to DHT
|
|
279
|
+
if self.config.announce_to_dht:
|
|
280
|
+
self._announce_checkpoint(metadata)
|
|
281
|
+
|
|
282
|
+
logger.info(f"Hot snapshot saved: {checkpoint_id} "
|
|
283
|
+
f"({compressed_size/1024:.1f}KB, "
|
|
284
|
+
f"ratio={metadata.compression_ratio:.2f})")
|
|
285
|
+
|
|
286
|
+
return metadata
|
|
287
|
+
|
|
288
|
+
except Exception as e:
|
|
289
|
+
logger.error(f"Failed to save hot snapshot: {e}")
|
|
290
|
+
return None
|
|
291
|
+
|
|
292
|
+
def _save_cold_checkpoint(self) -> Optional[CheckpointMetadata]:
|
|
293
|
+
"""Save a cold checkpoint with full state."""
|
|
294
|
+
with self._lock:
|
|
295
|
+
try:
|
|
296
|
+
checkpoint_id = f"cold_{int(time.time())}_{self.node_id[:8]}"
|
|
297
|
+
filename = f"{checkpoint_id}.pt.gz"
|
|
298
|
+
filepath = self.checkpoint_dir / filename
|
|
299
|
+
|
|
300
|
+
# Build checkpoint (more complete than hot)
|
|
301
|
+
checkpoint = self._build_checkpoint(CheckpointType.COLD)
|
|
302
|
+
|
|
303
|
+
# Save compressed
|
|
304
|
+
original_size, compressed_size = self._save_compressed(
|
|
305
|
+
checkpoint, filepath
|
|
306
|
+
)
|
|
307
|
+
|
|
308
|
+
# Create metadata
|
|
309
|
+
metadata = CheckpointMetadata(
|
|
310
|
+
checkpoint_id=checkpoint_id,
|
|
311
|
+
timestamp=time.time(),
|
|
312
|
+
checkpoint_type=CheckpointType.COLD,
|
|
313
|
+
training_step=self.training_step,
|
|
314
|
+
outer_step=self.outer_step,
|
|
315
|
+
inner_step=self.inner_step,
|
|
316
|
+
model_hash=self._compute_model_hash(),
|
|
317
|
+
num_params=sum(p.numel() for p in self.model.parameters()),
|
|
318
|
+
layer_ids=self._get_layer_ids(),
|
|
319
|
+
file_path=str(filepath),
|
|
320
|
+
compressed_size=compressed_size,
|
|
321
|
+
original_size=original_size,
|
|
322
|
+
node_id=self.node_id,
|
|
323
|
+
)
|
|
324
|
+
|
|
325
|
+
# Track checkpoint
|
|
326
|
+
self.cold_checkpoints.append(metadata)
|
|
327
|
+
|
|
328
|
+
# Cleanup old checkpoints
|
|
329
|
+
self._cleanup_old_checkpoints()
|
|
330
|
+
|
|
331
|
+
logger.info(f"Cold checkpoint saved: {checkpoint_id} "
|
|
332
|
+
f"({compressed_size/1024/1024:.1f}MB)")
|
|
333
|
+
|
|
334
|
+
return metadata
|
|
335
|
+
|
|
336
|
+
except Exception as e:
|
|
337
|
+
logger.error(f"Failed to save cold checkpoint: {e}")
|
|
338
|
+
return None
|
|
339
|
+
|
|
340
|
+
def _build_checkpoint(self, checkpoint_type: CheckpointType) -> Dict[str, Any]:
|
|
341
|
+
"""Build checkpoint dictionary."""
|
|
342
|
+
checkpoint = {
|
|
343
|
+
'checkpoint_type': checkpoint_type.value,
|
|
344
|
+
'timestamp': time.time(),
|
|
345
|
+
'node_id': self.node_id,
|
|
346
|
+
|
|
347
|
+
# Model state
|
|
348
|
+
'model_state_dict': self.model.state_dict(),
|
|
349
|
+
|
|
350
|
+
# Optimizer state
|
|
351
|
+
'optimizer_state_dict': self.optimizer.state_dict(),
|
|
352
|
+
|
|
353
|
+
# Training progress
|
|
354
|
+
'training_step': self.training_step,
|
|
355
|
+
'outer_step': self.outer_step,
|
|
356
|
+
'inner_step': self.inner_step,
|
|
357
|
+
}
|
|
358
|
+
|
|
359
|
+
# Add DiLoCo state if available
|
|
360
|
+
if self.diloco is not None:
|
|
361
|
+
checkpoint['diloco_state'] = {
|
|
362
|
+
'initial_weights': {
|
|
363
|
+
k: v.clone() for k, v in self.diloco.initial_weights.items()
|
|
364
|
+
},
|
|
365
|
+
'outer_optimizer': self.diloco.outer_optimizer.state_dict(),
|
|
366
|
+
'stats': {
|
|
367
|
+
'inner_step_count': self.diloco.stats.inner_step_count,
|
|
368
|
+
'outer_step_count': self.diloco.stats.outer_step_count,
|
|
369
|
+
'total_inner_steps': self.diloco.stats.total_inner_steps,
|
|
370
|
+
},
|
|
371
|
+
'phase': self.diloco.phase.value,
|
|
372
|
+
}
|
|
373
|
+
|
|
374
|
+
# For cold checkpoints, add extra info
|
|
375
|
+
if checkpoint_type == CheckpointType.COLD:
|
|
376
|
+
checkpoint['model_hash'] = self._compute_model_hash()
|
|
377
|
+
checkpoint['layer_ids'] = self._get_layer_ids()
|
|
378
|
+
|
|
379
|
+
return checkpoint
|
|
380
|
+
|
|
381
|
+
def _save_compressed(
|
|
382
|
+
self,
|
|
383
|
+
checkpoint: Dict[str, Any],
|
|
384
|
+
filepath: Path,
|
|
385
|
+
) -> Tuple[int, int]:
|
|
386
|
+
"""
|
|
387
|
+
Save checkpoint with gzip compression.
|
|
388
|
+
|
|
389
|
+
Returns:
|
|
390
|
+
(original_size, compressed_size)
|
|
391
|
+
"""
|
|
392
|
+
# Serialize to buffer
|
|
393
|
+
buffer = io.BytesIO()
|
|
394
|
+
torch.save(checkpoint, buffer)
|
|
395
|
+
original_data = buffer.getvalue()
|
|
396
|
+
original_size = len(original_data)
|
|
397
|
+
|
|
398
|
+
# Compress
|
|
399
|
+
compressed_data = gzip.compress(
|
|
400
|
+
original_data,
|
|
401
|
+
compresslevel=self.config.compression_level
|
|
402
|
+
)
|
|
403
|
+
compressed_size = len(compressed_data)
|
|
404
|
+
|
|
405
|
+
# Write to file
|
|
406
|
+
with open(filepath, 'wb') as f:
|
|
407
|
+
f.write(compressed_data)
|
|
408
|
+
|
|
409
|
+
return original_size, compressed_size
|
|
410
|
+
|
|
411
|
+
# ==================== CLEANUP ====================
|
|
412
|
+
|
|
413
|
+
def _cleanup_old_snapshots(self):
|
|
414
|
+
"""Remove old hot snapshots beyond max limit."""
|
|
415
|
+
while len(self.hot_snapshots) > self.config.max_hot_snapshots:
|
|
416
|
+
oldest = self.hot_snapshots.pop(0)
|
|
417
|
+
try:
|
|
418
|
+
Path(oldest.file_path).unlink(missing_ok=True)
|
|
419
|
+
logger.debug(f"Removed old snapshot: {oldest.checkpoint_id}")
|
|
420
|
+
except Exception as e:
|
|
421
|
+
logger.warning(f"Failed to remove snapshot: {e}")
|
|
422
|
+
|
|
423
|
+
def _cleanup_old_checkpoints(self):
|
|
424
|
+
"""Remove old cold checkpoints beyond max limit."""
|
|
425
|
+
while len(self.cold_checkpoints) > self.config.max_cold_checkpoints:
|
|
426
|
+
oldest = self.cold_checkpoints.pop(0)
|
|
427
|
+
try:
|
|
428
|
+
Path(oldest.file_path).unlink(missing_ok=True)
|
|
429
|
+
logger.debug(f"Removed old checkpoint: {oldest.checkpoint_id}")
|
|
430
|
+
except Exception as e:
|
|
431
|
+
logger.warning(f"Failed to remove checkpoint: {e}")
|
|
432
|
+
|
|
433
|
+
# ==================== LOADING ====================
|
|
434
|
+
|
|
435
|
+
def load_latest_checkpoint(self) -> Optional[Dict[str, Any]]:
|
|
436
|
+
"""Load the most recent checkpoint (hot or cold)."""
|
|
437
|
+
with self._lock:
|
|
438
|
+
# Find latest
|
|
439
|
+
all_checkpoints = self.hot_snapshots + self.cold_checkpoints
|
|
440
|
+
if not all_checkpoints:
|
|
441
|
+
return None
|
|
442
|
+
|
|
443
|
+
latest = max(all_checkpoints, key=lambda c: c.timestamp)
|
|
444
|
+
return self.load_checkpoint(latest.file_path)
|
|
445
|
+
|
|
446
|
+
def load_checkpoint(self, filepath: str) -> Optional[Dict[str, Any]]:
|
|
447
|
+
"""Load checkpoint from file."""
|
|
448
|
+
try:
|
|
449
|
+
filepath = Path(filepath)
|
|
450
|
+
|
|
451
|
+
if filepath.suffix == '.gz' or str(filepath).endswith('.pt.gz'):
|
|
452
|
+
# Compressed
|
|
453
|
+
with gzip.open(filepath, 'rb') as f:
|
|
454
|
+
buffer = io.BytesIO(f.read())
|
|
455
|
+
return torch.load(buffer, map_location='cpu')
|
|
456
|
+
else:
|
|
457
|
+
# Uncompressed
|
|
458
|
+
return torch.load(filepath, map_location='cpu')
|
|
459
|
+
|
|
460
|
+
except Exception as e:
|
|
461
|
+
logger.error(f"Failed to load checkpoint {filepath}: {e}")
|
|
462
|
+
return None
|
|
463
|
+
|
|
464
|
+
def restore_from_checkpoint(self, checkpoint: Dict[str, Any]) -> bool:
|
|
465
|
+
"""
|
|
466
|
+
Restore model/optimizer/trainer from checkpoint.
|
|
467
|
+
|
|
468
|
+
Args:
|
|
469
|
+
checkpoint: Checkpoint dictionary
|
|
470
|
+
|
|
471
|
+
Returns:
|
|
472
|
+
True if successful
|
|
473
|
+
"""
|
|
474
|
+
with self._lock:
|
|
475
|
+
try:
|
|
476
|
+
# Restore model
|
|
477
|
+
if 'model_state_dict' in checkpoint:
|
|
478
|
+
self.model.load_state_dict(checkpoint['model_state_dict'])
|
|
479
|
+
|
|
480
|
+
# Restore optimizer
|
|
481
|
+
if 'optimizer_state_dict' in checkpoint:
|
|
482
|
+
self.optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
|
|
483
|
+
|
|
484
|
+
# Restore training state
|
|
485
|
+
self.training_step = checkpoint.get('training_step', 0)
|
|
486
|
+
self.outer_step = checkpoint.get('outer_step', 0)
|
|
487
|
+
self.inner_step = checkpoint.get('inner_step', 0)
|
|
488
|
+
|
|
489
|
+
# Restore DiLoCo state
|
|
490
|
+
if self.diloco is not None and 'diloco_state' in checkpoint:
|
|
491
|
+
self.diloco.load_state_dict(checkpoint['diloco_state'])
|
|
492
|
+
|
|
493
|
+
self.recoveries_performed += 1
|
|
494
|
+
|
|
495
|
+
logger.info(f"Restored from checkpoint: "
|
|
496
|
+
f"step={self.training_step}, "
|
|
497
|
+
f"outer={self.outer_step}")
|
|
498
|
+
|
|
499
|
+
return True
|
|
500
|
+
|
|
501
|
+
except Exception as e:
|
|
502
|
+
logger.error(f"Failed to restore from checkpoint: {e}")
|
|
503
|
+
return False
|
|
504
|
+
|
|
505
|
+
# ==================== PEER RECOVERY ====================
|
|
506
|
+
|
|
507
|
+
async def fetch_neighbor_snapshot(
|
|
508
|
+
self,
|
|
509
|
+
peer_id: str,
|
|
510
|
+
) -> Optional[Dict[str, Any]]:
|
|
511
|
+
"""
|
|
512
|
+
Fetch hot snapshot from a neighbor peer.
|
|
513
|
+
|
|
514
|
+
Used for fast recovery after crash.
|
|
515
|
+
|
|
516
|
+
Args:
|
|
517
|
+
peer_id: ID of peer to fetch from
|
|
518
|
+
|
|
519
|
+
Returns:
|
|
520
|
+
Checkpoint dict if successful, None otherwise
|
|
521
|
+
"""
|
|
522
|
+
if self.p2p is None:
|
|
523
|
+
logger.warning("No P2P manager - cannot fetch from peer")
|
|
524
|
+
return None
|
|
525
|
+
|
|
526
|
+
try:
|
|
527
|
+
# Look up peer's checkpoint in DHT
|
|
528
|
+
key = f"checkpoint_{peer_id}"
|
|
529
|
+
|
|
530
|
+
if hasattr(self.p2p, 'dht') and self.p2p.dht:
|
|
531
|
+
checkpoint_info = self.p2p.dht.lookup_value(key)
|
|
532
|
+
|
|
533
|
+
if checkpoint_info:
|
|
534
|
+
# Fetch via gRPC
|
|
535
|
+
# This would use a GetHotSnapshot RPC
|
|
536
|
+
logger.info(f"Found checkpoint from peer {peer_id}")
|
|
537
|
+
# Implementation would go here
|
|
538
|
+
|
|
539
|
+
return None
|
|
540
|
+
|
|
541
|
+
except Exception as e:
|
|
542
|
+
logger.error(f"Failed to fetch from peer {peer_id}: {e}")
|
|
543
|
+
return None
|
|
544
|
+
|
|
545
|
+
async def try_auto_recovery(self) -> bool:
|
|
546
|
+
"""
|
|
547
|
+
Attempt automatic recovery from peers.
|
|
548
|
+
|
|
549
|
+
Tries to find and load a recent checkpoint from any available peer.
|
|
550
|
+
|
|
551
|
+
Returns:
|
|
552
|
+
True if recovery succeeded
|
|
553
|
+
"""
|
|
554
|
+
if not self.config.auto_fetch_on_start:
|
|
555
|
+
return False
|
|
556
|
+
|
|
557
|
+
if self.p2p is None:
|
|
558
|
+
return False
|
|
559
|
+
|
|
560
|
+
# Get list of known peers
|
|
561
|
+
peers = []
|
|
562
|
+
if hasattr(self.p2p, 'get_peers'):
|
|
563
|
+
peers = self.p2p.get_peers()
|
|
564
|
+
|
|
565
|
+
# Try each peer
|
|
566
|
+
for peer_id in peers:
|
|
567
|
+
checkpoint = await self.fetch_neighbor_snapshot(peer_id)
|
|
568
|
+
if checkpoint:
|
|
569
|
+
return self.restore_from_checkpoint(checkpoint)
|
|
570
|
+
|
|
571
|
+
return False
|
|
572
|
+
|
|
573
|
+
def _announce_checkpoint(self, metadata: CheckpointMetadata):
|
|
574
|
+
"""Announce checkpoint availability to DHT."""
|
|
575
|
+
if self.p2p is None:
|
|
576
|
+
return
|
|
577
|
+
|
|
578
|
+
try:
|
|
579
|
+
if hasattr(self.p2p, 'dht') and self.p2p.dht:
|
|
580
|
+
key = f"checkpoint_{self.node_id}"
|
|
581
|
+
value = {
|
|
582
|
+
'checkpoint_id': metadata.checkpoint_id,
|
|
583
|
+
'timestamp': metadata.timestamp,
|
|
584
|
+
'training_step': metadata.training_step,
|
|
585
|
+
'model_hash': metadata.model_hash,
|
|
586
|
+
}
|
|
587
|
+
self.p2p.dht.store(key, str(value))
|
|
588
|
+
logger.debug(f"Announced checkpoint to DHT: {metadata.checkpoint_id}")
|
|
589
|
+
|
|
590
|
+
except Exception as e:
|
|
591
|
+
logger.warning(f"Failed to announce checkpoint: {e}")
|
|
592
|
+
|
|
593
|
+
# ==================== SERVING ====================
|
|
594
|
+
|
|
595
|
+
def get_latest_snapshot_for_serving(self) -> Optional[bytes]:
|
|
596
|
+
"""
|
|
597
|
+
Get latest snapshot data for serving to peers.
|
|
598
|
+
|
|
599
|
+
Returns compressed checkpoint bytes.
|
|
600
|
+
"""
|
|
601
|
+
if not self.config.serve_to_peers:
|
|
602
|
+
return None
|
|
603
|
+
|
|
604
|
+
with self._lock:
|
|
605
|
+
if not self.hot_snapshots:
|
|
606
|
+
return None
|
|
607
|
+
|
|
608
|
+
latest = self.hot_snapshots[-1]
|
|
609
|
+
|
|
610
|
+
try:
|
|
611
|
+
with open(latest.file_path, 'rb') as f:
|
|
612
|
+
data = f.read()
|
|
613
|
+
|
|
614
|
+
self.snapshots_served += 1
|
|
615
|
+
return data
|
|
616
|
+
|
|
617
|
+
except Exception as e:
|
|
618
|
+
logger.error(f"Failed to read snapshot for serving: {e}")
|
|
619
|
+
return None
|
|
620
|
+
|
|
621
|
+
# ==================== UTILITIES ====================
|
|
622
|
+
|
|
623
|
+
def _compute_model_hash(self) -> str:
|
|
624
|
+
"""Compute hash of model parameters."""
|
|
625
|
+
hasher = hashlib.sha256()
|
|
626
|
+
|
|
627
|
+
for name, param in sorted(self.model.named_parameters()):
|
|
628
|
+
hasher.update(name.encode())
|
|
629
|
+
hasher.update(param.data.cpu().numpy().tobytes()[:1000]) # First 1000 bytes
|
|
630
|
+
|
|
631
|
+
return hasher.hexdigest()[:16]
|
|
632
|
+
|
|
633
|
+
def _get_layer_ids(self) -> List[int]:
|
|
634
|
+
"""Get layer IDs from model if available."""
|
|
635
|
+
if hasattr(self.model, 'my_layer_ids'):
|
|
636
|
+
return list(self.model.my_layer_ids)
|
|
637
|
+
return []
|
|
638
|
+
|
|
639
|
+
def update_training_state(
|
|
640
|
+
self,
|
|
641
|
+
training_step: int,
|
|
642
|
+
outer_step: int = 0,
|
|
643
|
+
inner_step: int = 0,
|
|
644
|
+
):
|
|
645
|
+
"""Update training state for checkpoints."""
|
|
646
|
+
with self._lock:
|
|
647
|
+
self.training_step = training_step
|
|
648
|
+
self.outer_step = outer_step
|
|
649
|
+
self.inner_step = inner_step
|
|
650
|
+
|
|
651
|
+
def get_stats(self) -> Dict[str, Any]:
|
|
652
|
+
"""Get checkpointer statistics."""
|
|
653
|
+
with self._lock:
|
|
654
|
+
return {
|
|
655
|
+
'running': self.running,
|
|
656
|
+
'snapshots_saved': self.snapshots_saved,
|
|
657
|
+
'snapshots_served': self.snapshots_served,
|
|
658
|
+
'recoveries_performed': self.recoveries_performed,
|
|
659
|
+
'hot_snapshot_count': len(self.hot_snapshots),
|
|
660
|
+
'cold_checkpoint_count': len(self.cold_checkpoints),
|
|
661
|
+
'latest_snapshot_age': (
|
|
662
|
+
self.hot_snapshots[-1].age_seconds
|
|
663
|
+
if self.hot_snapshots else None
|
|
664
|
+
),
|
|
665
|
+
'training_step': self.training_step,
|
|
666
|
+
}
|
|
667
|
+
|
|
668
|
+
def force_snapshot(self) -> Optional[CheckpointMetadata]:
|
|
669
|
+
"""Force an immediate hot snapshot."""
|
|
670
|
+
return self._save_hot_snapshot()
|
|
671
|
+
|
|
672
|
+
|
|
673
|
+
# ==================== FACTORY FUNCTIONS ====================
|
|
674
|
+
|
|
675
|
+
def create_checkpointer(
|
|
676
|
+
model: nn.Module,
|
|
677
|
+
optimizer: torch.optim.Optimizer,
|
|
678
|
+
checkpoint_dir: str = "./checkpoints",
|
|
679
|
+
snapshot_interval: float = 120.0,
|
|
680
|
+
node_id: str = "",
|
|
681
|
+
**config_kwargs,
|
|
682
|
+
) -> SpeculativeCheckpointer:
|
|
683
|
+
"""
|
|
684
|
+
Factory function to create a speculative checkpointer.
|
|
685
|
+
|
|
686
|
+
Args:
|
|
687
|
+
model: Model to checkpoint
|
|
688
|
+
optimizer: Optimizer to checkpoint
|
|
689
|
+
checkpoint_dir: Directory for checkpoints
|
|
690
|
+
snapshot_interval: Seconds between hot snapshots
|
|
691
|
+
node_id: This node's ID
|
|
692
|
+
**config_kwargs: Additional config options
|
|
693
|
+
|
|
694
|
+
Returns:
|
|
695
|
+
Configured SpeculativeCheckpointer
|
|
696
|
+
"""
|
|
697
|
+
config = CheckpointConfig(
|
|
698
|
+
checkpoint_dir=checkpoint_dir,
|
|
699
|
+
snapshot_interval=snapshot_interval,
|
|
700
|
+
**config_kwargs,
|
|
701
|
+
)
|
|
702
|
+
|
|
703
|
+
return SpeculativeCheckpointer(
|
|
704
|
+
model=model,
|
|
705
|
+
optimizer=optimizer,
|
|
706
|
+
config=config,
|
|
707
|
+
node_id=node_id,
|
|
708
|
+
)
|
|
709
|
+
|