nexaroa 0.0.111__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. neuroshard/__init__.py +93 -0
  2. neuroshard/__main__.py +4 -0
  3. neuroshard/cli.py +466 -0
  4. neuroshard/core/__init__.py +92 -0
  5. neuroshard/core/consensus/verifier.py +252 -0
  6. neuroshard/core/crypto/__init__.py +20 -0
  7. neuroshard/core/crypto/ecdsa.py +392 -0
  8. neuroshard/core/economics/__init__.py +52 -0
  9. neuroshard/core/economics/constants.py +387 -0
  10. neuroshard/core/economics/ledger.py +2111 -0
  11. neuroshard/core/economics/market.py +975 -0
  12. neuroshard/core/economics/wallet.py +168 -0
  13. neuroshard/core/governance/__init__.py +74 -0
  14. neuroshard/core/governance/proposal.py +561 -0
  15. neuroshard/core/governance/registry.py +545 -0
  16. neuroshard/core/governance/versioning.py +332 -0
  17. neuroshard/core/governance/voting.py +453 -0
  18. neuroshard/core/model/__init__.py +30 -0
  19. neuroshard/core/model/dynamic.py +4186 -0
  20. neuroshard/core/model/llm.py +905 -0
  21. neuroshard/core/model/registry.py +164 -0
  22. neuroshard/core/model/scaler.py +387 -0
  23. neuroshard/core/model/tokenizer.py +568 -0
  24. neuroshard/core/network/__init__.py +56 -0
  25. neuroshard/core/network/connection_pool.py +72 -0
  26. neuroshard/core/network/dht.py +130 -0
  27. neuroshard/core/network/dht_plan.py +55 -0
  28. neuroshard/core/network/dht_proof_store.py +516 -0
  29. neuroshard/core/network/dht_protocol.py +261 -0
  30. neuroshard/core/network/dht_service.py +506 -0
  31. neuroshard/core/network/encrypted_channel.py +141 -0
  32. neuroshard/core/network/nat.py +201 -0
  33. neuroshard/core/network/nat_traversal.py +695 -0
  34. neuroshard/core/network/p2p.py +929 -0
  35. neuroshard/core/network/p2p_data.py +150 -0
  36. neuroshard/core/swarm/__init__.py +106 -0
  37. neuroshard/core/swarm/aggregation.py +729 -0
  38. neuroshard/core/swarm/buffers.py +643 -0
  39. neuroshard/core/swarm/checkpoint.py +709 -0
  40. neuroshard/core/swarm/compute.py +624 -0
  41. neuroshard/core/swarm/diloco.py +844 -0
  42. neuroshard/core/swarm/factory.py +1288 -0
  43. neuroshard/core/swarm/heartbeat.py +669 -0
  44. neuroshard/core/swarm/logger.py +487 -0
  45. neuroshard/core/swarm/router.py +658 -0
  46. neuroshard/core/swarm/service.py +640 -0
  47. neuroshard/core/training/__init__.py +29 -0
  48. neuroshard/core/training/checkpoint.py +600 -0
  49. neuroshard/core/training/distributed.py +1602 -0
  50. neuroshard/core/training/global_tracker.py +617 -0
  51. neuroshard/core/training/production.py +276 -0
  52. neuroshard/governance_cli.py +729 -0
  53. neuroshard/grpc_server.py +895 -0
  54. neuroshard/runner.py +3223 -0
  55. neuroshard/sdk/__init__.py +92 -0
  56. neuroshard/sdk/client.py +990 -0
  57. neuroshard/sdk/errors.py +101 -0
  58. neuroshard/sdk/types.py +282 -0
  59. neuroshard/tracker/__init__.py +0 -0
  60. neuroshard/tracker/server.py +864 -0
  61. neuroshard/ui/__init__.py +0 -0
  62. neuroshard/ui/app.py +102 -0
  63. neuroshard/ui/templates/index.html +1052 -0
  64. neuroshard/utils/__init__.py +0 -0
  65. neuroshard/utils/autostart.py +81 -0
  66. neuroshard/utils/hardware.py +121 -0
  67. neuroshard/utils/serialization.py +90 -0
  68. neuroshard/version.py +1 -0
  69. nexaroa-0.0.111.dist-info/METADATA +283 -0
  70. nexaroa-0.0.111.dist-info/RECORD +78 -0
  71. nexaroa-0.0.111.dist-info/WHEEL +5 -0
  72. nexaroa-0.0.111.dist-info/entry_points.txt +4 -0
  73. nexaroa-0.0.111.dist-info/licenses/LICENSE +190 -0
  74. nexaroa-0.0.111.dist-info/top_level.txt +2 -0
  75. protos/__init__.py +0 -0
  76. protos/neuroshard.proto +651 -0
  77. protos/neuroshard_pb2.py +160 -0
  78. protos/neuroshard_pb2_grpc.py +1298 -0
@@ -0,0 +1,1288 @@
1
+ """
2
+ Swarm Node Factory - Creates and configures SwarmEnabledNodes
3
+
4
+ This is THE architecture for NeuroShard. Every node runs the full swarm stack:
5
+ - SwarmRouter for fault-tolerant multipath routing
6
+ - ActivationBuffer for async compute decoupling
7
+ - SwarmHeartbeatService for capacity advertisement
8
+ - DiLoCoTrainer for lazy gradient sync
9
+ - SpeculativeCheckpointer for fast crash recovery
10
+ - RobustAggregator for Byzantine-tolerant gradient aggregation
11
+
12
+ There are NO toggles, NO fallbacks, NO backward compatibility modes.
13
+ Swarm IS the architecture.
14
+
15
+ Usage:
16
+ from neuroshard.core.swarm import create_swarm_node, SwarmNodeConfig
17
+
18
+ config = SwarmNodeConfig(
19
+ diloco_inner_steps=500,
20
+ checkpoint_interval=120,
21
+ )
22
+
23
+ swarm_node = create_swarm_node(
24
+ node_token=token,
25
+ port=port,
26
+ tracker_url=tracker,
27
+ config=config,
28
+ )
29
+ """
30
+
31
+ import asyncio
32
+ import logging
33
+ import threading
34
+ import time
35
+ from dataclasses import dataclass
36
+ from pathlib import Path
37
+ from typing import Dict, List, Optional, Any, Tuple
38
+
39
+ import torch
40
+
41
+ from neuroshard.core.model.dynamic import (
42
+ DynamicNeuroNode,
43
+ create_dynamic_node,
44
+ DynamicNeuroLLM,
45
+ )
46
+
47
+ logger = logging.getLogger(__name__)
48
+
49
+
50
+ @dataclass
51
+ class SwarmNodeConfig:
52
+ """
53
+ Configuration for SwarmEnabledNode.
54
+
55
+ All values have sensible defaults but can be customized.
56
+
57
+ NETWORK SIZE CONSIDERATIONS:
58
+ ===========================
59
+ Small network (1-10 nodes):
60
+ - diloco_inner_steps: 100-200 (sync more often for convergence)
61
+ - aggregation_method: "mean" (few Byzantine concerns)
62
+ - cosine_threshold: 0.3 (more lenient, peers vary)
63
+
64
+ Medium network (10-100 nodes):
65
+ - diloco_inner_steps: 300-500 (balance communication/compute)
66
+ - aggregation_method: "trimmed_mean" (some Byzantine protection)
67
+ - cosine_threshold: 0.5 (stricter validation)
68
+
69
+ Large network (100+ nodes):
70
+ - diloco_inner_steps: 500-1000 (reduce communication overhead)
71
+ - aggregation_method: "trimmed_mean" or "krum" (Byzantine-tolerant)
72
+ - cosine_threshold: 0.5 (strict validation)
73
+ """
74
+
75
+ # Buffer Sizes
76
+ inbound_buffer_size: int = 100
77
+ outbound_buffer_size: int = 50
78
+ soft_overflow_threshold: float = 0.9
79
+ hard_overflow_threshold: float = 0.99
80
+
81
+ # Routing
82
+ ack_timeout_ms: int = 200
83
+ k_candidates: int = 3
84
+
85
+ # Heartbeat
86
+ heartbeat_interval: float = 5.0
87
+ heartbeat_port: int = 9999
88
+
89
+ # DiLoCo - Use get_diloco_inner_steps() for dynamic adjustment
90
+ diloco_inner_steps: int = 500 # Base value (adjusted by network size)
91
+ diloco_inner_steps_min: int = 50 # Minimum for very small networks
92
+ diloco_inner_steps_max: int = 1000 # Maximum for very large networks
93
+ diloco_outer_lr: float = 0.7
94
+ diloco_outer_momentum: float = 0.9
95
+ diloco_auto_scale: bool = True # Auto-adjust inner_steps based on network size
96
+
97
+ # Checkpointing
98
+ checkpoint_interval: int = 120 # 2 minutes
99
+ max_checkpoints: int = 5
100
+ checkpoint_dir: Optional[str] = None # Default: ~/.neuroshard/checkpoints
101
+
102
+ # Aggregation
103
+ aggregation_method: str = "trimmed_mean" # "mean", "median", "trimmed_mean", "krum"
104
+ krum_f: int = 0 # Byzantine workers for Krum
105
+ trimmed_mean_beta: float = 0.1 # Trim fraction
106
+
107
+ # Gradient Validation
108
+ cosine_threshold: float = 0.5
109
+ magnitude_ratio_threshold: float = 10.0
110
+
111
+ # Compute Engine
112
+ num_micro_batches: int = 4
113
+
114
+ def get_checkpoint_dir(self) -> Path:
115
+ """Get checkpoint directory, creating if needed."""
116
+ if self.checkpoint_dir:
117
+ path = Path(self.checkpoint_dir)
118
+ else:
119
+ path = Path.home() / ".neuroshard" / "checkpoints"
120
+ path.mkdir(parents=True, exist_ok=True)
121
+ return path
122
+
123
+ def get_diloco_inner_steps(self, num_peers: int = 1) -> int:
124
+ """
125
+ Get optimal DiLoCo inner steps based on network size.
126
+
127
+ Rationale:
128
+ - Small networks: Sync more often for faster convergence
129
+ - Large networks: Sync less often to reduce communication overhead
130
+
131
+ Formula: base * (1 + log10(num_peers) / 3), clamped to [min, max]
132
+
133
+ Examples:
134
+ - 1 peer: 500 * 1.0 = 500 steps
135
+ - 10 peers: 500 * 1.33 = 665 steps
136
+ - 100 peers: 500 * 1.67 = 833 steps
137
+ - 1000 peers: 500 * 2.0 = 1000 steps (capped)
138
+ """
139
+ if not self.diloco_auto_scale or num_peers <= 1:
140
+ return self.diloco_inner_steps
141
+
142
+ import math
143
+ scale_factor = 1.0 + math.log10(max(1, num_peers)) / 3
144
+ scaled_steps = int(self.diloco_inner_steps * scale_factor)
145
+
146
+ return max(self.diloco_inner_steps_min, min(self.diloco_inner_steps_max, scaled_steps))
147
+
148
+
149
+ class SwarmComponents:
150
+ """
151
+ Container for all swarm components.
152
+
153
+ Provides unified lifecycle management for:
154
+ - SwarmRouter (multipath routing)
155
+ - ActivationBuffer/OutboundBuffer (async compute)
156
+ - SwarmHeartbeatService (capacity advertisement)
157
+ - ComputeEngine (GPU worker)
158
+ - DiLoCoTrainer (lazy gradient sync)
159
+ - SpeculativeCheckpointer (crash recovery)
160
+ - RobustAggregator (Byzantine-tolerant aggregation)
161
+ """
162
+
163
+ def __init__(self):
164
+ # Core components
165
+ self.swarm_router = None
166
+ self.inbound_buffer = None
167
+ self.outbound_buffer = None
168
+ self.heartbeat_service = None
169
+ self.compute_engine = None
170
+
171
+ # Training components
172
+ self.diloco_trainer = None
173
+ self.outer_optimizer = None
174
+ self.speculative_checkpointer = None
175
+ self.robust_aggregator = None
176
+ self.gradient_validator = None
177
+
178
+ # State
179
+ self.running = False
180
+ self._tasks: List[asyncio.Task] = []
181
+
182
+ async def start_async(self):
183
+ """Start all async components."""
184
+ self.running = True
185
+
186
+ if self.swarm_router:
187
+ await self.swarm_router.start()
188
+ logger.info("[SWARM] SwarmRouter started")
189
+
190
+ if self.compute_engine:
191
+ task = asyncio.create_task(self.compute_engine.run())
192
+ self._tasks.append(task)
193
+ logger.info("[SWARM] ComputeEngine started")
194
+
195
+ def start_sync(self):
196
+ """Start all synchronous components (threads)."""
197
+ if self.heartbeat_service:
198
+ self.heartbeat_service.start()
199
+ logger.info("[SWARM] HeartbeatService started")
200
+
201
+ if self.speculative_checkpointer:
202
+ self.speculative_checkpointer.start()
203
+ logger.info("[SWARM] SpeculativeCheckpointer started")
204
+
205
+ async def stop_async(self):
206
+ """Stop all async components."""
207
+ self.running = False
208
+
209
+ for task in self._tasks:
210
+ task.cancel()
211
+ try:
212
+ await task
213
+ except asyncio.CancelledError:
214
+ pass
215
+ self._tasks.clear()
216
+
217
+ if self.swarm_router:
218
+ await self.swarm_router.stop()
219
+
220
+ if self.compute_engine:
221
+ self.compute_engine.running = False
222
+
223
+ def stop_sync(self):
224
+ """Stop all synchronous components."""
225
+ if self.heartbeat_service:
226
+ self.heartbeat_service.stop()
227
+
228
+ if self.speculative_checkpointer:
229
+ self.speculative_checkpointer.stop()
230
+
231
+ def get_stats(self) -> Dict[str, Any]:
232
+ """Get combined stats from all components."""
233
+ stats = {"running": self.running}
234
+
235
+ if self.inbound_buffer:
236
+ stats["inbound_buffer"] = self.inbound_buffer.get_stats()
237
+
238
+ if self.outbound_buffer:
239
+ stats["outbound_buffer"] = self.outbound_buffer.get_stats()
240
+
241
+ if self.swarm_router:
242
+ stats["router"] = self.swarm_router.get_stats()
243
+
244
+ if self.heartbeat_service:
245
+ stats["heartbeat"] = self.heartbeat_service.get_stats()
246
+
247
+ if self.compute_engine:
248
+ stats["compute"] = self.compute_engine.get_stats()
249
+
250
+ if self.diloco_trainer:
251
+ stats["diloco"] = {
252
+ "inner_step_count": self.diloco_trainer.stats.inner_step_count,
253
+ "inner_steps_total": self.diloco_trainer.config.inner_steps,
254
+ "outer_step_count": self.diloco_trainer.stats.outer_step_count,
255
+ }
256
+
257
+ return stats
258
+
259
+
260
+ class SwarmEnabledDynamicNode:
261
+ """
262
+ A DynamicNeuroNode with full swarm capabilities.
263
+
264
+ This is THE node type for NeuroShard. Every node runs:
265
+ - Fault-tolerant multipath routing (SwarmRouter)
266
+ - Async activation buffering (ActivationBuffer, OutboundBuffer)
267
+ - Capacity-aware peer selection (SwarmHeartbeatService)
268
+ - Decoupled GPU compute (ComputeEngine)
269
+ - DiLoCo lazy gradient sync (DiLoCoTrainer)
270
+ - Speculative checkpointing (SpeculativeCheckpointer)
271
+ - Byzantine-tolerant aggregation (RobustAggregator)
272
+
273
+ There are NO toggles. This IS the architecture.
274
+ """
275
+
276
+ def __init__(
277
+ self,
278
+ base_node: DynamicNeuroNode,
279
+ config: SwarmNodeConfig,
280
+ p2p_manager: Optional[Any] = None,
281
+ ):
282
+ """
283
+ Initialize SwarmEnabledDynamicNode.
284
+
285
+ Args:
286
+ base_node: The DynamicNeuroNode to enhance (REQUIRED)
287
+ config: SwarmNodeConfig with settings (REQUIRED)
288
+ p2p_manager: Optional P2P manager (uses base_node.p2p_manager if not provided)
289
+ """
290
+ self.base_node = base_node
291
+ self.config = config
292
+ self.p2p_manager = p2p_manager or base_node.p2p_manager
293
+
294
+ # Expose base node properties directly
295
+ self.node_id = base_node.node_id
296
+ self.node_token = base_node.node_token
297
+ self.model = base_node.model
298
+ self.my_layer_ids = base_node.my_layer_ids
299
+ self.layer_pool = base_node.layer_pool
300
+ self.enable_training = base_node.enable_training
301
+ self.device = base_node.device
302
+ self.available_memory_mb = base_node.available_memory_mb
303
+
304
+ # Training state - sync from base node (may have been loaded from checkpoint)
305
+ self._total_training_rounds = base_node.total_training_rounds
306
+ self._current_loss = base_node.current_loss if base_node.current_loss != float('inf') else float('inf')
307
+
308
+ # Initialize swarm components (router, heartbeat, compute, etc.)
309
+ # NOTE: Named swarm_components to avoid conflict with base_node.swarm (DataSwarm)
310
+ self.swarm_components = SwarmComponents()
311
+ self._init_swarm_components()
312
+
313
+ logger.info(f"[SWARM] SwarmEnabledNode initialized for {self.node_id[:16]}...")
314
+ logger.info(f"[SWARM] - DiLoCo: inner_steps={config.diloco_inner_steps}")
315
+ logger.info(f"[SWARM] - Checkpointing: interval={config.checkpoint_interval}s")
316
+ logger.info(f"[SWARM] - Heartbeat: interval={config.heartbeat_interval}s")
317
+
318
+ # Make swarm_components accessible from base_node BEFORE restoring pending state
319
+ # (DiLoCo restore needs access to swarm_components)
320
+ # NOTE: Don't overwrite base_node.swarm - that's the DataSwarm for P2P downloads!
321
+ base_node.swarm_components = self.swarm_components
322
+
323
+ # Restore pending state from checkpoint (DiLoCo, optimizer)
324
+ # This must happen AFTER swarm components are initialized AND base_node.swarm is set
325
+ if hasattr(base_node, '_restore_pending_state'):
326
+ base_node._restore_pending_state()
327
+
328
+ # ==================== PROPERTIES ====================
329
+ # Expose training state for runner/dashboard access
330
+
331
+ @property
332
+ def total_training_rounds(self) -> int:
333
+ """Total training rounds completed (used by runner for PoNW proofs)."""
334
+ return self._total_training_rounds
335
+
336
+ @total_training_rounds.setter
337
+ def total_training_rounds(self, value: int):
338
+ self._total_training_rounds = value
339
+
340
+ @property
341
+ def current_loss(self) -> float:
342
+ """Current training loss (used by dashboard)."""
343
+ return self._current_loss
344
+
345
+ @current_loss.setter
346
+ def current_loss(self, value: float):
347
+ self._current_loss = value
348
+
349
+ @property
350
+ def total_tokens_processed(self) -> int:
351
+ """Total tokens processed (delegate to base node)."""
352
+ return self.base_node.total_tokens_processed
353
+
354
+ @total_tokens_processed.setter
355
+ def total_tokens_processed(self, value: int):
356
+ self.base_node.total_tokens_processed = value
357
+
358
+ @property
359
+ def current_training_round(self) -> int:
360
+ """Current DiLoCo outer round (for gradient sync coordination)."""
361
+ if self.swarm_components.diloco_trainer:
362
+ return self.swarm_components.diloco_trainer.stats.outer_step_count
363
+ return 0
364
+
365
+ def _init_swarm_components(self):
366
+ """Initialize all swarm components."""
367
+ from neuroshard.core.swarm.buffers import ActivationBuffer, OutboundBuffer
368
+ from neuroshard.core.swarm.router import SwarmRouter
369
+ from neuroshard.core.swarm.heartbeat import SwarmHeartbeatService
370
+ from neuroshard.core.swarm.compute import ComputeEngine
371
+
372
+ # Create buffers
373
+ self.swarm_components.inbound_buffer = ActivationBuffer(
374
+ max_size=self.config.inbound_buffer_size
375
+ )
376
+ self.swarm_components.outbound_buffer = OutboundBuffer(
377
+ max_size=self.config.outbound_buffer_size,
378
+ soft_overflow_threshold=self.config.soft_overflow_threshold,
379
+ hard_overflow_threshold=self.config.hard_overflow_threshold,
380
+ )
381
+
382
+ # Create router
383
+ dht = self.p2p_manager.dht if self.p2p_manager else None
384
+ self.swarm_components.swarm_router = SwarmRouter(dht_protocol=dht)
385
+ self.swarm_components.swarm_router.K_CANDIDATES = self.config.k_candidates
386
+ self.swarm_components.swarm_router.ACK_TIMEOUT_MS = self.config.ack_timeout_ms
387
+
388
+ # Create heartbeat service
389
+ self.swarm_components.heartbeat_service = SwarmHeartbeatService(
390
+ node_id=self.node_id,
391
+ udp_port=self.config.heartbeat_port,
392
+ )
393
+ self.swarm_components.heartbeat_service.HEARTBEAT_INTERVAL = self.config.heartbeat_interval
394
+ self.swarm_components.heartbeat_service.set_capacity_callback(self._get_capacity_bitmask)
395
+ self.swarm_components.heartbeat_service.set_peer_update_callback(
396
+ self.swarm_components.swarm_router.update_peer_from_heartbeat
397
+ )
398
+
399
+ # Create compute engine
400
+ self.swarm_components.compute_engine = ComputeEngine(
401
+ model=self.model,
402
+ inbound=self.swarm_components.inbound_buffer,
403
+ outbound=self.swarm_components.outbound_buffer,
404
+ diloco_trainer=None, # Set after DiLoCo init
405
+ num_micro_batches=self.config.num_micro_batches,
406
+ )
407
+
408
+ # Initialize training components if training is enabled
409
+ if self.enable_training:
410
+ self._init_diloco()
411
+ self._init_checkpointer()
412
+ self._init_global_tracker() # Initialize tracker at startup to restore state
413
+
414
+ def _init_diloco(self):
415
+ """Initialize DiLoCo trainer and related components."""
416
+ from neuroshard.core.swarm.diloco import DiLoCoTrainer, OuterOptimizer, DiLoCoConfig
417
+ from neuroshard.core.swarm.aggregation import (
418
+ RobustAggregator,
419
+ GradientValidator,
420
+ AggregationConfig,
421
+ AggregationStrategy,
422
+ ValidationConfig,
423
+ )
424
+
425
+ # Create outer optimizer
426
+ self.swarm_components.outer_optimizer = OuterOptimizer(
427
+ lr=self.config.diloco_outer_lr,
428
+ momentum=self.config.diloco_outer_momentum,
429
+ )
430
+
431
+ # Create DiLoCo trainer
432
+ diloco_config = DiLoCoConfig(
433
+ inner_steps=self.config.diloco_inner_steps,
434
+ outer_lr=self.config.diloco_outer_lr,
435
+ outer_momentum=self.config.diloco_outer_momentum,
436
+ )
437
+
438
+ self.swarm_components.diloco_trainer = DiLoCoTrainer(
439
+ model=self.model,
440
+ config=diloco_config,
441
+ inner_optimizer=self.base_node.optimizer,
442
+ )
443
+
444
+ # Connect to compute engine
445
+ self.swarm_components.compute_engine.diloco = self.swarm_components.diloco_trainer
446
+
447
+ # Create gradient validator
448
+ validation_config = ValidationConfig(
449
+ min_cosine_similarity=self.config.cosine_threshold,
450
+ max_magnitude_ratio=self.config.magnitude_ratio_threshold,
451
+ )
452
+ self.swarm_components.gradient_validator = GradientValidator(config=validation_config)
453
+
454
+ # Create robust aggregator
455
+ strategy_map = {
456
+ "mean": AggregationStrategy.MEAN,
457
+ "median": AggregationStrategy.MEDIAN,
458
+ "trimmed_mean": AggregationStrategy.TRIMMED_MEAN,
459
+ "krum": AggregationStrategy.KRUM,
460
+ }
461
+ strategy = strategy_map.get(self.config.aggregation_method, AggregationStrategy.TRIMMED_MEAN)
462
+
463
+ agg_config = AggregationConfig(
464
+ strategy=strategy,
465
+ num_byzantine=self.config.krum_f,
466
+ trim_fraction=self.config.trimmed_mean_beta,
467
+ )
468
+
469
+ self.swarm_components.robust_aggregator = RobustAggregator(
470
+ aggregation_config=agg_config,
471
+ validation_config=validation_config,
472
+ )
473
+
474
+ def _init_checkpointer(self):
475
+ """Initialize speculative checkpointer."""
476
+ from neuroshard.core.swarm.checkpoint import SpeculativeCheckpointer, CheckpointConfig
477
+
478
+ checkpoint_config = CheckpointConfig(
479
+ checkpoint_dir=str(self.config.get_checkpoint_dir()),
480
+ snapshot_interval=float(self.config.checkpoint_interval),
481
+ max_hot_snapshots=self.config.max_checkpoints,
482
+ )
483
+
484
+ self.swarm_components.speculative_checkpointer = SpeculativeCheckpointer(
485
+ model=self.model,
486
+ optimizer=self.base_node.optimizer,
487
+ config=checkpoint_config,
488
+ diloco_trainer=self.swarm_components.diloco_trainer,
489
+ p2p_manager=self.p2p_manager,
490
+ )
491
+
492
+ def _get_model_hash(self) -> str:
493
+ """
494
+ Get hash of current model architecture for compatibility checking.
495
+
496
+ Used to ensure peers are training compatible architectures before
497
+ accepting their gradient contributions.
498
+ """
499
+ if not self.model:
500
+ return ""
501
+
502
+ import hashlib
503
+ hasher = hashlib.sha256()
504
+
505
+ # Hash architecture dimensions (hidden_dim, num_layers, num_heads)
506
+ # Support both DynamicNeuroLLM (uses .architecture) and regular models (direct attributes)
507
+ if hasattr(self.model, 'architecture'):
508
+ hidden_dim = self.model.architecture.hidden_dim
509
+ num_heads = self.model.architecture.num_heads
510
+ else:
511
+ hidden_dim = getattr(self.model, 'hidden_dim', 0)
512
+ num_heads = getattr(self.model, 'num_heads', 0)
513
+ arch_str = f"{hidden_dim}:{len(self.my_layer_ids)}:{num_heads}"
514
+ hasher.update(arch_str.encode())
515
+
516
+ # Hash parameter names and shapes (not values - just structure)
517
+ for name, param in sorted(self.model.named_parameters()):
518
+ hasher.update(f"{name}:{list(param.shape)}".encode())
519
+
520
+ return hasher.hexdigest()[:16]
521
+
522
+ def _get_capacity_bitmask(self):
523
+ """Get current capacity for heartbeat broadcast."""
524
+ from neuroshard.core.swarm.heartbeat import CapacityBitmask
525
+
526
+ # Get memory info
527
+ available_mb = 0
528
+ gpu_util = 0.0
529
+
530
+ if torch.cuda.is_available():
531
+ try:
532
+ free, total = torch.cuda.mem_get_info()
533
+ available_mb = free // (1024 * 1024)
534
+ except:
535
+ pass
536
+ else:
537
+ available_mb = self.available_memory_mb
538
+
539
+ # Determine layer range
540
+ layer_range = (0, 0)
541
+ if self.my_layer_ids:
542
+ layer_range = (min(self.my_layer_ids), max(self.my_layer_ids) + 1)
543
+
544
+ # Get buffer status
545
+ queue_depth = len(self.swarm_components.inbound_buffer) if self.swarm_components.inbound_buffer else 0
546
+ is_backpressured = self.swarm_components.inbound_buffer.is_backpressured if self.swarm_components.inbound_buffer else False
547
+
548
+ return CapacityBitmask(
549
+ node_id=self.node_id,
550
+ timestamp=time.time(),
551
+ available_memory_mb=available_mb,
552
+ queue_depth=queue_depth,
553
+ layer_range=layer_range,
554
+ gpu_utilization=gpu_util,
555
+ network_saturation=0.0,
556
+ is_training=self.enable_training,
557
+ is_accepting_inference=True,
558
+ is_accepting_activations=not is_backpressured,
559
+ grpc_addr=self.base_node.grpc_addr,
560
+ )
561
+
562
+ # ==================== LIFECYCLE ====================
563
+
564
+ def start(self):
565
+ """Start all swarm components."""
566
+ self.swarm_components.start_sync()
567
+ logger.info("[SWARM] Node started")
568
+
569
+ def stop(self):
570
+ """Stop all swarm components."""
571
+ self.swarm_components.stop_sync()
572
+ logger.info("[SWARM] Node stopped")
573
+
574
+ async def start_async(self):
575
+ """Start async swarm components."""
576
+ await self.swarm_components.start_async()
577
+
578
+ async def stop_async(self):
579
+ """Stop async swarm components."""
580
+ await self.swarm_components.stop_async()
581
+
582
+ # ==================== BUFFER ACCESS ====================
583
+
584
+ def receive_activation(self, packet: 'ActivationPacket') -> bool:
585
+ """
586
+ Receive an activation packet from a peer.
587
+
588
+ Called by gRPC handler when SwarmForward is received.
589
+ """
590
+ return self.swarm_components.inbound_buffer.put_nowait(packet)
591
+
592
+ def get_buffer_status(self) -> Dict[str, Any]:
593
+ """Get status of inbound and outbound buffers."""
594
+ return {
595
+ 'inbound': self.swarm_components.inbound_buffer.get_stats(),
596
+ 'outbound': self.swarm_components.outbound_buffer.get_stats(),
597
+ }
598
+
599
+ # ==================== ROUTING ====================
600
+
601
+ def get_swarm_route(self) -> Dict[int, List['PeerCandidate']]:
602
+ """
603
+ Get swarm route with K candidates per layer.
604
+
605
+ Returns dict of layer_id -> list of K candidates.
606
+ """
607
+ from neuroshard.core.swarm.router import PeerCandidate
608
+
609
+ route: Dict[int, List[PeerCandidate]] = {}
610
+ num_layers = self.layer_pool.current_num_layers if self.layer_pool else 12
611
+
612
+ for layer_id in range(num_layers):
613
+ candidates = self.swarm_components.swarm_router.get_candidates(layer_id)
614
+ if candidates:
615
+ route[layer_id] = candidates
616
+
617
+ return route
618
+
619
+ # ==================== TRAINING ====================
620
+
621
+ def _init_global_tracker(self):
622
+ """Initialize global training tracker for verification."""
623
+ try:
624
+ from neuroshard.core.training.global_tracker import GlobalTrainingTracker
625
+ self._global_tracker = GlobalTrainingTracker(
626
+ node_id=self.node_id,
627
+ model=self.model,
628
+ )
629
+ logger.info("[SWARM] Global training tracker initialized")
630
+ except Exception as e:
631
+ logger.warning(f"[SWARM] Could not init global tracker: {e}")
632
+ self._global_tracker = None
633
+
634
+ def train_step(self) -> Optional[float]:
635
+ """
636
+ Execute a training step with DiLoCo lazy gradient sync.
637
+
638
+ Supports both local and distributed training:
639
+ - Full nodes (embedding + LM head): Train locally with DiLoCo
640
+ - Partial nodes (embedding only): Use pipeline training to Validator
641
+ - Workers (no embedding): Wait for activations via gRPC
642
+
643
+ Returns loss value or None if no data available.
644
+ """
645
+ if not self.enable_training:
646
+ return None
647
+
648
+ # Ensure global tracker is initialized
649
+ if not hasattr(self, '_global_tracker') or self._global_tracker is None:
650
+ self._init_global_tracker()
651
+
652
+ # Check if we can do LOCAL training (need embedding + LM head)
653
+ # DYNAMIC CHECK: Use layer_pool to get CURRENT lm_head_holder
654
+ # This handles the case where a new Validator joined and took over the LM head
655
+ am_current_validator = self.model.has_lm_head
656
+ if hasattr(self.base_node, 'layer_pool') and self.base_node.layer_pool:
657
+ am_current_validator = (self.base_node.layer_pool.lm_head_holder == self.base_node.node_id)
658
+
659
+ is_full_node = self.model.has_embedding and am_current_validator
660
+
661
+ if not self.model.has_embedding:
662
+ # WORKER: No embedding = wait for activations via gRPC
663
+ # Training happens reactively in forward_pipeline/backward_pipeline
664
+ return None
665
+
666
+ if not is_full_node:
667
+ # DISTRIBUTED TRAINING: Has embedding but no LM head
668
+ # OR: We WERE a full node but new Validator took over → use pipeline!
669
+ # Use pipeline training - forward to Validator, receive gradients back
670
+ return self.base_node.train_step()
671
+
672
+ # LOCAL TRAINING: Full node with embedding + LM head
673
+ diloco = self.swarm_components.diloco_trainer
674
+
675
+ # Get training data
676
+ batch = self.base_node._get_training_batch()
677
+ if batch is None:
678
+ return None
679
+
680
+ input_ids, labels = batch
681
+
682
+ # Forward pass
683
+ self.model.train()
684
+ outputs = self.model.forward_my_layers(
685
+ self.model.embed(input_ids.to(self.device))
686
+ )
687
+
688
+ # Compute loss locally
689
+ logits = self.model.compute_logits(outputs)
690
+ loss = torch.nn.functional.cross_entropy(
691
+ logits.view(-1, logits.size(-1)),
692
+ labels.view(-1).to(self.device)
693
+ )
694
+
695
+ # DiLoCo inner step (backward + optimizer step + zero_grad)
696
+ diloco.inner_step(loss)
697
+ self._current_loss = loss.item()
698
+ self._total_training_rounds += 1
699
+
700
+ # Record loss to data loader for plateau detection
701
+ # This enables adaptive shard rotation when the model stops learning
702
+ if hasattr(self.base_node, 'genesis_loader') and self.base_node.genesis_loader:
703
+ self.base_node.genesis_loader.record_loss(self._current_loss)
704
+
705
+ # Note: grad_norm is 0 after inner_step because zero_grad() was called
706
+ # To properly track gradient norm, we'd need to compute it inside inner_step
707
+ # before zero_grad(). For now, we use pseudo-gradient norm from DiLoCo stats.
708
+ grad_norm = diloco.stats.avg_pseudo_grad_norm if diloco.stats.avg_pseudo_grad_norm > 0 else 0.0
709
+
710
+ # Record step in global tracker (for verification)
711
+ if self._global_tracker:
712
+ current_shard = 0
713
+ if hasattr(self.base_node, 'genesis_loader') and self.base_node.genesis_loader:
714
+ stats = self.base_node.genesis_loader.get_stats()
715
+ current_shard = stats.get('current_shard_id', 0)
716
+
717
+ self._global_tracker.record_step(
718
+ loss=self._current_loss,
719
+ step=self._total_training_rounds,
720
+ shard_id=current_shard,
721
+ tokens_in_batch=input_ids.numel(),
722
+ gradient_norm=grad_norm,
723
+ inner_step=diloco.stats.inner_step_count,
724
+ outer_step=diloco.stats.outer_step_count,
725
+ )
726
+
727
+ # Check if outer sync needed
728
+ if diloco.should_sync():
729
+ self._do_diloco_sync()
730
+
731
+ # NOTE: Checkpoint saving is handled by DiLoCo outer sync (every 500 steps)
732
+ # We removed the per-100-step checkpoint because:
733
+ # 1. It caused 70-80s delays on memory-constrained systems
734
+ # 2. DiLoCo outer sync already saves after each 500-step cycle
735
+ # 3. Worst case loss on crash: 500 steps (acceptable)
736
+
737
+ return self._current_loss
738
+
739
+ def get_global_training_status(self) -> Dict[str, Any]:
740
+ """
741
+ Get global training verification status.
742
+
743
+ Returns comprehensive stats showing:
744
+ - Whether training is actually improving the model
745
+ - Whether nodes are converging (same model hash)
746
+ - Network-wide loss metrics
747
+ """
748
+ if not hasattr(self, '_global_tracker') or not self._global_tracker:
749
+ return {
750
+ "error": "Global tracker not initialized",
751
+ "training_verified": False,
752
+ }
753
+
754
+ return self._global_tracker.get_global_status()
755
+
756
+ def _do_diloco_sync(self):
757
+ """
758
+ Execute DiLoCo outer synchronization with REAL gradient exchange.
759
+
760
+ This is the key step that makes distributed training ACTUALLY work:
761
+ 1. Compute local pseudo-gradient (delta from initial weights)
762
+ 2. Gossip to peers via GossipGradient RPC
763
+ 3. Collect peer contributions
764
+ 4. Aggregate using Byzantine-tolerant aggregator
765
+ 5. Apply aggregated update to model
766
+
767
+ If no peers respond, falls back to local gradient only.
768
+ """
769
+ import zlib
770
+ import random
771
+ from protos import neuroshard_pb2, neuroshard_pb2_grpc
772
+ from neuroshard.core.network.connection_pool import get_channel
773
+ from urllib.parse import urlparse
774
+
775
+ diloco = self.swarm_components.diloco_trainer
776
+
777
+ # Step 1: Compute local pseudo-gradient
778
+ pseudo_grad = diloco.compute_pseudo_gradient()
779
+
780
+ # Compress gradients for transmission
781
+ from neuroshard.core.training.production import GradientCompressor
782
+ compressor = GradientCompressor()
783
+
784
+ compressed_grads = {}
785
+ for name, grad in pseudo_grad.items():
786
+ compressed_grads[name] = compressor.compress(grad, name)
787
+
788
+ # Step 2: Gossip to peers
789
+ peers_synced = 0
790
+ peer_contributions = [] # List of (peer_id, decompressed_grads, batch_size, loss)
791
+
792
+ if self.p2p_manager:
793
+ # Get list of peers
794
+ peers = list(self.p2p_manager.known_peers.keys())
795
+ routing_peers = []
796
+ if self.p2p_manager.routing_table:
797
+ for n in self.p2p_manager.routing_table.get_all_nodes():
798
+ routing_peers.append(f"http://{n.ip}:{n.port}")
799
+
800
+ peers.extend(routing_peers)
801
+ # Deduplicate
802
+ peers = list(set(peers))
803
+
804
+ logger.debug(f"[SWARM] Peer discovery: known_peers={len(self.p2p_manager.known_peers)}, "
805
+ f"routing_table={len(routing_peers)}, total_unique={len(peers)}")
806
+
807
+ if peers:
808
+ # DYNAMIC GOSSIP FANOUT for network scaling
809
+ # Formula: 2 * sqrt(N) + 3, capped at 50
810
+ # This ensures gossip reaches the whole network in O(log N) rounds
811
+ # - 1-10 nodes: 5-9 peers (ensures full coverage)
812
+ # - 100 nodes: ~23 peers (good redundancy)
813
+ # - 1000 nodes: ~66 peers (capped at 50)
814
+ # - 10000+ nodes: 50 peers (bandwidth limit)
815
+ import math
816
+ num_peers = len(peers)
817
+ # Use 2*sqrt(N) for better convergence in large networks
818
+ fanout = min(int(2 * math.sqrt(num_peers) + 3), 50)
819
+ targets = random.sample(peers, min(num_peers, fanout))
820
+
821
+ # Create request with model hash for architecture validation
822
+ model_hash = self._get_model_hash()
823
+
824
+ # Include architecture dimensions for debugging compatibility issues
825
+ arch_info = ""
826
+ if self.model:
827
+ arch_info = f"{getattr(self.model, 'hidden_dim', '?')}x{len(self.my_layer_ids)}L"
828
+
829
+ req = neuroshard_pb2.GossipGradientRequest(
830
+ node_id=str(self.node_id),
831
+ round_id=diloco.stats.outer_step_count,
832
+ model_hash=model_hash, # Critical for architecture validation
833
+ timestamp=time.time(),
834
+ batch_size=diloco.config.inner_steps,
835
+ loss=diloco.stats.avg_inner_loss,
836
+ layer_gradients=compressed_grads,
837
+ signature=f"diloco_{diloco.stats.outer_step_count}_{self.node_id[:8]}_{arch_info}",
838
+ ttl=3,
839
+ )
840
+
841
+ logger.info(f"[SWARM] DiLoCo sync: gossiping to {len(targets)} peers...")
842
+
843
+ # Gossip to each peer and collect responses
844
+ for target_url in targets:
845
+ try:
846
+ parsed = urlparse(target_url)
847
+ ip = parsed.hostname
848
+ port = (parsed.port or 80) + 1000 # gRPC port
849
+
850
+ channel = get_channel(f"{ip}:{port}")
851
+ stub = neuroshard_pb2_grpc.NeuroShardServiceStub(channel)
852
+
853
+ resp = stub.GossipGradient(req, timeout=10.0)
854
+
855
+ if resp.accepted:
856
+ peers_synced += 1
857
+ logger.debug(f"[SWARM] Peer {ip}:{port} accepted gradient (round={resp.current_round})")
858
+ except Exception as e:
859
+ logger.debug(f"[SWARM] Gossip to {target_url} failed: {e}")
860
+
861
+ # Step 3: Collect contributions (our own + any received from peers)
862
+ # Note: Peer contributions are received asynchronously via GossipGradient RPC handler
863
+ # For now, we use our local gradient + any cached peer gradients
864
+ all_contributions = [
865
+ {
866
+ "node_id": self.node_id,
867
+ "gradients": pseudo_grad,
868
+ "batch_size": diloco.config.inner_steps,
869
+ "loss": diloco.stats.avg_inner_loss,
870
+ }
871
+ ]
872
+
873
+ # Add any peer contributions we've received (stored in _received_peer_grads)
874
+ if hasattr(self, '_received_peer_grads'):
875
+ for peer_id, peer_data in self._received_peer_grads.items():
876
+ # Only use fresh contributions (< 60s old)
877
+ if time.time() - peer_data.get('timestamp', 0) < 60:
878
+ all_contributions.append({
879
+ "node_id": peer_id,
880
+ "gradients": peer_data['gradients'],
881
+ "batch_size": peer_data.get('batch_size', 1),
882
+ "loss": peer_data.get('loss', float('inf')),
883
+ })
884
+ # Clear after using
885
+ self._received_peer_grads = {}
886
+
887
+ # Step 4: Aggregate using robust aggregator
888
+ aggregated_grads = pseudo_grad # Default to local if no peers
889
+
890
+ if len(all_contributions) > 1 and self.swarm_components.robust_aggregator:
891
+ logger.info(f"[SWARM] Aggregating {len(all_contributions)} contributions")
892
+ try:
893
+ # Clear previous contributions and add new ones
894
+ self.swarm_components.robust_aggregator.clear()
895
+
896
+ # Add peer contributions (validate against our local gradient)
897
+ rejected_peers = []
898
+ for contrib in all_contributions:
899
+ if contrib["node_id"] == self.node_id:
900
+ continue # Skip self - add as local_grads
901
+
902
+ accepted, reason = self.swarm_components.robust_aggregator.add_contribution(
903
+ peer_id=contrib["node_id"],
904
+ gradients=contrib["gradients"],
905
+ reference_grads=pseudo_grad, # Validate against our gradient
906
+ trust_score=1.0,
907
+ validate=True
908
+ )
909
+ if not accepted:
910
+ rejected_peers.append((contrib["node_id"][:16], reason))
911
+
912
+ # Aggregate with our local gradients included
913
+ aggregated_grads = self.swarm_components.robust_aggregator.aggregate(
914
+ local_grads=pseudo_grad
915
+ )
916
+
917
+ if rejected_peers:
918
+ logger.warning(f"[SWARM] Rejected {len(rejected_peers)} contributions: {rejected_peers}")
919
+
920
+ except Exception as e:
921
+ logger.error(f"[SWARM] Aggregation failed, using local: {e}")
922
+ import traceback
923
+ logger.error(traceback.format_exc())
924
+ aggregated_grads = pseudo_grad
925
+
926
+ # Capture stats BEFORE apply_outer_update (which resets them)
927
+ avg_loss_before_reset = diloco.stats.avg_inner_loss
928
+ outer_step_before = diloco.stats.outer_step_count
929
+
930
+ # Step 5: Apply aggregated update
931
+ diloco.apply_outer_update(aggregated_grads)
932
+
933
+ # Record sync result for tracking
934
+ if hasattr(self, '_global_tracker') and self._global_tracker:
935
+ self._global_tracker.record_sync_result(
936
+ success=peers_synced > 0,
937
+ peers_synced=peers_synced
938
+ )
939
+
940
+ sync_status = "distributed" if peers_synced > 0 else "local-only"
941
+ logger.info(
942
+ f"[SWARM] DiLoCo outer sync #{outer_step_before + 1} ({sync_status}): "
943
+ f"peers={peers_synced}, contributions={len(all_contributions)}, "
944
+ f"avg_loss={avg_loss_before_reset:.4f}"
945
+ )
946
+
947
+ # IMPORTANT: Save checkpoint after each outer sync (this is a major milestone!)
948
+ self._save_checkpoint()
949
+ logger.info(f"[SWARM] Checkpoint saved after outer sync #{diloco.stats.outer_step_count}")
950
+
951
+ def receive_peer_gradients(self, contribution) -> bool:
952
+ """
953
+ Receive gradient contribution from a peer (called by GossipGradient RPC).
954
+
955
+ Stores the contribution for use in next sync round.
956
+
957
+ Security Checks (in order):
958
+ 1. Model hash validation (architecture compatibility)
959
+ 2. Round ID check (freshness)
960
+ 3. Gradient shape validation (per-parameter)
961
+ 4. Gradient magnitude sanity check
962
+ """
963
+ from neuroshard.core.training.production import GradientCompressor
964
+
965
+ if not hasattr(self, '_received_peer_grads'):
966
+ self._received_peer_grads = {}
967
+
968
+ try:
969
+ # === SECURITY CHECK 1: Model Hash Validation (REQUIRED for gradients) ===
970
+ # Unlike proofs (which are historical records), gradient sync REQUIRES
971
+ # matching architectures - you cannot average tensors of different shapes.
972
+ # Peers with different architectures form separate training cohorts.
973
+ if hasattr(contribution, 'model_hash') and contribution.model_hash:
974
+ our_hash = self._get_model_hash()
975
+ if our_hash and contribution.model_hash != our_hash:
976
+ logger.info(
977
+ f"[SWARM] Skipping gradient from {contribution.node_id[:8]}... - "
978
+ f"different architecture cohort (peer={contribution.model_hash[:8]}..., "
979
+ f"ours={our_hash[:8]}...). Will sync with matching peers."
980
+ )
981
+ return False
982
+
983
+ # === SECURITY CHECK 2: Round ID Freshness ===
984
+ # Only accept gradients from recent rounds (within 2 rounds)
985
+ if self.swarm_components.diloco_trainer:
986
+ our_round = self.swarm_components.diloco_trainer.stats.outer_step_count
987
+ peer_round = getattr(contribution, 'round_id', 0)
988
+ if abs(our_round - peer_round) > 2:
989
+ logger.warning(
990
+ f"[SWARM] REJECTED stale gradient from {contribution.node_id[:8]}... - "
991
+ f"round mismatch: peer={peer_round}, ours={our_round}"
992
+ )
993
+ return False
994
+
995
+ # Decompress gradients
996
+ compressor = GradientCompressor()
997
+ gradients = {}
998
+
999
+ for name, compressed in contribution.layer_gradients.items():
1000
+ gradients[name] = compressor.decompress(compressed)
1001
+
1002
+ # === SECURITY CHECK 3: Architecture Shape Validation ===
1003
+ # Ensures gradient shapes match our model parameters
1004
+ if self.model and hasattr(self.model, 'named_parameters'):
1005
+ our_params = dict(self.model.named_parameters())
1006
+ mismatched = []
1007
+
1008
+ for name, grad in gradients.items():
1009
+ if name in our_params:
1010
+ if grad.shape != our_params[name].shape:
1011
+ mismatched.append((name, grad.shape, our_params[name].shape))
1012
+
1013
+ if mismatched:
1014
+ logger.warning(
1015
+ f"[SWARM] REJECTED gradient from {contribution.node_id[:8]}... - "
1016
+ f"{len(mismatched)} shape mismatches (incompatible architecture!)"
1017
+ )
1018
+ for name, peer_shape, our_shape in mismatched[:3]: # Log first 3
1019
+ logger.warning(f" {name}: peer={peer_shape}, ours={our_shape}")
1020
+ return False
1021
+
1022
+ # === SECURITY CHECK 4: Gradient Magnitude Sanity ===
1023
+ # Reject obviously malicious gradients (too large)
1024
+ MAX_GRAD_NORM = 1000.0 # Reasonable upper bound
1025
+ total_norm = sum(g.norm().item() ** 2 for g in gradients.values()) ** 0.5
1026
+ if total_norm > MAX_GRAD_NORM:
1027
+ logger.warning(
1028
+ f"[SWARM] REJECTED gradient from {contribution.node_id[:8]}... - "
1029
+ f"gradient norm {total_norm:.2f} exceeds max {MAX_GRAD_NORM}"
1030
+ )
1031
+ return False
1032
+
1033
+ # Store for aggregation
1034
+ self._received_peer_grads[contribution.node_id] = {
1035
+ 'gradients': gradients,
1036
+ 'batch_size': getattr(contribution, 'batch_size', 1),
1037
+ 'loss': getattr(contribution, 'loss', float('inf')),
1038
+ 'timestamp': getattr(contribution, 'timestamp', time.time()),
1039
+ }
1040
+
1041
+ logger.info(f"[SWARM] Accepted gradient from {contribution.node_id[:8]}... "
1042
+ f"(round={getattr(contribution, 'round_id', '?')}, "
1043
+ f"loss={getattr(contribution, 'loss', 0):.4f}, "
1044
+ f"norm={total_norm:.2f})")
1045
+
1046
+ return True
1047
+
1048
+ except Exception as e:
1049
+ logger.error(f"[SWARM] Failed to process peer gradient: {e}")
1050
+ import traceback
1051
+ logger.error(traceback.format_exc())
1052
+ return False
1053
+
1054
+ # ==================== STATS ====================
1055
+
1056
+ def get_stats(self) -> Dict[str, Any]:
1057
+ """Get combined stats from base node and swarm components."""
1058
+ # Safety check for shutdown race condition
1059
+ base_node = getattr(self, 'base_node', None)
1060
+ if not base_node:
1061
+ return {"status": "shutting_down"}
1062
+
1063
+ try:
1064
+ stats = base_node.get_stats()
1065
+ except (AttributeError, RuntimeError):
1066
+ # Node is shutting down
1067
+ return {"status": "shutting_down"}
1068
+
1069
+ # Override with swarm node's actual training values (these are updated
1070
+ # in train_step() but base_node's values are only synced on checkpoint)
1071
+ stats["total_training_rounds"] = getattr(self, '_total_training_rounds', 0)
1072
+ stats["current_loss"] = getattr(self, '_current_loss', 0.0)
1073
+
1074
+ swarm = getattr(self, 'swarm', None)
1075
+ if swarm:
1076
+ stats["swarm"] = swarm.get_stats()
1077
+ return stats
1078
+
1079
+ def get_swarm_status(self) -> Dict[str, Any]:
1080
+ """Get detailed swarm status."""
1081
+ return self.swarm_components.get_stats()
1082
+
1083
+ def get_diloco_progress(self) -> Dict[str, Any]:
1084
+ """Get DiLoCo training progress."""
1085
+ if not self.swarm_components.diloco_trainer:
1086
+ return {"enabled": False}
1087
+
1088
+ diloco = self.swarm_components.diloco_trainer
1089
+ return {
1090
+ "enabled": True,
1091
+ "inner_step_count": diloco.stats.inner_step_count,
1092
+ "inner_steps_total": diloco.config.inner_steps,
1093
+ "progress": diloco.stats.inner_step_count / diloco.config.inner_steps,
1094
+ "outer_step_count": diloco.stats.outer_step_count,
1095
+ }
1096
+
1097
+ # ==================== DELEGATION ====================
1098
+
1099
+ @property
1100
+ def grpc_addr(self):
1101
+ """Get gRPC address."""
1102
+ return self.base_node.grpc_addr
1103
+
1104
+ @property
1105
+ def data_manager(self):
1106
+ """Get data manager."""
1107
+ return self.base_node.data_manager
1108
+
1109
+ @property
1110
+ def genesis_loader(self):
1111
+ """Get genesis loader."""
1112
+ return self.base_node.genesis_loader
1113
+
1114
+ @property
1115
+ def optimizer(self):
1116
+ """Get optimizer."""
1117
+ return self.base_node.optimizer
1118
+
1119
+ def forward(self, input_ids: torch.Tensor, **kwargs):
1120
+ """Forward pass - delegates to base node."""
1121
+ return self.base_node.forward(input_ids, **kwargs)
1122
+
1123
+ def _save_checkpoint(self, async_save: bool = True):
1124
+ """
1125
+ Save checkpoint with swarm state.
1126
+
1127
+ Syncs the swarm node's training counters to base node before saving.
1128
+
1129
+ Args:
1130
+ async_save: If True, save in background thread (default).
1131
+ If False, block until save completes (for shutdown).
1132
+ """
1133
+ # Sync training state to base node before saving
1134
+ self.base_node.total_training_rounds = self._total_training_rounds
1135
+ self.base_node.current_loss = self._current_loss
1136
+
1137
+ # Delegate to base node's checkpoint saving
1138
+ return self.base_node._save_checkpoint(async_save=async_save)
1139
+
1140
+ def __getattr__(self, name):
1141
+ """Delegate unknown attributes to base node."""
1142
+ return getattr(self.base_node, name)
1143
+
1144
+
1145
+ def verify_training_work(self, proof: Any) -> Tuple[bool, str]:
1146
+ """
1147
+ Verify training work against THIS NODE's internal state.
1148
+
1149
+ Called by ProofVerifier after universal checks pass.
1150
+
1151
+ This method enforces "personal integrity" - checks that only
1152
+ the node itself can verify:
1153
+ - Did I actually do the work I'm claiming?
1154
+ - Does the model hash match MY architecture?
1155
+ - Is training actually enabled on MY node?
1156
+
1157
+ Universal checks (rate limits, required fields) are handled
1158
+ by ProofVerifier before this method is called.
1159
+
1160
+ Args:
1161
+ proof: PoNWProof object
1162
+
1163
+ Returns:
1164
+ (is_valid, reason) tuple
1165
+ """
1166
+ # =====================================================================
1167
+ # NODE-SPECIFIC CHECK 1: Model Hash Recording (NOT rejection)
1168
+ # =====================================================================
1169
+ # The model hash in a proof records WHAT architecture the work was done on.
1170
+ # In a growing network, architectures evolve - a proof from a month ago
1171
+ # was valid for THAT architecture, even if the network has since grown.
1172
+ #
1173
+ # We LOG mismatches for monitoring but DO NOT REJECT - the work was valid
1174
+ # when it was done. The hash is stored with the proof as historical metadata.
1175
+ current_hash = self._get_model_hash()
1176
+ proof_hash = getattr(proof, 'model_hash', None)
1177
+
1178
+ if proof_hash and current_hash and proof_hash != current_hash:
1179
+ # Different architecture - this is NORMAL in a growing network.
1180
+ # The proof was valid for its architecture at the time.
1181
+ logger.info(
1182
+ f"[PoNW] Proof from different architecture epoch: "
1183
+ f"proof={proof_hash[:8]}..., current={current_hash[:8]}... "
1184
+ f"(network has evolved - this is expected)"
1185
+ )
1186
+ # Continue validation - don't reject based on hash alone
1187
+
1188
+ # =====================================================================
1189
+ # NODE-SPECIFIC CHECK 2: Internal Counter Verification
1190
+ # =====================================================================
1191
+ # This is the core "Lazy Mining" prevention.
1192
+ # Only verify our OWN proofs against internal counter - we can't
1193
+ # know other nodes' internal state.
1194
+ proof_node_id = getattr(proof, 'node_id', None)
1195
+ claimed_batches = getattr(proof, 'training_batches', 0)
1196
+
1197
+ if proof_node_id == self.node_id:
1198
+ # This is OUR proof - verify against our actual work count.
1199
+ # Allow a small buffer (10) for timing/sync issues.
1200
+ buffer = 10
1201
+ if claimed_batches > (self._total_training_rounds + buffer):
1202
+ logger.warning(
1203
+ f"[PoNW] Lazy mining attempt blocked: "
1204
+ f"claimed={claimed_batches}, actual={self._total_training_rounds}"
1205
+ )
1206
+ return False, (
1207
+ f"Claimed batches ({claimed_batches}) exceeds "
1208
+ f"internal counter ({self._total_training_rounds})"
1209
+ )
1210
+
1211
+ # =====================================================================
1212
+ # NODE-SPECIFIC CHECK 3: Training Enabled Status
1213
+ # =====================================================================
1214
+ # Can't claim training rewards if training is disabled.
1215
+ if not self.enable_training and claimed_batches > 0:
1216
+ return False, "Claimed training work but training is disabled"
1217
+
1218
+ # All node-specific checks passed
1219
+ return True, "Training work verified against local state"
1220
+
1221
+ # ==================== FACTORY ====================
1222
+
1223
+ def create_swarm_node(
1224
+ node_token: str,
1225
+ port: int,
1226
+ tracker_url: str,
1227
+ config: SwarmNodeConfig,
1228
+ available_memory_mb: Optional[int] = None,
1229
+ enable_training: bool = False,
1230
+ max_storage_mb: int = 10000,
1231
+ max_cpu_threads: int = 4,
1232
+ p2p_manager: Optional[Any] = None,
1233
+ device: str = "auto",
1234
+ ) -> SwarmEnabledDynamicNode:
1235
+ """
1236
+ Factory function to create a SwarmEnabledDynamicNode.
1237
+
1238
+ This is the main entry point for creating nodes.
1239
+
1240
+ Args:
1241
+ node_token: Authentication token for the node
1242
+ port: HTTP port for the node
1243
+ tracker_url: URL of the tracker server
1244
+ config: SwarmNodeConfig with settings (REQUIRED)
1245
+ available_memory_mb: Override memory detection
1246
+ enable_training: Whether to enable training
1247
+ max_storage_mb: Max disk space for data shards
1248
+ max_cpu_threads: Max CPU threads to use
1249
+ p2p_manager: Optional P2P manager (created if not provided)
1250
+
1251
+ Returns:
1252
+ SwarmEnabledDynamicNode ready for use
1253
+ """
1254
+ # Create base DynamicNeuroNode
1255
+ # Pass P2P manager so DHT is available during layer assignment!
1256
+ base_node = create_dynamic_node(
1257
+ node_token=node_token,
1258
+ port=port,
1259
+ tracker_url=tracker_url,
1260
+ available_memory_mb=available_memory_mb,
1261
+ enable_training=enable_training,
1262
+ max_storage_mb=max_storage_mb,
1263
+ max_cpu_threads=max_cpu_threads,
1264
+ device=device,
1265
+ p2p_manager=p2p_manager, # NEW: Pass P2P for DHT discovery
1266
+ )
1267
+
1268
+ # Wrap with swarm capabilities
1269
+ swarm_node = SwarmEnabledDynamicNode(
1270
+ base_node=base_node,
1271
+ config=config,
1272
+ p2p_manager=p2p_manager,
1273
+ )
1274
+
1275
+ # INJECT INTO LEDGER: Connect the node to the ledger for PoNW verification
1276
+ # This fixes the "Lazy Mining" vulnerability by allowing the ledger to
1277
+ # verify work against the actual model state.
1278
+ if base_node.p2p_manager and base_node.p2p_manager.ledger:
1279
+ # Set the model interface for the verifier
1280
+ if hasattr(base_node.p2p_manager.ledger, 'verifier'):
1281
+ base_node.p2p_manager.ledger.verifier.model_interface = swarm_node
1282
+ logger.info("[FACTORY] Connected SwarmNode to Ledger Verifier (PoNW security active)")
1283
+
1284
+ return swarm_node
1285
+
1286
+
1287
+ # Alias for backwards compatibility and clarity
1288
+ create_swarm_node_with_p2p = create_swarm_node