nexaroa 0.0.111__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. neuroshard/__init__.py +93 -0
  2. neuroshard/__main__.py +4 -0
  3. neuroshard/cli.py +466 -0
  4. neuroshard/core/__init__.py +92 -0
  5. neuroshard/core/consensus/verifier.py +252 -0
  6. neuroshard/core/crypto/__init__.py +20 -0
  7. neuroshard/core/crypto/ecdsa.py +392 -0
  8. neuroshard/core/economics/__init__.py +52 -0
  9. neuroshard/core/economics/constants.py +387 -0
  10. neuroshard/core/economics/ledger.py +2111 -0
  11. neuroshard/core/economics/market.py +975 -0
  12. neuroshard/core/economics/wallet.py +168 -0
  13. neuroshard/core/governance/__init__.py +74 -0
  14. neuroshard/core/governance/proposal.py +561 -0
  15. neuroshard/core/governance/registry.py +545 -0
  16. neuroshard/core/governance/versioning.py +332 -0
  17. neuroshard/core/governance/voting.py +453 -0
  18. neuroshard/core/model/__init__.py +30 -0
  19. neuroshard/core/model/dynamic.py +4186 -0
  20. neuroshard/core/model/llm.py +905 -0
  21. neuroshard/core/model/registry.py +164 -0
  22. neuroshard/core/model/scaler.py +387 -0
  23. neuroshard/core/model/tokenizer.py +568 -0
  24. neuroshard/core/network/__init__.py +56 -0
  25. neuroshard/core/network/connection_pool.py +72 -0
  26. neuroshard/core/network/dht.py +130 -0
  27. neuroshard/core/network/dht_plan.py +55 -0
  28. neuroshard/core/network/dht_proof_store.py +516 -0
  29. neuroshard/core/network/dht_protocol.py +261 -0
  30. neuroshard/core/network/dht_service.py +506 -0
  31. neuroshard/core/network/encrypted_channel.py +141 -0
  32. neuroshard/core/network/nat.py +201 -0
  33. neuroshard/core/network/nat_traversal.py +695 -0
  34. neuroshard/core/network/p2p.py +929 -0
  35. neuroshard/core/network/p2p_data.py +150 -0
  36. neuroshard/core/swarm/__init__.py +106 -0
  37. neuroshard/core/swarm/aggregation.py +729 -0
  38. neuroshard/core/swarm/buffers.py +643 -0
  39. neuroshard/core/swarm/checkpoint.py +709 -0
  40. neuroshard/core/swarm/compute.py +624 -0
  41. neuroshard/core/swarm/diloco.py +844 -0
  42. neuroshard/core/swarm/factory.py +1288 -0
  43. neuroshard/core/swarm/heartbeat.py +669 -0
  44. neuroshard/core/swarm/logger.py +487 -0
  45. neuroshard/core/swarm/router.py +658 -0
  46. neuroshard/core/swarm/service.py +640 -0
  47. neuroshard/core/training/__init__.py +29 -0
  48. neuroshard/core/training/checkpoint.py +600 -0
  49. neuroshard/core/training/distributed.py +1602 -0
  50. neuroshard/core/training/global_tracker.py +617 -0
  51. neuroshard/core/training/production.py +276 -0
  52. neuroshard/governance_cli.py +729 -0
  53. neuroshard/grpc_server.py +895 -0
  54. neuroshard/runner.py +3223 -0
  55. neuroshard/sdk/__init__.py +92 -0
  56. neuroshard/sdk/client.py +990 -0
  57. neuroshard/sdk/errors.py +101 -0
  58. neuroshard/sdk/types.py +282 -0
  59. neuroshard/tracker/__init__.py +0 -0
  60. neuroshard/tracker/server.py +864 -0
  61. neuroshard/ui/__init__.py +0 -0
  62. neuroshard/ui/app.py +102 -0
  63. neuroshard/ui/templates/index.html +1052 -0
  64. neuroshard/utils/__init__.py +0 -0
  65. neuroshard/utils/autostart.py +81 -0
  66. neuroshard/utils/hardware.py +121 -0
  67. neuroshard/utils/serialization.py +90 -0
  68. neuroshard/version.py +1 -0
  69. nexaroa-0.0.111.dist-info/METADATA +283 -0
  70. nexaroa-0.0.111.dist-info/RECORD +78 -0
  71. nexaroa-0.0.111.dist-info/WHEEL +5 -0
  72. nexaroa-0.0.111.dist-info/entry_points.txt +4 -0
  73. nexaroa-0.0.111.dist-info/licenses/LICENSE +190 -0
  74. nexaroa-0.0.111.dist-info/top_level.txt +2 -0
  75. protos/__init__.py +0 -0
  76. protos/neuroshard.proto +651 -0
  77. protos/neuroshard_pb2.py +160 -0
  78. protos/neuroshard_pb2_grpc.py +1298 -0
@@ -0,0 +1,1602 @@
1
+ """
2
+ Distributed Training System for NeuroLLM
3
+
4
+ Implements decentralized training where:
5
+ 1. Nodes contribute compute for forward/backward passes
6
+ 2. Gradients are aggregated via gossip protocol
7
+ 3. Training rewards are distributed in NEURO tokens
8
+ 4. Model checkpoints are shared across the network
9
+
10
+ Key Components:
11
+ - GradientAggregator: Collects and averages gradients from peers
12
+ - TrainingCoordinator: Orchestrates distributed training
13
+ - DataContributor: Handles federated dataset management
14
+ - RewardCalculator: Computes NEURO rewards for contributions
15
+
16
+ Training Flow:
17
+ 1. Coordinator broadcasts current model state hash
18
+ 2. Nodes with matching state participate in training round
19
+ 3. Each node processes local data batch
20
+ 4. Gradients are compressed and gossiped
21
+ 5. Aggregated gradients are applied
22
+ 6. New checkpoint is created and distributed
23
+ 7. NEURO rewards are calculated and distributed
24
+ """
25
+
26
+ import torch
27
+ import torch.nn as nn
28
+ import threading
29
+ import time
30
+ import hashlib
31
+ from concurrent.futures import ThreadPoolExecutor
32
+ import logging
33
+ import json
34
+ import io
35
+ import zlib
36
+ import base64
37
+ import os
38
+ import requests
39
+ from typing import Dict, List, Optional, Tuple, Any, Callable
40
+ from dataclasses import dataclass, field
41
+ from enum import Enum
42
+ from collections import defaultdict
43
+
44
+ # Import economics constants for consistency
45
+ from neuroshard.core.economics.constants import (
46
+ TRAINING_REWARD_PER_BATCH,
47
+ DATA_REWARD_PER_SAMPLE
48
+ )
49
+
50
+ logger = logging.getLogger(__name__)
51
+
52
+
53
+ class TrainingState(Enum):
54
+ """State of the training coordinator."""
55
+ IDLE = "idle"
56
+ COLLECTING = "collecting" # Collecting gradients from peers
57
+ AGGREGATING = "aggregating" # Aggregating gradients
58
+ APPLYING = "applying" # Applying updates
59
+ CHECKPOINTING = "checkpointing" # Creating checkpoint
60
+
61
+
62
+ @dataclass
63
+ class GradientContribution:
64
+ """A gradient contribution from a node."""
65
+ node_id: str
66
+ round_id: int
67
+ layer_gradients: Dict[str, bytes] # layer_name -> compressed gradient
68
+ batch_size: int
69
+ loss: float
70
+ timestamp: float
71
+ signature: str # Proof of work
72
+
73
+
74
+ @dataclass
75
+ class TrainingRound:
76
+ """A single training round."""
77
+ round_id: int
78
+ started_at: float
79
+ model_hash: str
80
+
81
+ # Contributions
82
+ contributions: Dict[str, GradientContribution] = field(default_factory=dict)
83
+ min_contributions: int = 3
84
+ max_contributions: int = 100
85
+
86
+ # Results
87
+ aggregated_gradients: Optional[Dict[str, torch.Tensor]] = None
88
+ total_batch_size: int = 0
89
+ avg_loss: float = 0.0
90
+
91
+ # State
92
+ completed: bool = False
93
+ applied: bool = False
94
+
95
+
96
+ @dataclass
97
+ class TrainingReward:
98
+ """Reward for training contribution."""
99
+ node_id: str
100
+ round_id: int
101
+ compute_reward: float # For compute contribution
102
+ data_reward: float # For data contribution
103
+ quality_bonus: float # For high-quality gradients
104
+ total_neuro: float
105
+
106
+
107
+ class GradientCompressor:
108
+ """
109
+ Compresses gradients for efficient network transmission.
110
+
111
+ Uses a combination of:
112
+ 1. Top-K sparsification (keep only largest gradients)
113
+ 2. Quantization (reduce precision)
114
+ 3. zlib compression
115
+ """
116
+
117
+ def __init__(self, top_k_ratio: float = 0.1, bits: int = 8):
118
+ self.top_k_ratio = top_k_ratio
119
+ self.bits = bits
120
+
121
+ def compress(self, gradient: torch.Tensor) -> bytes:
122
+ """Compress a gradient tensor."""
123
+ # CRITICAL: Move to CPU first for MPS/CUDA compatibility
124
+ gradient = gradient.detach().cpu()
125
+
126
+ # Flatten
127
+ flat = gradient.flatten()
128
+
129
+ # Top-K sparsification
130
+ k = max(1, int(len(flat) * self.top_k_ratio))
131
+ values, indices = torch.topk(flat.abs(), k)
132
+
133
+ # Get actual values (with signs)
134
+ sparse_values = flat[indices]
135
+
136
+ # Quantize to specified bits
137
+ max_val = sparse_values.abs().max()
138
+ if max_val > 0:
139
+ scale = (2 ** (self.bits - 1) - 1) / max_val
140
+ quantized = (sparse_values * scale).round().to(torch.int8)
141
+ else:
142
+ quantized = torch.zeros(k, dtype=torch.int8)
143
+ scale = 1.0
144
+
145
+ # Pack into bytes (tensors already on CPU)
146
+ data = {
147
+ "shape": list(gradient.shape),
148
+ "k": k,
149
+ "indices": base64.b64encode(indices.numpy().tobytes()).decode('ascii'),
150
+ "values": base64.b64encode(quantized.numpy().tobytes()).decode('ascii'),
151
+ "scale": float(scale),
152
+ "dtype": str(gradient.dtype),
153
+ }
154
+
155
+ # Serialize and compress
156
+ json_data = json.dumps(data).encode()
157
+ return zlib.compress(json_data)
158
+
159
+ def decompress(self, data: bytes, device: str = "cpu") -> torch.Tensor:
160
+ """Decompress a gradient tensor."""
161
+ # Decompress and deserialize
162
+ json_data = zlib.decompress(data)
163
+ packed = json.loads(json_data)
164
+
165
+ # Unpack
166
+ shape = packed["shape"]
167
+ k = packed["k"]
168
+ indices = torch.frombuffer(
169
+ bytearray(base64.b64decode(packed["indices"])),
170
+ dtype=torch.int64
171
+ ).clone().to(device)
172
+ values = torch.frombuffer(
173
+ bytearray(base64.b64decode(packed["values"])),
174
+ dtype=torch.int8
175
+ ).float().clone().to(device)
176
+ scale = packed["scale"]
177
+
178
+ # Dequantize
179
+ values = values / scale
180
+
181
+ # Reconstruct sparse tensor
182
+ flat = torch.zeros(torch.prod(torch.tensor(shape)), device=device)
183
+ flat[indices] = values
184
+
185
+ return flat.view(*shape)
186
+
187
+
188
+ class GradientAggregator:
189
+ """
190
+ Aggregates gradients from multiple nodes.
191
+
192
+ Supports:
193
+ - Simple averaging
194
+ - Weighted averaging (by batch size)
195
+ - Robust aggregation (median, trimmed mean)
196
+ """
197
+
198
+ def __init__(self, method: str = "weighted_mean"):
199
+ self.method = method
200
+ self.compressor = GradientCompressor()
201
+
202
+ def aggregate(
203
+ self,
204
+ contributions: List[GradientContribution],
205
+ layer_names: List[str]
206
+ ) -> Dict[str, torch.Tensor]:
207
+ """
208
+ Aggregate gradients from multiple contributions.
209
+
210
+ Args:
211
+ contributions: List of gradient contributions
212
+ layer_names: Names of layers to aggregate
213
+
214
+ Returns:
215
+ Aggregated gradients per layer
216
+ """
217
+ if not contributions:
218
+ return {}
219
+
220
+ aggregated = {}
221
+ total_batch_size = sum(c.batch_size for c in contributions)
222
+
223
+ for layer_name in layer_names:
224
+ # Collect gradients for this layer
225
+ gradients = []
226
+ weights = []
227
+
228
+ for contrib in contributions:
229
+ if layer_name in contrib.layer_gradients:
230
+ grad = self.compressor.decompress(contrib.layer_gradients[layer_name])
231
+ gradients.append(grad)
232
+ weights.append(contrib.batch_size)
233
+
234
+ if not gradients:
235
+ continue
236
+
237
+ # Stack gradients
238
+ stacked = torch.stack(gradients)
239
+
240
+ if self.method == "mean":
241
+ aggregated[layer_name] = stacked.mean(dim=0)
242
+
243
+ elif self.method == "weighted_mean":
244
+ weights_tensor = torch.tensor(weights, dtype=torch.float32)
245
+ weights_tensor = weights_tensor / weights_tensor.sum()
246
+ aggregated[layer_name] = (stacked * weights_tensor.view(-1, *([1] * (stacked.dim() - 1)))).sum(dim=0)
247
+
248
+ elif self.method == "median":
249
+ aggregated[layer_name] = stacked.median(dim=0)[0]
250
+
251
+ elif self.method == "trimmed_mean":
252
+ # Remove top and bottom 10%
253
+ k = max(1, len(gradients) // 10)
254
+ sorted_grads = stacked.sort(dim=0)[0]
255
+ aggregated[layer_name] = sorted_grads[k:-k].mean(dim=0) if k < len(gradients) // 2 else stacked.mean(dim=0)
256
+
257
+ return aggregated
258
+
259
+
260
+ class TrainingCoordinator:
261
+ """
262
+ Coordinates distributed training across the network.
263
+
264
+ Responsibilities:
265
+ 1. Initiate training rounds
266
+ 2. Collect gradient contributions
267
+ 3. Aggregate and apply updates
268
+ 4. Distribute rewards
269
+ 5. Manage checkpoints
270
+
271
+ NOTE: This class is LEGACY and not currently used in production.
272
+ The active reward path uses economics.py constants via ledger.py
273
+ """
274
+
275
+ # Configuration
276
+ ROUND_DURATION_SECONDS = 60
277
+ MIN_CONTRIBUTIONS = 3
278
+ GRADIENT_CLIP_NORM = 1.0
279
+
280
+ # Reward rates (using economics.py constants for consistency)
281
+ # NOTE: LEGACY - These are kept for backwards compatibility but not actively used
282
+ # Import at class level to match economics.py values
283
+ from neuroshard.core.economics.constants import TRAINING_REWARD_PER_BATCH as COMPUTE_REWARD_PER_BATCH
284
+ from neuroshard.core.economics.constants import DATA_REWARD_PER_SAMPLE
285
+ QUALITY_BONUS_MULTIPLIER = 1.5
286
+
287
+ def __init__(
288
+ self,
289
+ model: nn.Module,
290
+ optimizer: torch.optim.Optimizer,
291
+ node_id: str,
292
+ ledger_manager = None,
293
+ on_round_complete: Optional[Callable] = None
294
+ ):
295
+ self.model = model
296
+ self.optimizer = optimizer
297
+ self.node_id = node_id
298
+ self.ledger = ledger_manager
299
+ self.on_round_complete = on_round_complete
300
+
301
+ # State
302
+ self.state = TrainingState.IDLE
303
+ self.current_round: Optional[TrainingRound] = None
304
+ self.round_history: List[TrainingRound] = []
305
+ self.global_step = 0
306
+
307
+ # Components
308
+ self.aggregator = GradientAggregator()
309
+ self.compressor = GradientCompressor()
310
+
311
+ # Threading
312
+ self.lock = threading.Lock()
313
+ self.running = False
314
+
315
+ # Stats
316
+ self.total_rounds = 0
317
+ self.total_contributions = 0
318
+ self.total_neuro_distributed = 0.0
319
+
320
+ logger.info(f"TrainingCoordinator initialized for node {node_id}")
321
+
322
+ def start(self):
323
+ """Start the training coordinator."""
324
+ self.running = True
325
+ threading.Thread(target=self._training_loop, daemon=True).start()
326
+ logger.info("Training coordinator started")
327
+
328
+ def stop(self):
329
+ """Stop the training coordinator."""
330
+ self.running = False
331
+
332
+ def _training_loop(self):
333
+ """Main training loop."""
334
+ while self.running:
335
+ try:
336
+ if self.state == TrainingState.IDLE:
337
+ # Start new round
338
+ self._start_round()
339
+
340
+ elif self.state == TrainingState.COLLECTING:
341
+ # Check if round should complete
342
+ if self._should_complete_round():
343
+ self._complete_round()
344
+
345
+ time.sleep(1)
346
+
347
+ except Exception as e:
348
+ logger.error(f"Training loop error: {e}")
349
+ time.sleep(5)
350
+
351
+ def _get_model_hash(self) -> str:
352
+ """Get hash of current model state."""
353
+ state_dict = self.model.state_dict()
354
+ hasher = hashlib.sha256()
355
+
356
+ for name, param in sorted(state_dict.items()):
357
+ hasher.update(name.encode())
358
+ hasher.update(param.cpu().numpy().tobytes()[:1000]) # Sample for speed
359
+
360
+ return hasher.hexdigest()[:16]
361
+
362
+ def _start_round(self):
363
+ """Start a new training round."""
364
+ with self.lock:
365
+ self.total_rounds += 1
366
+
367
+ self.current_round = TrainingRound(
368
+ round_id=self.total_rounds,
369
+ started_at=time.time(),
370
+ model_hash=self._get_model_hash(),
371
+ min_contributions=self.MIN_CONTRIBUTIONS,
372
+ )
373
+
374
+ self.state = TrainingState.COLLECTING
375
+
376
+ logger.info(f"Started training round {self.total_rounds}")
377
+
378
+ def _should_complete_round(self) -> bool:
379
+ """Check if current round should complete."""
380
+ if not self.current_round:
381
+ return False
382
+
383
+ # Time limit
384
+ elapsed = time.time() - self.current_round.started_at
385
+ if elapsed >= self.ROUND_DURATION_SECONDS:
386
+ return True
387
+
388
+ # Max contributions
389
+ if len(self.current_round.contributions) >= self.current_round.max_contributions:
390
+ return True
391
+
392
+ return False
393
+
394
+ def _complete_round(self):
395
+ """Complete the current training round."""
396
+ if not self.current_round:
397
+ return
398
+
399
+ with self.lock:
400
+ round_data = self.current_round
401
+
402
+ # Check minimum contributions
403
+ if len(round_data.contributions) < round_data.min_contributions:
404
+ logger.warning(f"Round {round_data.round_id} failed: insufficient contributions "
405
+ f"({len(round_data.contributions)}/{round_data.min_contributions})")
406
+ self.state = TrainingState.IDLE
407
+ self.current_round = None
408
+ return
409
+
410
+ self.state = TrainingState.AGGREGATING
411
+
412
+ logger.info(f"Completing round {round_data.round_id} with {len(round_data.contributions)} contributions")
413
+
414
+ # Aggregate gradients
415
+ layer_names = [name for name, _ in self.model.named_parameters()]
416
+ aggregated = self.aggregator.aggregate(
417
+ list(round_data.contributions.values()),
418
+ layer_names
419
+ )
420
+
421
+ round_data.aggregated_gradients = aggregated
422
+ round_data.total_batch_size = sum(c.batch_size for c in round_data.contributions.values())
423
+ round_data.avg_loss = sum(c.loss for c in round_data.contributions.values()) / len(round_data.contributions)
424
+
425
+ # Apply gradients
426
+ with self.lock:
427
+ self.state = TrainingState.APPLYING
428
+
429
+ self._apply_gradients(aggregated)
430
+ round_data.applied = True
431
+
432
+ # Calculate and distribute rewards
433
+ rewards = self._calculate_rewards(round_data)
434
+ self._distribute_rewards(rewards)
435
+
436
+ # Checkpoint
437
+ with self.lock:
438
+ self.state = TrainingState.CHECKPOINTING
439
+
440
+ self._create_checkpoint(round_data)
441
+
442
+ # Complete
443
+ round_data.completed = True
444
+ self.round_history.append(round_data)
445
+
446
+ # Keep only last 100 rounds
447
+ if len(self.round_history) > 100:
448
+ self.round_history = self.round_history[-100:]
449
+
450
+ # Callback
451
+ if self.on_round_complete:
452
+ self.on_round_complete(round_data)
453
+
454
+ # Reset
455
+ with self.lock:
456
+ self.current_round = None
457
+ self.state = TrainingState.IDLE
458
+ self.global_step += 1
459
+
460
+ logger.info(f"Round {round_data.round_id} complete: loss={round_data.avg_loss:.4f}, "
461
+ f"batch_size={round_data.total_batch_size}")
462
+
463
+ def _apply_gradients(self, gradients: Dict[str, torch.Tensor]):
464
+ """Apply aggregated gradients to model."""
465
+ self.optimizer.zero_grad()
466
+
467
+ for name, param in self.model.named_parameters():
468
+ if name in gradients:
469
+ if param.grad is None:
470
+ param.grad = gradients[name].to(param.device)
471
+ else:
472
+ param.grad.copy_(gradients[name])
473
+
474
+ # Gradient clipping
475
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.GRADIENT_CLIP_NORM)
476
+
477
+ # Apply
478
+ self.optimizer.step()
479
+
480
+ def _calculate_rewards(self, round_data: TrainingRound) -> List[TrainingReward]:
481
+ """Calculate NEURO rewards for contributions."""
482
+ rewards = []
483
+
484
+ # Calculate average loss for quality comparison
485
+ avg_loss = round_data.avg_loss
486
+
487
+ for node_id, contrib in round_data.contributions.items():
488
+ # Base compute reward
489
+ compute_reward = contrib.batch_size * self.COMPUTE_REWARD_PER_BATCH
490
+
491
+ # Data contribution reward
492
+ data_reward = contrib.batch_size * self.DATA_REWARD_PER_SAMPLE
493
+
494
+ # Quality bonus (lower loss = better)
495
+ quality_bonus = 0.0
496
+ if contrib.loss < avg_loss:
497
+ quality_bonus = compute_reward * (self.QUALITY_BONUS_MULTIPLIER - 1)
498
+
499
+ total = compute_reward + data_reward + quality_bonus
500
+
501
+ rewards.append(TrainingReward(
502
+ node_id=node_id,
503
+ round_id=round_data.round_id,
504
+ compute_reward=compute_reward,
505
+ data_reward=data_reward,
506
+ quality_bonus=quality_bonus,
507
+ total_neuro=total
508
+ ))
509
+
510
+ self.total_neuro_distributed += total
511
+
512
+ return rewards
513
+
514
+ def _distribute_rewards(self, rewards: List[TrainingReward]):
515
+ """Distribute NEURO rewards to contributors using PoNW proofs."""
516
+ if not self.ledger:
517
+ logger.debug("No ledger available for reward distribution")
518
+ return
519
+
520
+ for reward in rewards:
521
+ try:
522
+ from neuroshard.core.economics.ledger import PoNWProof, ProofType
523
+ import time
524
+
525
+ # Create a proper training PoNW proof
526
+ proof = PoNWProof(
527
+ node_id=reward.node_id,
528
+ proof_type=ProofType.TRAINING.value,
529
+ timestamp=time.time(),
530
+ nonce=f"train_{reward.round_id}_{reward.node_id[:16]}",
531
+ training_batches=int(reward.compute_reward / self.COMPUTE_REWARD_PER_BATCH),
532
+ data_samples=int(reward.data_reward / self.DATA_REWARD_PER_SAMPLE),
533
+ signature=f"training_reward_{reward.round_id}_{reward.node_id}"
534
+ )
535
+
536
+ # Process through the ledger (handles deduplication, stats, etc.)
537
+ success, amount, msg = self.ledger.process_proof(proof)
538
+
539
+ if success:
540
+ logger.info(f"Reward: {reward.node_id[:8]}... earned {amount:.6f} NEURO "
541
+ f"(compute={reward.compute_reward:.6f}, data={reward.data_reward:.6f}, "
542
+ f"quality={reward.quality_bonus:.6f})")
543
+ else:
544
+ logger.debug(f"Training reward not processed: {msg}")
545
+
546
+ except Exception as e:
547
+ logger.error(f"Failed to distribute reward to {reward.node_id}: {e}")
548
+
549
+ def _create_checkpoint(self, round_data: TrainingRound):
550
+ """Create a checkpoint after training round."""
551
+ checkpoint_path = f"checkpoints/neuro_llm_round_{round_data.round_id}.pt"
552
+
553
+ try:
554
+ import os
555
+ os.makedirs("checkpoints", exist_ok=True)
556
+
557
+ torch.save({
558
+ "model_state_dict": self.model.state_dict(),
559
+ "optimizer_state_dict": self.optimizer.state_dict(),
560
+ "round_id": round_data.round_id,
561
+ "global_step": self.global_step,
562
+ "model_hash": self._get_model_hash(),
563
+ "avg_loss": round_data.avg_loss,
564
+ "total_batch_size": round_data.total_batch_size,
565
+ "timestamp": time.time(),
566
+ }, checkpoint_path)
567
+
568
+ logger.info(f"Checkpoint saved: {checkpoint_path}")
569
+
570
+ except Exception as e:
571
+ logger.error(f"Failed to save checkpoint: {e}")
572
+
573
+ def submit_contribution(self, contribution: GradientContribution) -> bool:
574
+ """
575
+ Submit a gradient contribution for the current round.
576
+
577
+ Called by peers when they have computed gradients.
578
+ """
579
+ with self.lock:
580
+ if self.state != TrainingState.COLLECTING:
581
+ return False
582
+
583
+ if not self.current_round:
584
+ return False
585
+
586
+ # Verify model hash matches
587
+ # In production, this would be more sophisticated
588
+
589
+ # Add contribution
590
+ self.current_round.contributions[contribution.node_id] = contribution
591
+ self.total_contributions += 1
592
+
593
+ logger.debug(f"Received contribution from {contribution.node_id[:8]}... "
594
+ f"(batch_size={contribution.batch_size}, loss={contribution.loss:.4f})")
595
+
596
+ return True
597
+
598
+ def compute_local_gradients(
599
+ self,
600
+ input_ids: torch.Tensor,
601
+ labels: torch.Tensor
602
+ ) -> GradientContribution:
603
+ """
604
+ Compute gradients on local data.
605
+
606
+ Call this to participate in training.
607
+ """
608
+ self.model.train()
609
+
610
+ # Forward pass
611
+ outputs = self.model(input_ids=input_ids, labels=labels)
612
+ loss = outputs["loss"]
613
+
614
+ # Backward pass
615
+ loss.backward()
616
+
617
+ # Collect and compress gradients
618
+ layer_gradients = {}
619
+ for name, param in self.model.named_parameters():
620
+ if param.grad is not None:
621
+ layer_gradients[name] = self.compressor.compress(param.grad)
622
+
623
+ # Clear gradients (they're saved in contribution)
624
+ self.optimizer.zero_grad()
625
+
626
+ # Create contribution
627
+ contribution = GradientContribution(
628
+ node_id=self.node_id,
629
+ round_id=self.current_round.round_id if self.current_round else 0,
630
+ layer_gradients=layer_gradients,
631
+ batch_size=input_ids.shape[0],
632
+ loss=loss.item(),
633
+ timestamp=time.time(),
634
+ signature=self._sign_contribution(layer_gradients)
635
+ )
636
+
637
+ return contribution
638
+
639
+ def _sign_contribution(self, gradients: Dict[str, bytes]) -> str:
640
+ """Sign a contribution for verification."""
641
+ hasher = hashlib.sha256()
642
+ hasher.update(self.node_id.encode())
643
+ hasher.update(str(time.time()).encode())
644
+ for name, data in sorted(gradients.items()):
645
+ hasher.update(name.encode())
646
+ hasher.update(data[:100]) # Sample for speed
647
+ return hasher.hexdigest()
648
+
649
+ def get_status(self) -> Dict[str, Any]:
650
+ """Get coordinator status."""
651
+ return {
652
+ "state": self.state.value,
653
+ "global_step": self.global_step,
654
+ "total_rounds": self.total_rounds,
655
+ "total_contributions": self.total_contributions,
656
+ "total_neuro_distributed": self.total_neuro_distributed,
657
+ "current_round": {
658
+ "round_id": self.current_round.round_id,
659
+ "contributions": len(self.current_round.contributions),
660
+ "elapsed": time.time() - self.current_round.started_at,
661
+ "model_hash": self.current_round.model_hash,
662
+ } if self.current_round else None,
663
+ "recent_rounds": [
664
+ {
665
+ "round_id": r.round_id,
666
+ "contributions": len(r.contributions),
667
+ "avg_loss": r.avg_loss,
668
+ "total_batch_size": r.total_batch_size,
669
+ }
670
+ for r in self.round_history[-10:]
671
+ ]
672
+ }
673
+
674
+
675
+ class GenesisDataLoader:
676
+ """
677
+ Loads training data from the verified Genesis Dataset.
678
+
679
+ Features:
680
+ - Dynamic shard count (reads from manifest)
681
+ - User-configurable storage limit (max_storage_mb)
682
+ - Shard rotation (cycles through dataset over time)
683
+ - Multi-shard support (downloads multiple shards up to storage limit)
684
+ - ASYNC PREFETCHING: Pre-downloads next shard while training on current
685
+
686
+ Active only for nodes holding Layer 0 (Embedding Layer).
687
+
688
+ Data Source: CloudFront CDN (backed by S3)
689
+ """
690
+ # CloudFront CDN URL - single source of truth (cached, DDoS protected)
691
+ GENESIS_CDN_URL = "https://dwquwt9gkkeil.cloudfront.net"
692
+ # Size per shard in MB (must match populate_genesis_s3.py)
693
+ SHARD_SIZE_MB = 10
694
+
695
+ def __init__(
696
+ self,
697
+ node_id: str,
698
+ tokenizer,
699
+ cache_dir: str = None, # Default to ~/.neuroshard/data_cache
700
+ max_storage_mb: float = 100.0, # User-configurable limit
701
+ manifest_version: int = 1
702
+ ):
703
+ self.node_id = node_id
704
+ self.tokenizer = tokenizer
705
+
706
+ # Default cache_dir to ~/.neuroshard/data_cache for consistent storage
707
+ if cache_dir is None:
708
+ cache_dir = os.path.join(os.path.expanduser("~"), ".neuroshard", "data_cache")
709
+ self.cache_dir = cache_dir
710
+ self.max_storage_mb = max_storage_mb
711
+ self.manifest_version = manifest_version
712
+
713
+ # CloudFront CDN manifest URL - single source of truth
714
+ self.manifest_url = f"{self.GENESIS_CDN_URL}/manifest.json"
715
+
716
+ # Manifest data (cached, refreshed periodically)
717
+ self.manifest = None
718
+ self.total_shards = 0
719
+ self.manifest_last_fetch = 0
720
+ self.MANIFEST_REFRESH_INTERVAL = 600 # Refresh manifest every 10 minutes (auto-update tokenizer)
721
+
722
+ # Shard management
723
+ self.max_shards = max(1, int(max_storage_mb / self.SHARD_SIZE_MB))
724
+ self.assigned_shard_ids = [] # List of shard IDs this node is responsible for
725
+ self.loaded_shards = {} # shard_id -> tensor data
726
+ self.current_shard_idx = 0 # Index into assigned_shard_ids for rotation
727
+ self.shard_rotation_count = 0 # How many times we've rotated through
728
+ self.loading_shards = set() # Track shards currently being downloaded
729
+ self._shard_lock = threading.Lock() # Lock for shard loading
730
+ self._download_executor = ThreadPoolExecutor(max_workers=3, thread_name_prefix="shard-download")
731
+
732
+ # ASYNC PREFETCHING: Keep next shard(s) ready in background
733
+ self._prefetch_in_progress = set() # Shard IDs being prefetched
734
+ self._prefetch_ready = {} # shard_id -> tensor data (ready to use)
735
+ self._prefetch_ahead = 2 # Number of shards to prefetch ahead (was 1)
736
+
737
+ # LOSS PLATEAU DETECTION: Track loss to detect when to rotate shards early
738
+ self._loss_history = [] # Recent loss values
739
+ self._loss_history_max = 50 # Number of loss values to track
740
+ self._loss_plateau_threshold = 0.02 # If loss variance < this, plateau detected
741
+ self._min_steps_per_shard = 100 # Minimum steps before considering early rotation
742
+ self._steps_on_current_shard = 0 # Steps taken on current shard
743
+
744
+ # Initialize Data Swarm for P2P downloading
745
+ self.swarm = None
746
+
747
+ self.current_dataset = None
748
+ self.dataset_iterator = 0
749
+
750
+ # Fetch manifest and assign initial shards
751
+ self._refresh_manifest()
752
+ self._assign_shards()
753
+
754
+ # Try to load learned tokenizer from CDN (for proper vocab)
755
+ self._load_learned_tokenizer()
756
+
757
+ # THUNDERING HERD PREVENTION: Add random jitter before first download
758
+ # This spreads load across the CDN when many nodes start simultaneously
759
+ # Jitter: 0-5 seconds based on node_id hash
760
+ import random
761
+ jitter_seed = int(hashlib.sha256(self.node_id.encode()).hexdigest()[:8], 16)
762
+ jitter_seconds = (jitter_seed % 5000) / 1000.0 # 0-5 seconds
763
+
764
+ def delayed_prefetch():
765
+ time.sleep(jitter_seconds)
766
+ self._start_prefetch_next()
767
+
768
+ # Start prefetching first shard with jitter (non-blocking)
769
+ threading.Thread(target=delayed_prefetch, daemon=True).start()
770
+
771
+ logger.info(f"GenesisDataLoader initialized: "
772
+ f"total_shards={self.total_shards}, "
773
+ f"max_storage={max_storage_mb}MB ({self.max_shards} shards), "
774
+ f"assigned={self.assigned_shard_ids[:5]}{'...' if len(self.assigned_shard_ids) > 5 else ''}, "
775
+ f"prefetch_jitter={jitter_seconds:.2f}s")
776
+
777
+ def _load_learned_tokenizer(self):
778
+ """
779
+ Download and load the learned tokenizer from Genesis CDN.
780
+ Checks if the network has learned more tokens and updates locally.
781
+ """
782
+ try:
783
+ tokenizer_url = f"{self.GENESIS_CDN_URL}/tokenizer.json"
784
+ tokenizer_cache_path = os.path.join(self.cache_dir, "tokenizer.json")
785
+
786
+ # Always try to fetch latest from CDN
787
+ try:
788
+ logger.debug(f"[GENESIS] Checking for tokenizer updates from {tokenizer_url}...")
789
+ resp = requests.get(tokenizer_url, timeout=10)
790
+
791
+ if resp.status_code == 200:
792
+ remote_tokenizer_data = resp.json()
793
+ remote_vocab_size = remote_tokenizer_data.get("next_merge_id", 0)
794
+
795
+ # Always cache the downloaded tokenizer (for offline use)
796
+ os.makedirs(self.cache_dir, exist_ok=True)
797
+ with open(tokenizer_cache_path, 'w') as f:
798
+ f.write(resp.text)
799
+
800
+ # Update our tokenizer if remote has more tokens
801
+ if remote_vocab_size > self.tokenizer.next_merge_id:
802
+ logger.info(f"[GENESIS] Found improved tokenizer! ({self.tokenizer.next_merge_id} -> {remote_vocab_size} tokens)")
803
+
804
+ from neuroshard.core.model.tokenizer import NeuroTokenizer
805
+ learned_tokenizer = NeuroTokenizer.load(tokenizer_cache_path)
806
+
807
+ self.tokenizer.merges = learned_tokenizer.merges
808
+ self.tokenizer.merge_to_tokens = learned_tokenizer.merge_to_tokens
809
+ self.tokenizer.next_merge_id = learned_tokenizer.next_merge_id
810
+
811
+ logger.info(f"[GENESIS] Tokenizer updated: {self.tokenizer.next_merge_id} tokens")
812
+ else:
813
+ logger.info(f"[GENESIS] Tokenizer cached: {remote_vocab_size} tokens (current: {self.tokenizer.next_merge_id})")
814
+ return
815
+ except Exception as e:
816
+ logger.debug(f"[GENESIS] Failed to check for tokenizer updates: {e}")
817
+
818
+ # Fallback to cached version if download failed
819
+ if os.path.exists(tokenizer_cache_path) and self.tokenizer.next_merge_id <= 266:
820
+ logger.info(f"[GENESIS] Loading cached tokenizer from {tokenizer_cache_path}")
821
+ try:
822
+ from neuroshard.core.model.tokenizer import NeuroTokenizer
823
+ learned_tokenizer = NeuroTokenizer.load(tokenizer_cache_path)
824
+
825
+ if learned_tokenizer.next_merge_id > self.tokenizer.next_merge_id:
826
+ self.tokenizer.merges = learned_tokenizer.merges
827
+ self.tokenizer.merge_to_tokens = learned_tokenizer.merge_to_tokens
828
+ self.tokenizer.next_merge_id = learned_tokenizer.next_merge_id
829
+ logger.info(f"[GENESIS] Loaded cached tokenizer: {self.tokenizer.next_merge_id} tokens")
830
+ except Exception as e:
831
+ logger.warning(f"[GENESIS] Failed to load cached tokenizer: {e}")
832
+
833
+ except Exception as e:
834
+ logger.warning(f"[GENESIS] Error managing tokenizer: {e}")
835
+
836
+ def _refresh_manifest_sync(self):
837
+ """Synchronous manifest fetch (runs in background thread)."""
838
+ try:
839
+ logger.info(f"[GENESIS] Fetching manifest from {self.manifest_url}...")
840
+ resp = requests.get(self.manifest_url, timeout=15)
841
+ if resp.status_code == 200:
842
+ manifest_data = resp.json()
843
+ total_shards = manifest_data.get("total_shards", 0)
844
+
845
+ # Update state atomically
846
+ with self._shard_lock:
847
+ self.manifest = manifest_data
848
+ self.total_shards = total_shards
849
+ self.manifest_last_fetch = time.time()
850
+
851
+ logger.info(f"[GENESIS] Manifest loaded: {self.total_shards} shards available")
852
+
853
+ # Also check if tokenizer has improved (in background)
854
+ self._load_learned_tokenizer()
855
+ else:
856
+ logger.error(f"[GENESIS] Failed to fetch manifest: HTTP {resp.status_code}")
857
+ logger.error(f"[GENESIS] Response: {resp.text[:200]}")
858
+ except Exception as e:
859
+ logger.error(f"[GENESIS] Failed to fetch manifest from {self.manifest_url}: {type(e).__name__}: {e}")
860
+ import traceback
861
+ logger.error(f"[GENESIS] Traceback: {traceback.format_exc()}")
862
+
863
+ def _refresh_manifest(self):
864
+ """Fetch latest manifest from S3 (non-blocking after first load)."""
865
+ now = time.time()
866
+
867
+ # First time initialization - must be synchronous
868
+ if self.manifest is None:
869
+ self._refresh_manifest_sync()
870
+ if self.total_shards == 0:
871
+ raise RuntimeError(f"Cannot fetch manifest from {self.manifest_url}. Check S3 bucket.")
872
+ return
873
+
874
+ # Subsequent refreshes - use cached if recent
875
+ if (now - self.manifest_last_fetch) < self.MANIFEST_REFRESH_INTERVAL:
876
+ return # Use cached manifest
877
+
878
+ # Refresh in background (non-blocking)
879
+ self._download_executor.submit(self._refresh_manifest_sync)
880
+
881
+ def _assign_shards(self):
882
+ """
883
+ Assign shards to this node based on:
884
+ 1. Node's deterministic hash (ensures different nodes get different shards)
885
+ 2. User's storage limit (max_shards)
886
+ 3. Rotation offset (allows cycling through entire dataset over time)
887
+ """
888
+ if self.total_shards == 0:
889
+ self.assigned_shard_ids = [0]
890
+ return
891
+
892
+ # Base offset from node ID (deterministic)
893
+ node_hash = int(hashlib.sha256(self.node_id.encode()).hexdigest(), 16)
894
+ base_offset = node_hash % self.total_shards
895
+
896
+ # Rotation offset (changes over time to cover more data)
897
+ rotation_offset = (self.shard_rotation_count * self.max_shards) % self.total_shards
898
+
899
+ # Assign shards starting from (base + rotation) offset
900
+ self.assigned_shard_ids = []
901
+ for i in range(self.max_shards):
902
+ shard_id = (base_offset + rotation_offset + i) % self.total_shards
903
+ self.assigned_shard_ids.append(shard_id)
904
+
905
+ logger.info(f"Assigned {len(self.assigned_shard_ids)} shards: "
906
+ f"{self.assigned_shard_ids[:5]}{'...' if len(self.assigned_shard_ids) > 5 else ''}")
907
+
908
+ def rotate_shards(self):
909
+ """
910
+ Rotate to next set of shards.
911
+ Call this periodically to train on different parts of the dataset.
912
+ """
913
+ # Clear old loaded shards to free memory
914
+ old_shards = list(self.loaded_shards.keys())
915
+ self.loaded_shards.clear()
916
+ self.current_dataset = None
917
+ self.dataset_iterator = 0
918
+
919
+ # Increment rotation counter
920
+ self.shard_rotation_count += 1
921
+
922
+ # Refresh manifest (in case new shards were added)
923
+ self._refresh_manifest()
924
+
925
+ # Reassign shards with new rotation offset
926
+ self._assign_shards()
927
+
928
+ # Clean up old shard files from disk
929
+ self._cleanup_old_shards(old_shards)
930
+
931
+ logger.info(f"Rotated to new shards (rotation #{self.shard_rotation_count})")
932
+
933
+ def _cleanup_old_shards(self, old_shard_ids: list):
934
+ """Remove old shard files from disk to stay within storage limit."""
935
+ for shard_id in old_shard_ids:
936
+ if shard_id not in self.assigned_shard_ids:
937
+ shard_path = os.path.join(self.cache_dir, f"genesis_shard_{shard_id}.pt")
938
+ try:
939
+ if os.path.exists(shard_path):
940
+ os.remove(shard_path)
941
+ logger.debug(f"Cleaned up old shard: {shard_path}")
942
+ except Exception as e:
943
+ logger.warning(f"Failed to cleanup shard {shard_id}: {e}")
944
+
945
+ def set_swarm(self, swarm):
946
+ """Set the DataSwarm instance."""
947
+ self.swarm = swarm
948
+
949
+ def record_loss(self, loss: float):
950
+ """
951
+ Record a training loss for plateau detection.
952
+
953
+ Call this from the training loop to enable adaptive shard rotation.
954
+ When loss plateaus, the loader will rotate to fresh data.
955
+ """
956
+ self._loss_history.append(loss)
957
+ if len(self._loss_history) > self._loss_history_max:
958
+ self._loss_history.pop(0)
959
+ self._steps_on_current_shard += 1
960
+
961
+ def _should_rotate_early(self) -> bool:
962
+ """
963
+ Check if we should rotate to a new shard early due to loss plateau.
964
+
965
+ Conditions for early rotation:
966
+ 1. Have enough loss samples (at least 20)
967
+ 2. Minimum steps on current shard (100)
968
+ 3. Loss has plateaued (low variance)
969
+ 4. Loss is low enough that we're not still actively learning
970
+ """
971
+ if len(self._loss_history) < 20:
972
+ return False
973
+
974
+ if self._steps_on_current_shard < self._min_steps_per_shard:
975
+ return False
976
+
977
+ # Calculate loss statistics
978
+ recent_losses = self._loss_history[-20:]
979
+ avg_loss = sum(recent_losses) / len(recent_losses)
980
+ variance = sum((l - avg_loss) ** 2 for l in recent_losses) / len(recent_losses)
981
+
982
+ # Also check if loss is decreasing (don't rotate if still improving)
983
+ if len(self._loss_history) >= 40:
984
+ older_avg = sum(self._loss_history[-40:-20]) / 20
985
+ improvement = older_avg - avg_loss
986
+
987
+ # Still improving significantly - don't rotate
988
+ if improvement > 0.005:
989
+ return False
990
+
991
+ # Plateau detected: low variance AND low absolute loss
992
+ if variance < self._loss_plateau_threshold and avg_loss < 0.05:
993
+ logger.info(f"[GENESIS] Loss plateau detected: avg={avg_loss:.4f}, variance={variance:.6f}")
994
+ logger.info(f"[GENESIS] Rotating to fresh data for continued learning")
995
+ return True
996
+
997
+ return False
998
+
999
+ def force_shard_rotation(self, reason: str = "manual"):
1000
+ """
1001
+ Force rotation to a new shard.
1002
+
1003
+ Call this when you want to move to fresh data (e.g., loss plateau).
1004
+ """
1005
+ logger.info(f"[GENESIS] Forcing shard rotation: {reason}")
1006
+ self._loss_history.clear()
1007
+ self._steps_on_current_shard = 0
1008
+
1009
+ # Move to next shard
1010
+ self.current_shard_idx += 1
1011
+
1012
+ if self.current_shard_idx >= len(self.assigned_shard_ids):
1013
+ # We've gone through all assigned shards - rotate to new set
1014
+ logger.info(f"[GENESIS] Exhausted all {len(self.assigned_shard_ids)} assigned shards. Getting new set...")
1015
+ self.rotate_shards()
1016
+
1017
+ # Reset dataset iterator to start fresh
1018
+ self.current_dataset = None
1019
+ self.dataset_iterator = 0
1020
+
1021
+ # Start prefetching the new shard
1022
+ self._start_prefetch_next()
1023
+
1024
+ def _start_prefetch_next(self):
1025
+ """Start prefetching the next shard(s) in background."""
1026
+ if not self.assigned_shard_ids:
1027
+ return
1028
+
1029
+ # Prefetch current and multiple next shards for faster data access
1030
+ shards_to_prefetch = []
1031
+ for offset in range(self._prefetch_ahead + 1): # Current + prefetch_ahead (default: 0, 1, 2)
1032
+ idx = (self.current_shard_idx + offset) % len(self.assigned_shard_ids)
1033
+ shard_id = self.assigned_shard_ids[idx]
1034
+
1035
+ with self._shard_lock:
1036
+ # Skip if already loaded, prefetching, or ready
1037
+ if (shard_id in self.loaded_shards or
1038
+ shard_id in self._prefetch_in_progress or
1039
+ shard_id in self._prefetch_ready or
1040
+ shard_id in self.loading_shards):
1041
+ continue
1042
+
1043
+ # Limit total prefetch in progress to avoid overwhelming the system
1044
+ if len(self._prefetch_in_progress) >= 3:
1045
+ break
1046
+
1047
+ shards_to_prefetch.append(shard_id)
1048
+ self._prefetch_in_progress.add(shard_id)
1049
+
1050
+ # Start downloads in background
1051
+ for shard_id in shards_to_prefetch:
1052
+ target_url = self.get_shard_url(shard_id)
1053
+ logger.debug(f"Prefetching shard {shard_id} in background...")
1054
+ self._download_executor.submit(self._prefetch_shard_sync, shard_id, target_url)
1055
+
1056
+ def _prefetch_shard_sync(self, shard_id: int, target_url: str):
1057
+ """Synchronous shard prefetch (runs in background thread)."""
1058
+ try:
1059
+ logger.info(f"[GENESIS] Downloading shard {shard_id}...")
1060
+ # Download the Shard
1061
+ shard_path = None
1062
+
1063
+ if self.swarm:
1064
+ try:
1065
+ shard_path = self.swarm.download_shard(shard_id, manifest_url=target_url)
1066
+ logger.info(f"[GENESIS] Swarm download succeeded for shard {shard_id}")
1067
+ except Exception as e:
1068
+ logger.warning(f"[GENESIS] Swarm prefetch failed: {e}")
1069
+
1070
+ if not shard_path:
1071
+ logger.info(f"[GENESIS] Using HTTP fallback for shard {shard_id}")
1072
+ shard_path = self._http_fallback_download(shard_id, target_url)
1073
+ logger.info(f"[GENESIS] HTTP download completed for shard {shard_id}")
1074
+
1075
+ # Load tensor into prefetch buffer
1076
+ tensor_data = torch.load(shard_path, weights_only=True)
1077
+
1078
+ with self._shard_lock:
1079
+ # DYNAMIC MEMORY LIMIT: Based on user's max_storage_mb setting
1080
+ # Each shard is ~10MB compressed on disk, ~100-200MB uncompressed in RAM
1081
+ # Calculate max shards we can keep in memory
1082
+ shard_size_mb = 150 # Conservative estimate per shard in RAM
1083
+ max_cached_shards = max(3, int(self.max_storage_mb / shard_size_mb))
1084
+
1085
+ total_loaded = len(self.loaded_shards) + len(self._prefetch_ready)
1086
+ if total_loaded >= max_cached_shards:
1087
+ # Clear oldest loaded shard (not the current one)
1088
+ current_shard = self.assigned_shard_ids[self.current_shard_idx % len(self.assigned_shard_ids)] if self.assigned_shard_ids else None
1089
+ for old_shard_id in list(self.loaded_shards.keys()):
1090
+ if old_shard_id != current_shard:
1091
+ del self.loaded_shards[old_shard_id]
1092
+ logger.debug(f"Evicted shard {old_shard_id} from cache (limit: {max_cached_shards} shards)")
1093
+ break
1094
+
1095
+ self._prefetch_ready[shard_id] = tensor_data
1096
+ self._prefetch_in_progress.discard(shard_id)
1097
+
1098
+ logger.info(f"[GENESIS] Shard {shard_id} ready: {len(tensor_data):,} tokens")
1099
+
1100
+ except Exception as e:
1101
+ logger.error(f"[GENESIS] Download FAILED for shard {shard_id}: {type(e).__name__}: {e}")
1102
+ import traceback
1103
+ logger.error(f"[GENESIS] Traceback: {traceback.format_exc()}")
1104
+ with self._shard_lock:
1105
+ self._prefetch_in_progress.discard(shard_id)
1106
+
1107
+ def is_data_ready(self) -> bool:
1108
+ """Check if data is ready for training (non-blocking check)."""
1109
+ # Try to acquire lock with timeout to prevent blocking training loop
1110
+ acquired = self._shard_lock.acquire(timeout=0.5)
1111
+ if not acquired:
1112
+ # Lock held by download thread - assume data might be ready soon
1113
+ logger.debug("[GENESIS] Lock contention in is_data_ready - skipping check")
1114
+ return False
1115
+
1116
+ try:
1117
+ # Data ready if we have current dataset OR prefetched shard is ready
1118
+ if self.current_dataset is not None and len(self.current_dataset) > 0:
1119
+ return True
1120
+
1121
+ # Check if ANY assigned shard is ready (not just current)
1122
+ # This handles the case where prefetch completes before is_data_ready is called
1123
+ if self._prefetch_ready:
1124
+ # A prefetched shard is ready - we can use it
1125
+ return True
1126
+
1127
+ # Also check loaded_shards
1128
+ if self.loaded_shards:
1129
+ return True
1130
+
1131
+ # Check if current shard is specifically ready
1132
+ if self.assigned_shard_ids:
1133
+ shard_id = self.assigned_shard_ids[self.current_shard_idx % len(self.assigned_shard_ids)]
1134
+ if shard_id in self._prefetch_ready:
1135
+ return True
1136
+ if shard_id in self.loaded_shards:
1137
+ return True
1138
+
1139
+ return False
1140
+ finally:
1141
+ self._shard_lock.release()
1142
+
1143
+ def get_shard_url(self, shard_id: int) -> str:
1144
+ """Get download URL for a specific shard (always use CDN)."""
1145
+ # Always use CDN URL regardless of what manifest says
1146
+ # This ensures we go through CloudFront for caching/security
1147
+ return f"{self.GENESIS_CDN_URL}/shard_{shard_id}.pt"
1148
+
1149
+ def _load_shard_sync(self, shard_id: int, target_url: str):
1150
+ """Synchronous shard loading (runs in background thread)."""
1151
+ # Download the Shard (Swarm or HTTP)
1152
+ shard_path = None
1153
+
1154
+ if self.swarm:
1155
+ try:
1156
+ shard_path = self.swarm.download_shard(shard_id, manifest_url=target_url)
1157
+ except Exception as e:
1158
+ logger.error(f"Swarm download failed: {e}")
1159
+
1160
+ if not shard_path:
1161
+ shard_path = self._http_fallback_download(shard_id, target_url)
1162
+
1163
+ # Load tensor
1164
+ try:
1165
+ tensor_data = torch.load(shard_path, weights_only=True)
1166
+ with self._shard_lock:
1167
+ self.loaded_shards[shard_id] = tensor_data
1168
+ self.current_dataset = tensor_data
1169
+ self.dataset_iterator = 0
1170
+ self.loading_shards.discard(shard_id)
1171
+ logger.info(f"Loaded Shard {shard_id}: {len(tensor_data)} tokens")
1172
+ except Exception as e:
1173
+ logger.error(f"Failed to load shard {shard_path}: {e}")
1174
+ with self._shard_lock:
1175
+ self.loading_shards.discard(shard_id)
1176
+ # Create dummy data if all else fails (use valid byte tokens 10-265)
1177
+ self.current_dataset = torch.randint(10, 266, (10000,), dtype=torch.long)
1178
+
1179
+ def ensure_shard_loaded(self, shard_id: int = None):
1180
+ """
1181
+ Download and load a shard if not present.
1182
+ Opportunistically switches to ANY ready shard if the target isn't ready.
1183
+ """
1184
+ target_shard_id = shard_id
1185
+
1186
+ if target_shard_id is None:
1187
+ # Default: try current shard in rotation
1188
+ if not self.assigned_shard_ids:
1189
+ return
1190
+ target_shard_id = self.assigned_shard_ids[self.current_shard_idx % len(self.assigned_shard_ids)]
1191
+
1192
+ with self._shard_lock:
1193
+ # 1. Check if target is ready (Fastest)
1194
+ if target_shard_id in self.loaded_shards:
1195
+ self.current_dataset = self.loaded_shards[target_shard_id]
1196
+ return
1197
+
1198
+ # 2. Check if target is in prefetch buffer
1199
+ if target_shard_id in self._prefetch_ready:
1200
+ self.current_dataset = self._prefetch_ready.pop(target_shard_id)
1201
+ self.loaded_shards[target_shard_id] = self.current_dataset
1202
+ self.dataset_iterator = 0
1203
+ logger.info(f"Using prefetched shard {target_shard_id}: {len(self.current_dataset)} tokens")
1204
+ self._start_prefetch_next_unlocked()
1205
+ return
1206
+
1207
+ # 3. OPPORTUNISTIC: If target isn't ready, check if ANY assigned shard is ready in prefetch
1208
+ # This prevents blocking on shard A when shard B is already downloaded
1209
+ if shard_id is None: # Only if caller didn't request specific shard
1210
+ for ready_id in list(self._prefetch_ready.keys()):
1211
+ if ready_id in self.assigned_shard_ids:
1212
+ # Switch to this ready shard!
1213
+ logger.info(f"Opportunistically switching to ready shard {ready_id} (was waiting for {target_shard_id})")
1214
+
1215
+ # Update index to match
1216
+ try:
1217
+ new_idx = self.assigned_shard_ids.index(ready_id)
1218
+ self.current_shard_idx = new_idx
1219
+ except ValueError:
1220
+ pass
1221
+
1222
+ self.current_dataset = self._prefetch_ready.pop(ready_id)
1223
+ self.loaded_shards[ready_id] = self.current_dataset
1224
+ self.dataset_iterator = 0
1225
+ self._start_prefetch_next_unlocked()
1226
+ return
1227
+
1228
+ # 4. If still nothing, trigger download for target
1229
+ if target_shard_id in self.loading_shards or target_shard_id in self._prefetch_in_progress:
1230
+ logger.debug(f"Shard {target_shard_id} is already being downloaded, waiting...")
1231
+ return # Don't block
1232
+
1233
+ # Mark as loading and start download in background
1234
+ self.loading_shards.add(target_shard_id)
1235
+
1236
+ target_url = self.get_shard_url(target_shard_id)
1237
+ logger.info(f"Loading Shard {target_shard_id} from {target_url}")
1238
+
1239
+ # Submit to thread pool (non-blocking)
1240
+ self._download_executor.submit(self._load_shard_sync, target_shard_id, target_url)
1241
+
1242
+ def _start_prefetch_next_unlocked(self):
1243
+ """Start prefetching next shard (call only when holding _shard_lock)."""
1244
+ # Schedule prefetch in background (don't hold lock during download)
1245
+ self._download_executor.submit(self._start_prefetch_next)
1246
+
1247
+ def _http_fallback_download(self, shard_id: int, target_url: str = None) -> str:
1248
+ """Download shard from CloudFront CDN."""
1249
+ os.makedirs(self.cache_dir, exist_ok=True)
1250
+ shard_path = os.path.join(self.cache_dir, f"genesis_shard_{shard_id}.pt")
1251
+
1252
+ if os.path.exists(shard_path):
1253
+ return shard_path
1254
+
1255
+ # Use target URL from manifest, or construct CDN URL
1256
+ url = target_url or f"{self.GENESIS_CDN_URL}/shard_{shard_id}.pt"
1257
+
1258
+ try:
1259
+ with requests.get(url, stream=True, timeout=60) as r:
1260
+ r.raise_for_status()
1261
+ with open(shard_path, 'wb') as f:
1262
+ for chunk in r.iter_content(chunk_size=8192):
1263
+ f.write(chunk)
1264
+ logger.info(f"Downloaded shard {shard_id}: {os.path.getsize(shard_path)/1e6:.1f}MB")
1265
+ return shard_path
1266
+ except Exception as e:
1267
+ logger.error(f"Failed to download shard {shard_id} from {url}: {e}")
1268
+ raise RuntimeError(f"Failed to download shard {shard_id}: {e}")
1269
+
1270
+ def get_batch(self, batch_size: int = 4, seq_len: int = 512) -> Tuple[torch.Tensor, torch.Tensor]:
1271
+ """
1272
+ Get a batch from the current shard.
1273
+
1274
+ NON-BLOCKING VERSION: Returns quickly if data not ready.
1275
+ Uses prefetch buffer for instant shard switches.
1276
+
1277
+ Automatically rotates to next shard when current one is exhausted.
1278
+ Returns (input_ids, labels).
1279
+
1280
+ Raises RuntimeError if data not ready (caller should retry later).
1281
+ """
1282
+ # Try to load from prefetch buffer first
1283
+ self.ensure_shard_loaded()
1284
+
1285
+ # NON-BLOCKING: Check if data is actually ready
1286
+ # Don't wait/block - let the caller handle the retry
1287
+ if self.current_dataset is None:
1288
+ # Check if anything is in progress
1289
+ with self._shard_lock:
1290
+ loading_any = bool(self.loading_shards or self._prefetch_in_progress)
1291
+ prefetch_ready = bool(self._prefetch_ready)
1292
+
1293
+ if prefetch_ready:
1294
+ # There's a prefetched shard - try to use it
1295
+ self.ensure_shard_loaded()
1296
+ elif not loading_any:
1297
+ # Nothing loading - kick off a new load
1298
+ self._start_prefetch_next()
1299
+
1300
+ # Return early - data not ready yet
1301
+ raise RuntimeError("Data not ready - shard still loading")
1302
+
1303
+ data_len = len(self.current_dataset)
1304
+ req_len = (batch_size * seq_len) + 1
1305
+
1306
+ # Check for early rotation due to loss plateau
1307
+ if self._should_rotate_early():
1308
+ current_shard = self.assigned_shard_ids[self.current_shard_idx % len(self.assigned_shard_ids)]
1309
+ logger.info(f"[GENESIS] Early rotation from shard {current_shard} due to loss plateau")
1310
+ self.force_shard_rotation("loss_plateau")
1311
+ # Ensure new shard is loaded
1312
+ self.ensure_shard_loaded()
1313
+ if self.current_dataset is None:
1314
+ raise RuntimeError("Data not ready - loading fresh shard after plateau")
1315
+ data_len = len(self.current_dataset)
1316
+
1317
+ # Check if we've exhausted current shard
1318
+ if self.dataset_iterator + req_len > data_len:
1319
+ # Log completion of current shard
1320
+ completed_shard = self.assigned_shard_ids[self.current_shard_idx % len(self.assigned_shard_ids)]
1321
+ steps_done = data_len // req_len
1322
+ logger.info(f"✓ Completed shard {completed_shard} ({steps_done} steps, {data_len:,} tokens)")
1323
+
1324
+ # Reset loss tracking for new shard
1325
+ self._loss_history.clear()
1326
+ self._steps_on_current_shard = 0
1327
+
1328
+ # Move to next shard in our assigned list
1329
+ self.current_shard_idx += 1
1330
+
1331
+ if self.current_shard_idx >= len(self.assigned_shard_ids):
1332
+ # We've gone through all assigned shards - rotate to new set
1333
+ logger.info(f"Exhausted all {len(self.assigned_shard_ids)} assigned shards. Rotating to new set...")
1334
+ self.rotate_shards()
1335
+
1336
+ # Try to use prefetched shard (FAST PATH)
1337
+ next_shard_id = self.assigned_shard_ids[self.current_shard_idx % len(self.assigned_shard_ids)]
1338
+
1339
+ with self._shard_lock:
1340
+ if next_shard_id in self._prefetch_ready:
1341
+ # Instant switch to prefetched shard
1342
+ self.current_dataset = self._prefetch_ready.pop(next_shard_id)
1343
+ self.loaded_shards[next_shard_id] = self.current_dataset
1344
+ logger.info(f"Switched to prefetched shard {next_shard_id}: {len(self.current_dataset)} tokens")
1345
+ elif next_shard_id in self.loaded_shards:
1346
+ self.current_dataset = self.loaded_shards[next_shard_id]
1347
+ else:
1348
+ # Need to wait for next shard - trigger load
1349
+ self.ensure_shard_loaded(next_shard_id)
1350
+ raise RuntimeError("Data not ready - loading next shard")
1351
+
1352
+ # Start prefetching the shard after next
1353
+ self._start_prefetch_next()
1354
+
1355
+ self.dataset_iterator = 0
1356
+ data_len = len(self.current_dataset)
1357
+
1358
+ start_idx = self.dataset_iterator
1359
+ end_idx = start_idx + req_len
1360
+
1361
+ chunk = self.current_dataset[start_idx:end_idx]
1362
+ self.dataset_iterator += req_len
1363
+
1364
+ # Log shard progress periodically (every 100 steps within shard)
1365
+ steps_in_shard = self.dataset_iterator // req_len
1366
+ total_steps_in_shard = data_len // req_len
1367
+ if steps_in_shard % 100 == 0:
1368
+ progress_pct = (self.dataset_iterator / data_len) * 100
1369
+ current_shard = self.assigned_shard_ids[self.current_shard_idx % len(self.assigned_shard_ids)]
1370
+ logger.info(f"Shard {current_shard} progress: {progress_pct:.1f}% "
1371
+ f"({steps_in_shard}/{total_steps_in_shard} steps)")
1372
+
1373
+ # Prepare batch
1374
+ exact_len = batch_size * seq_len
1375
+
1376
+ inputs = chunk[:exact_len].view(batch_size, seq_len)
1377
+ labels = chunk[1:exact_len+1].view(batch_size, seq_len)
1378
+
1379
+ return inputs, labels
1380
+
1381
+ def get_stats(self) -> dict:
1382
+ """Get loader statistics."""
1383
+ # Calculate progress within current shard
1384
+ shard_progress = 0.0
1385
+ steps_in_shard = 0
1386
+ total_steps_in_shard = 0
1387
+ current_shard_id = None
1388
+
1389
+ if self.current_dataset is not None and len(self.current_dataset) > 0:
1390
+ data_len = len(self.current_dataset)
1391
+ req_len = 1025 # Approximate: batch_size * seq_len + 1
1392
+ shard_progress = (self.dataset_iterator / data_len) * 100
1393
+ steps_in_shard = self.dataset_iterator // req_len
1394
+ total_steps_in_shard = data_len // req_len
1395
+ if self.assigned_shard_ids:
1396
+ current_shard_id = self.assigned_shard_ids[self.current_shard_idx % len(self.assigned_shard_ids)]
1397
+
1398
+ # Compute loss plateau stats
1399
+ loss_avg = 0.0
1400
+ loss_variance = 0.0
1401
+ if self._loss_history:
1402
+ loss_avg = sum(self._loss_history) / len(self._loss_history)
1403
+ if len(self._loss_history) >= 2:
1404
+ loss_variance = sum((l - loss_avg) ** 2 for l in self._loss_history) / len(self._loss_history)
1405
+
1406
+ return {
1407
+ "total_shards_available": self.total_shards,
1408
+ "max_shards_configured": self.max_shards,
1409
+ "max_storage_mb": self.max_storage_mb,
1410
+ "assigned_shards": len(self.assigned_shard_ids),
1411
+ "loaded_shards": len(self.loaded_shards),
1412
+ "prefetch_in_progress": len(self._prefetch_in_progress),
1413
+ "prefetch_ready": len(self._prefetch_ready),
1414
+ "current_shard_idx": self.current_shard_idx,
1415
+ "current_shard_id": current_shard_id,
1416
+ "shard_progress_pct": round(shard_progress, 1),
1417
+ "steps_in_shard": steps_in_shard,
1418
+ "total_steps_in_shard": total_steps_in_shard,
1419
+ "rotation_count": self.shard_rotation_count,
1420
+ "storage_used_mb": len(self.loaded_shards) * self.SHARD_SIZE_MB,
1421
+ # Loss plateau detection stats
1422
+ "steps_on_current_shard": self._steps_on_current_shard,
1423
+ "loss_history_size": len(self._loss_history),
1424
+ "loss_avg": round(loss_avg, 6),
1425
+ "loss_variance": round(loss_variance, 8),
1426
+ "plateau_threshold": self._loss_plateau_threshold,
1427
+ }
1428
+
1429
+
1430
+ class DataValidator:
1431
+ """
1432
+ Validates training data quality before it enters the buffer.
1433
+
1434
+ Prevents garbage/spam from polluting the local training set.
1435
+ """
1436
+ def __init__(self):
1437
+ pass
1438
+
1439
+ def validate_text(self, text: str) -> Tuple[bool, str]:
1440
+ """
1441
+ Validate text quality.
1442
+ Returns (is_valid, reason).
1443
+ """
1444
+ if not text or not text.strip():
1445
+ return False, "Empty text"
1446
+
1447
+ if len(text) < 20:
1448
+ return False, "Text too short (<20 chars)"
1449
+
1450
+ # Entropy check (compression ratio)
1451
+ # Highly repetitive text compresses too well (ratio > 5.0)
1452
+ # Random text compresses poorly (ratio ~ 1.0)
1453
+ import zlib
1454
+ compressed = zlib.compress(text.encode())
1455
+ ratio = len(text) / len(compressed)
1456
+
1457
+ if ratio > 6.0:
1458
+ return False, f"High compression ratio ({ratio:.1f}) - likely repetitive spam"
1459
+
1460
+ if ratio < 1.1 and len(text) > 200:
1461
+ return False, f"Low compression ratio ({ratio:.1f}) - likely random gibberish"
1462
+
1463
+ # Basic character distribution check
1464
+ # Check if text is mostly special characters
1465
+ alnum_count = sum(c.isalnum() for c in text)
1466
+ if alnum_count / len(text) < 0.5:
1467
+ return False, "Too many special characters"
1468
+
1469
+ return True, "OK"
1470
+
1471
+
1472
+ class FederatedDataManager:
1473
+ """
1474
+ Manages federated dataset for distributed training.
1475
+
1476
+ Nodes can contribute:
1477
+ 1. Text data (tokenized)
1478
+ 2. Curated datasets
1479
+ 3. Synthetic data from other models
1480
+
1481
+ Privacy features:
1482
+ - Differential privacy (noise injection)
1483
+ - Data hashing (no raw text stored)
1484
+ - Local processing only
1485
+ """
1486
+
1487
+ def __init__(self, tokenizer, max_seq_len: int = 2048):
1488
+ self.tokenizer = tokenizer
1489
+ self.max_seq_len = max_seq_len
1490
+
1491
+ # Validator
1492
+ self.validator = DataValidator()
1493
+
1494
+ # Local data buffer
1495
+ self.data_buffer: List[torch.Tensor] = []
1496
+ self.max_buffer_size = 10000
1497
+
1498
+ # Stats
1499
+ self.total_samples_contributed = 0
1500
+ self.total_tokens_contributed = 0
1501
+ self.rejected_samples = 0
1502
+
1503
+ def add_text(self, text: str, apply_dp: bool = True, epsilon: float = 1.0):
1504
+ """
1505
+ Add text to the local training buffer.
1506
+
1507
+ Args:
1508
+ text: Raw text to add
1509
+ apply_dp: Apply differential privacy
1510
+ epsilon: DP epsilon (lower = more private)
1511
+ """
1512
+ # Validate first
1513
+ is_valid, reason = self.validator.validate_text(text)
1514
+ if not is_valid:
1515
+ logger.warning(f"Rejected training data: {reason}")
1516
+ self.rejected_samples += 1
1517
+ return
1518
+
1519
+ # Tokenize
1520
+ tokens = self.tokenizer.encode(text)
1521
+
1522
+ if len(tokens) == 0:
1523
+ return
1524
+
1525
+ # Chunk into sequences with overlap
1526
+ # Use smaller chunk size for flexibility
1527
+ chunk_size = min(self.max_seq_len, 512) # Use 512 for training efficiency
1528
+ stride = chunk_size // 2 # 50% overlap
1529
+
1530
+ chunks_added = 0
1531
+ for i in range(0, max(1, len(tokens) - chunk_size + 1), stride):
1532
+ chunk = tokens[i:i + chunk_size]
1533
+
1534
+ # Pad if needed
1535
+ if len(chunk) < chunk_size:
1536
+ chunk = chunk + [self.tokenizer.pad_token_id] * (chunk_size - len(chunk))
1537
+
1538
+ tensor = torch.tensor(chunk, dtype=torch.long)
1539
+
1540
+ # Apply differential privacy (token-level noise)
1541
+ if apply_dp:
1542
+ tensor = self._apply_dp(tensor, epsilon)
1543
+
1544
+ self.data_buffer.append(tensor)
1545
+ self.total_samples_contributed += 1
1546
+ self.total_tokens_contributed += len(chunk)
1547
+ chunks_added += 1
1548
+
1549
+ # Also handle short texts (< chunk_size)
1550
+ if len(tokens) < chunk_size and chunks_added == 0:
1551
+ chunk = tokens + [self.tokenizer.pad_token_id] * (chunk_size - len(tokens))
1552
+ tensor = torch.tensor(chunk, dtype=torch.long)
1553
+ if apply_dp:
1554
+ tensor = self._apply_dp(tensor, epsilon)
1555
+ self.data_buffer.append(tensor)
1556
+ self.total_samples_contributed += 1
1557
+ self.total_tokens_contributed += len(tokens)
1558
+
1559
+ # Trim buffer if too large
1560
+ if len(self.data_buffer) > self.max_buffer_size:
1561
+ self.data_buffer = self.data_buffer[-self.max_buffer_size:]
1562
+
1563
+ def _apply_dp(self, tokens: torch.Tensor, epsilon: float) -> torch.Tensor:
1564
+ """Apply differential privacy to tokens."""
1565
+ # Simple DP: randomly replace some tokens
1566
+ # More sophisticated methods would use the exponential mechanism
1567
+ noise_mask = torch.rand(tokens.shape) < (1.0 / epsilon)
1568
+ # Use current_vocab_size (not max vocab_size) to only sample valid tokens
1569
+ valid_vocab_size = getattr(self.tokenizer, 'current_vocab_size', 266)
1570
+ random_tokens = torch.randint(0, valid_vocab_size, tokens.shape)
1571
+ return torch.where(noise_mask, random_tokens, tokens)
1572
+
1573
+ def get_batch(self, batch_size: int) -> Tuple[torch.Tensor, torch.Tensor]:
1574
+ """
1575
+ Get a batch for training.
1576
+
1577
+ Returns:
1578
+ (input_ids, labels) - labels are shifted input_ids
1579
+ """
1580
+ if len(self.data_buffer) < batch_size:
1581
+ raise ValueError(f"Not enough data: have {len(self.data_buffer)}, need {batch_size}")
1582
+
1583
+ # Random sample
1584
+ import random
1585
+ indices = random.sample(range(len(self.data_buffer)), batch_size)
1586
+ batch = torch.stack([self.data_buffer[i] for i in indices])
1587
+
1588
+ # For causal LM, labels = inputs shifted by 1
1589
+ input_ids = batch[:, :-1]
1590
+ labels = batch[:, 1:]
1591
+
1592
+ return input_ids, labels
1593
+
1594
+ def get_stats(self) -> Dict[str, Any]:
1595
+ """Get data contribution stats."""
1596
+ return {
1597
+ "buffer_size": len(self.data_buffer),
1598
+ "total_samples": self.total_samples_contributed,
1599
+ "total_tokens": self.total_tokens_contributed,
1600
+ "rejected_samples": self.rejected_samples,
1601
+ }
1602
+