nexaroa 0.0.111__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (78) hide show
  1. neuroshard/__init__.py +93 -0
  2. neuroshard/__main__.py +4 -0
  3. neuroshard/cli.py +466 -0
  4. neuroshard/core/__init__.py +92 -0
  5. neuroshard/core/consensus/verifier.py +252 -0
  6. neuroshard/core/crypto/__init__.py +20 -0
  7. neuroshard/core/crypto/ecdsa.py +392 -0
  8. neuroshard/core/economics/__init__.py +52 -0
  9. neuroshard/core/economics/constants.py +387 -0
  10. neuroshard/core/economics/ledger.py +2111 -0
  11. neuroshard/core/economics/market.py +975 -0
  12. neuroshard/core/economics/wallet.py +168 -0
  13. neuroshard/core/governance/__init__.py +74 -0
  14. neuroshard/core/governance/proposal.py +561 -0
  15. neuroshard/core/governance/registry.py +545 -0
  16. neuroshard/core/governance/versioning.py +332 -0
  17. neuroshard/core/governance/voting.py +453 -0
  18. neuroshard/core/model/__init__.py +30 -0
  19. neuroshard/core/model/dynamic.py +4186 -0
  20. neuroshard/core/model/llm.py +905 -0
  21. neuroshard/core/model/registry.py +164 -0
  22. neuroshard/core/model/scaler.py +387 -0
  23. neuroshard/core/model/tokenizer.py +568 -0
  24. neuroshard/core/network/__init__.py +56 -0
  25. neuroshard/core/network/connection_pool.py +72 -0
  26. neuroshard/core/network/dht.py +130 -0
  27. neuroshard/core/network/dht_plan.py +55 -0
  28. neuroshard/core/network/dht_proof_store.py +516 -0
  29. neuroshard/core/network/dht_protocol.py +261 -0
  30. neuroshard/core/network/dht_service.py +506 -0
  31. neuroshard/core/network/encrypted_channel.py +141 -0
  32. neuroshard/core/network/nat.py +201 -0
  33. neuroshard/core/network/nat_traversal.py +695 -0
  34. neuroshard/core/network/p2p.py +929 -0
  35. neuroshard/core/network/p2p_data.py +150 -0
  36. neuroshard/core/swarm/__init__.py +106 -0
  37. neuroshard/core/swarm/aggregation.py +729 -0
  38. neuroshard/core/swarm/buffers.py +643 -0
  39. neuroshard/core/swarm/checkpoint.py +709 -0
  40. neuroshard/core/swarm/compute.py +624 -0
  41. neuroshard/core/swarm/diloco.py +844 -0
  42. neuroshard/core/swarm/factory.py +1288 -0
  43. neuroshard/core/swarm/heartbeat.py +669 -0
  44. neuroshard/core/swarm/logger.py +487 -0
  45. neuroshard/core/swarm/router.py +658 -0
  46. neuroshard/core/swarm/service.py +640 -0
  47. neuroshard/core/training/__init__.py +29 -0
  48. neuroshard/core/training/checkpoint.py +600 -0
  49. neuroshard/core/training/distributed.py +1602 -0
  50. neuroshard/core/training/global_tracker.py +617 -0
  51. neuroshard/core/training/production.py +276 -0
  52. neuroshard/governance_cli.py +729 -0
  53. neuroshard/grpc_server.py +895 -0
  54. neuroshard/runner.py +3223 -0
  55. neuroshard/sdk/__init__.py +92 -0
  56. neuroshard/sdk/client.py +990 -0
  57. neuroshard/sdk/errors.py +101 -0
  58. neuroshard/sdk/types.py +282 -0
  59. neuroshard/tracker/__init__.py +0 -0
  60. neuroshard/tracker/server.py +864 -0
  61. neuroshard/ui/__init__.py +0 -0
  62. neuroshard/ui/app.py +102 -0
  63. neuroshard/ui/templates/index.html +1052 -0
  64. neuroshard/utils/__init__.py +0 -0
  65. neuroshard/utils/autostart.py +81 -0
  66. neuroshard/utils/hardware.py +121 -0
  67. neuroshard/utils/serialization.py +90 -0
  68. neuroshard/version.py +1 -0
  69. nexaroa-0.0.111.dist-info/METADATA +283 -0
  70. nexaroa-0.0.111.dist-info/RECORD +78 -0
  71. nexaroa-0.0.111.dist-info/WHEEL +5 -0
  72. nexaroa-0.0.111.dist-info/entry_points.txt +4 -0
  73. nexaroa-0.0.111.dist-info/licenses/LICENSE +190 -0
  74. nexaroa-0.0.111.dist-info/top_level.txt +2 -0
  75. protos/__init__.py +0 -0
  76. protos/neuroshard.proto +651 -0
  77. protos/neuroshard_pb2.py +160 -0
  78. protos/neuroshard_pb2_grpc.py +1298 -0
@@ -0,0 +1,4186 @@
1
+ """
2
+ Dynamic Model Architecture - True Decentralization
3
+
4
+ This module implements a model that grows and shrinks with the network:
5
+ - NO fixed phases or model sizes
6
+ - Model size = what the network can collectively hold
7
+ - Each node contributes based on its available memory
8
+ - More memory = more layers = more NEURO rewards
9
+
10
+ Key Concepts:
11
+ 1. LAYER POOL: The network maintains a pool of layers
12
+ 2. DYNAMIC ASSIGNMENT: Nodes claim layers based on their capacity
13
+ 3. ORGANIC GROWTH: As more nodes join, model can have more layers
14
+ 4. GRACEFUL DEGRADATION: If nodes leave, layers are redistributed
15
+
16
+ Example:
17
+ Day 1: 10 nodes with 4GB each = 40GB total = ~10B params possible
18
+ Day 30: 100 nodes with avg 8GB = 800GB total = ~200B params possible
19
+
20
+ The model AUTOMATICALLY grows as capacity grows.
21
+ No voting, no phases, no central coordination.
22
+ """
23
+
24
+ import torch
25
+ import torch.nn as nn
26
+ import threading
27
+ import time
28
+ import logging
29
+ import hashlib
30
+ import math
31
+ import psutil # For adaptive memory management
32
+ from typing import Optional, Dict, List, Tuple, Any, Set
33
+ from dataclasses import dataclass, field
34
+ from collections import defaultdict
35
+ from urllib.parse import urlparse
36
+
37
+ logger = logging.getLogger(__name__)
38
+
39
+
40
+ # Dynamic architecture - NO MORE FIXED DIMENSIONS!
41
+ # Architecture is now calculated based on network capacity
42
+
43
+ # Dynamic vocabulary - starts at 32K, expands WITHOUT LIMIT as network grows
44
+ # The embedding/lm_head grow in chunks when tokenizer vocabulary exceeds current capacity
45
+ INITIAL_VOCAB_SIZE = 32000 # Starting size (efficient for small networks)
46
+ VOCAB_GROWTH_CHUNK = 32000 # Expand by 32K at a time (efficient GPU memory alignment)
47
+
48
+ # NO HARD LIMIT - vocabulary grows with the network
49
+ # The only real constraints are:
50
+ # - Memory: ~4KB per token (at hidden_dim=1024) or ~16KB (at hidden_dim=4096)
51
+ # - Practical: Most use cases covered under 1M tokens
52
+ # For reference: GPT-4 ~100K, Claude ~100K, Gemini ~256K
53
+ # NeuroShard can grow FAR beyond these as a truly decentralized, ever-growing LLM
54
+ MAX_VOCAB_SIZE = None # None = unlimited (constrained only by available memory)
55
+
56
+ # Import the new architecture scaler
57
+ from neuroshard.core.model.scaler import (
58
+ ModelArchitecture,
59
+ calculate_optimal_architecture,
60
+ should_upgrade_architecture,
61
+ estimate_memory_per_layer,
62
+ calculate_layer_assignment,
63
+ )
64
+
65
+
66
+ @dataclass
67
+ class LayerAssignment:
68
+ """Assignment of a layer to a node."""
69
+ layer_id: int
70
+ node_id: str
71
+ node_url: str
72
+ grpc_addr: str
73
+ assigned_at: float = field(default_factory=time.time)
74
+ last_heartbeat: float = field(default_factory=time.time)
75
+ version: int = 0 # Training version
76
+
77
+
78
+ @dataclass
79
+ class NetworkCapacity:
80
+ """Current network capacity."""
81
+ total_nodes: int
82
+ total_memory_mb: float
83
+ max_layers: int # How many layers the network can support
84
+ assigned_layers: int # How many layers are currently assigned
85
+ layer_coverage: Dict[int, int] # layer_id -> replica count
86
+
87
+
88
+ class DynamicLayerPool:
89
+ """
90
+ Manages the dynamic pool of model layers across the network.
91
+
92
+ This is the core of true decentralization:
93
+ - Layers are assigned based on node capacity
94
+ - Model grows BOTH deeper AND wider as network expands
95
+ - Architecture auto-optimizes based on scaling laws
96
+ - No fixed model size or phases
97
+
98
+ SCALABILITY CONSIDERATIONS:
99
+ ==========================
100
+ Small network (1-10 nodes):
101
+ - Each node may hold ALL layers (solo training mode)
102
+ - No layer replication needed
103
+ - Fast startup, immediate training
104
+
105
+ Medium network (10-100 nodes):
106
+ - Layers distributed across multiple nodes
107
+ - MIN_REPLICAS ensures redundancy
108
+ - Pipeline inference works across nodes
109
+
110
+ Large network (100-1000+ nodes):
111
+ - Strong layer distribution
112
+ - MAX_LAYERS_PER_NODE caps per-node load
113
+ - Architecture can scale to 100B+ params
114
+ """
115
+
116
+ # Minimum replicas per layer for redundancy
117
+ MIN_REPLICAS = 2
118
+
119
+ # Layer heartbeat timeout
120
+ HEARTBEAT_TIMEOUT = 120 # seconds
121
+
122
+ # Maximum layers any single node can hold (prevents memory issues in large networks)
123
+ # In small networks (< 100 nodes), this is effectively unlimited
124
+ # In large networks, it ensures load is distributed
125
+ MAX_LAYERS_PER_NODE = 200
126
+
127
+ # Architecture recalculation triggers
128
+ # NOTE: RECALC_INTERVAL_NODES is now DYNAMIC - see _get_recalc_interval()
129
+ MIN_UPGRADE_IMPROVEMENT = 1.3 # Only upgrade if 30%+ better
130
+
131
+ @staticmethod
132
+ def _get_recalc_interval(node_count: int) -> int:
133
+ """
134
+ Get dynamic architecture recalculation interval based on network size.
135
+
136
+ At small networks, recalculate more often (every node matters).
137
+ At large networks, recalculate less often (stability).
138
+
139
+ Formula: min(max(5, node_count // 10), 100)
140
+ - 1-50 nodes: every 5 nodes
141
+ - 51-100 nodes: every 5-10 nodes
142
+ - 100-1000 nodes: every 10-100 nodes
143
+ - 1000+ nodes: every 100 nodes
144
+ """
145
+ return min(max(5, node_count // 10), 100)
146
+
147
+ def __init__(self, dht_protocol=None):
148
+ self.dht = dht_protocol
149
+
150
+ # Layer assignments
151
+ self.layer_assignments: Dict[int, List[LayerAssignment]] = defaultdict(list)
152
+
153
+ # Node capacities
154
+ self.node_capacities: Dict[str, float] = {} # node_id -> available_mb
155
+
156
+ # DYNAMIC VOCAB: Track current vocabulary capacity for memory calculation
157
+ # This is updated when vocab expands and affects layer assignment
158
+ self.vocab_capacity: int = INITIAL_VOCAB_SIZE
159
+
160
+ # DYNAMIC ARCHITECTURE (auto-updates as network grows)
161
+ self.current_architecture: Optional[ModelArchitecture] = None
162
+ self.architecture_version: int = 0
163
+ self.last_node_count: int = 0
164
+
165
+ # Legacy fields (for compatibility)
166
+ self.current_num_layers = 0
167
+ self.embedding_holder: Optional[str] = None
168
+ self.lm_head_holder: Optional[str] = None
169
+
170
+ # Threading
171
+ self.lock = threading.Lock()
172
+
173
+ logger.info("DynamicLayerPool initialized with dynamic width + depth scaling")
174
+
175
+ def _auto_recalculate_architecture(self):
176
+ """
177
+ AUTOMATED architecture optimization - no human intervention needed.
178
+
179
+ Calculates optimal architecture based on current network capacity
180
+ and triggers upgrade if improvement is significant.
181
+ """
182
+ total_memory = sum(self.node_capacities.values())
183
+ optimal = calculate_optimal_architecture(total_memory)
184
+
185
+ if self.current_architecture is None:
186
+ # First initialization
187
+ self.current_architecture = optimal
188
+ self.current_num_layers = optimal.num_layers
189
+ self.architecture_version = 1
190
+ logger.info(f"🚀 Initial architecture: {optimal.num_layers}L × {optimal.hidden_dim}H "
191
+ f"({optimal.estimate_params()/1e6:.0f}M params)")
192
+ return
193
+
194
+ # Check if upgrade is worthwhile
195
+ should_upgrade, reason = should_upgrade_architecture(
196
+ self.current_architecture,
197
+ optimal,
198
+ self.MIN_UPGRADE_IMPROVEMENT
199
+ )
200
+
201
+ if should_upgrade:
202
+ logger.warning(f"🔄 ARCHITECTURE UPGRADE TRIGGERED!")
203
+ logger.warning(f" {reason}")
204
+ logger.warning(f" Old: {self.current_architecture.num_layers}L × {self.current_architecture.hidden_dim}H")
205
+ logger.warning(f" New: {optimal.num_layers}L × {optimal.hidden_dim}H")
206
+ logger.warning(f" Nodes will gradually migrate to new architecture on restart")
207
+
208
+ # Update architecture (new nodes will use new arch)
209
+ self.current_architecture = optimal
210
+ self.current_num_layers = optimal.num_layers
211
+ self.architecture_version += 1
212
+
213
+ # TODO: Trigger distillation-based migration for existing nodes
214
+ # For now, existing nodes keep their architecture until restart
215
+ else:
216
+ logger.debug(f"Architecture recalculation: no upgrade needed ({reason})")
217
+
218
+ def register_node(
219
+ self,
220
+ node_id: str,
221
+ node_url: str,
222
+ grpc_addr: str,
223
+ available_memory_mb: float,
224
+ staked_amount: float = 0.0
225
+ ) -> List[int]:
226
+ """
227
+ Register a node and assign layers based on its capacity AND stake.
228
+
229
+ AUTOMATIC ARCHITECTURE SCALING:
230
+ - Periodically recalculates optimal architecture as network grows
231
+ - Triggers upgrades when capacity increases significantly
232
+ - New nodes automatically use latest architecture
233
+
234
+ Validator role requires:
235
+ 1. Sufficient memory (>2GB)
236
+ 2. Minimum stake (100 NEURO) - prevents Sybil attacks
237
+
238
+ Returns list of layer IDs assigned to this node.
239
+ """
240
+ # Import validator requirements from centralized economics (with dynamic scaling!)
241
+ from neuroshard.core.economics.constants import (
242
+ VALIDATOR_MIN_MEMORY_MB,
243
+ get_dynamic_validator_stake
244
+ )
245
+
246
+ with self.lock:
247
+ self.node_capacities[node_id] = available_memory_mb
248
+
249
+ # AUTO-TRIGGER: Recalculate architecture if network grew significantly
250
+ node_count = len(self.node_capacities)
251
+ recalc_interval = self._get_recalc_interval(node_count)
252
+ if (node_count - self.last_node_count) >= recalc_interval:
253
+ self._auto_recalculate_architecture()
254
+ self.last_node_count = node_count
255
+
256
+ # Ensure we have an architecture
257
+ if self.current_architecture is None:
258
+ self._auto_recalculate_architecture()
259
+
260
+ # Calculate how many layers this node can hold
261
+ # Uses current architecture's dimensions (dynamic!)
262
+ # DEVICE-AWARE safety factors
263
+ # With gradient checkpointing always enabled, CPU can use higher factor
264
+ device_type = getattr(self, '_device_hint', 'cpu')
265
+ if device_type == 'cuda':
266
+ safety_factor = 0.6 # GPU: efficient memory usage
267
+ elif device_type == 'mps':
268
+ safety_factor = 0.5 # Apple Silicon: moderate overhead
269
+ else:
270
+ safety_factor = 0.5 # CPU: increased from 0.3 (checkpointing reduces overhead)
271
+
272
+ # SMART DISTRIBUTED TRAINING DESIGN:
273
+ # Goal: Enable meaningful training for ANY network size
274
+ #
275
+ # KEY INSIGHT: When a "full node" exists (has embedding + LM head),
276
+ # new nodes should NOT become Drivers (would create broken overlap).
277
+ # Instead, they should become WORKER+VALIDATOR to enable proper pipeline.
278
+ #
279
+ # Network States:
280
+ # 1. Empty network → First node becomes DRIVER, grows to full node
281
+ # 2. One full node exists → New node becomes WORKER+VALIDATOR (pipeline!)
282
+ # 3. Multiple partial nodes → Fill based on MIN_REPLICAS
283
+
284
+ # FULLY DECENTRALIZED: Discover network state from DHT ONLY (no tracker fallback!)
285
+ # Each node has its own LOCAL layer_pool, so we need DHT for network-wide view.
286
+ # BUT: DHT may have STALE data from previous runs - don't blindly trust it!
287
+ dht_layers = set()
288
+
289
+ # DHT discovery (P2P must be connected BEFORE start() for this to work!)
290
+ if self.dht:
291
+ dht_layers = self._discover_network_layers_from_dht()
292
+ if dht_layers:
293
+ highest_layer = max(dht_layers)
294
+ # SANITY CHECK: Only expand if DHT layers are "reasonable"
295
+ max_reasonable = max(32, self.current_num_layers * 2)
296
+ if highest_layer >= self.current_num_layers and highest_layer < max_reasonable:
297
+ self.current_num_layers = highest_layer + 1
298
+ logger.info(f"DHT discovery: network has {self.current_num_layers} layers")
299
+ elif highest_layer >= max_reasonable:
300
+ logger.warning(f"DHT shows {highest_layer + 1} layers but seems stale")
301
+ logger.warning(f"Ignoring stale DHT data - will use checkpoint layer count")
302
+ else:
303
+ logger.info("DHT: No existing layers found - this may be first node or peers not yet discovered")
304
+ else:
305
+ logger.warning("DHT not available - layer assignment will be solo mode")
306
+
307
+ driver_count = len(self.layer_assignments.get(0, []))
308
+ validator_layer = max(0, self.current_num_layers - 1) if self.current_num_layers > 0 else 0
309
+ validator_count = len(self.layer_assignments.get(validator_layer, []))
310
+
311
+ # CRITICAL: Check if a FULL NODE exists (FULLY DECENTRALIZED - DHT only!)
312
+ # A full node has BOTH layer 0 (embedding) AND last layer (LM head)
313
+ has_full_node = False
314
+ full_node_id = None
315
+
316
+ # Check local assignments first
317
+ if self.current_num_layers > 0:
318
+ layer_0_holders = {a.node_id for a in self.layer_assignments.get(0, [])}
319
+ last_layer_holders = {a.node_id for a in self.layer_assignments.get(validator_layer, [])}
320
+ full_node_ids = layer_0_holders & last_layer_holders # Intersection
321
+ if full_node_ids:
322
+ has_full_node = True
323
+ full_node_id = next(iter(full_node_ids))
324
+ logger.info(f"Full node detected (local): {full_node_id[:8]}... (has layers 0-{validator_layer})")
325
+
326
+ # Also check DHT - if both layer 0 AND last layer exist, a full node likely exists
327
+ if not has_full_node and dht_layers:
328
+ if 0 in dht_layers and validator_layer in dht_layers:
329
+ has_full_node = True
330
+ full_node_id = "unknown_from_dht"
331
+ logger.info(f"Full node detected (DHT): network has layers 0-{validator_layer}")
332
+
333
+ # ROLE ASSIGNMENT PRIORITY (redesigned for proper distributed training):
334
+ #
335
+ # If NO full node exists:
336
+ # 1. First node → DRIVER (will grow to full node)
337
+ # 2. Additional nodes → fill based on MIN_REPLICAS
338
+ #
339
+ # If a FULL NODE exists:
340
+ # - DON'T create another Driver (would overlap and break training!)
341
+ # - New nodes become WORKER+VALIDATOR for proper pipeline
342
+ # - This creates: FullNode[0-N] → NewNode[N+1 to Last] pipeline
343
+
344
+ if has_full_node:
345
+ # A full node exists - new nodes should NOT overlap with it
346
+ # Become WORKER+VALIDATOR to enable pipeline training!
347
+ if node_id == full_node_id:
348
+ # This IS the full node re-registering
349
+ role_hint = "DRIVER" # Keep as driver (already full)
350
+ needs_embedding = True
351
+ logger.info(f"This is the full node - keeping as DRIVER")
352
+ else:
353
+ # New node joining a network with a full node
354
+ # Become WORKER+VALIDATOR: get last layer + work backwards
355
+ # This enables pipeline: FullNode[embedding→layers] → Us[layers→LM_head]
356
+ role_hint = "WORKER+VALIDATOR"
357
+ needs_embedding = False # Don't need embedding, saves memory!
358
+ logger.info(f"Full node exists - becoming WORKER+VALIDATOR for pipeline training")
359
+ else:
360
+ # No full node - use standard role assignment
361
+ needs_more_drivers = driver_count < self.MIN_REPLICAS
362
+ needs_more_validators = validator_count < self.MIN_REPLICAS
363
+
364
+ if needs_more_drivers:
365
+ # Network needs more Drivers for MIN_REPLICAS
366
+ role_hint = "DRIVER"
367
+ needs_embedding = True
368
+ logger.info(f"Network needs more Drivers ({driver_count}/{self.MIN_REPLICAS})")
369
+ elif needs_more_validators and self.current_num_layers > 1:
370
+ # Network needs more Validators for MIN_REPLICAS
371
+ role_hint = "WORKER+VALIDATOR"
372
+ needs_embedding = False
373
+ logger.info(f"Network needs more Validators ({validator_count}/{self.MIN_REPLICAS})")
374
+ else:
375
+ # Network has enough of both → become Worker for layer redundancy
376
+ role_hint = "WORKER"
377
+ needs_embedding = False
378
+ logger.info(f"Network has enough Drivers ({driver_count}) and Validators ({validator_count})")
379
+
380
+ max_layers_for_node = calculate_layer_assignment(
381
+ available_memory_mb,
382
+ self.current_architecture,
383
+ safety_factor=safety_factor,
384
+ vocab_capacity=self.vocab_capacity,
385
+ training_mode=True,
386
+ needs_embedding=needs_embedding
387
+ )
388
+
389
+ logger.info(f"Layer calculation ({role_hint}): {available_memory_mb:.0f}MB × {safety_factor} safety = {max_layers_for_node} layers "
390
+ f"(needs_embedding={needs_embedding}, drivers={driver_count}, validators={validator_count})")
391
+
392
+ # SCALABILITY: Apply MAX_LAYERS_PER_NODE cap in large networks
393
+ # This prevents single nodes from hogging all layers and ensures
394
+ # load distribution as the network grows
395
+ node_count = len(self.node_capacities)
396
+ if node_count > 100:
397
+ # In large networks, cap layers per node
398
+ max_layers_for_node = min(max_layers_for_node, self.MAX_LAYERS_PER_NODE)
399
+ logger.debug(f"Large network ({node_count} nodes): capped to {max_layers_for_node} layers")
400
+
401
+ if max_layers_for_node < 1:
402
+ logger.warning(f"Node {node_id[:8]}... has insufficient memory for even 1 layer")
403
+ return []
404
+
405
+ # Find layers that need more replicas
406
+ assigned_layers = []
407
+
408
+ # SCALABILITY STRATEGY:
409
+ # High-capacity nodes (>8GB) are prioritized for Layer 0 (Driver) and Last Layer (Validator)
410
+ # This creates parallel pipelines ("Training Gangs")
411
+ is_high_capacity = available_memory_mb > 8000
412
+ is_medium_capacity = available_memory_mb > VALIDATOR_MIN_MEMORY_MB
413
+
414
+ # Count current validators for DYNAMIC stake requirement
415
+ num_drivers = len(self.layer_assignments[0])
416
+ num_validators = len(self.layer_assignments[max(0, self.current_num_layers - 1)])
417
+
418
+ # VALIDATOR ELIGIBILITY: Dynamic stake based on network size!
419
+ # - Few validators (1-10): 100 NEURO (accessible for bootstrap)
420
+ # - Many validators (1000+): 2500 NEURO (security at scale)
421
+ required_validator_stake = get_dynamic_validator_stake(num_validators)
422
+ has_validator_stake = staked_amount >= required_validator_stake
423
+
424
+ # DISTRIBUTED TRAINING ASSIGNMENT:
425
+ # Based on role_hint from capacity calculation, assign appropriate layers
426
+
427
+ last_layer = max(0, self.current_num_layers - 1) if self.current_num_layers > 0 else 0
428
+
429
+ # Determine role based on network needs and node capacity
430
+ if role_hint == "DRIVER":
431
+ # Become Driver: get Layer 0 + as many layers as we can
432
+ should_be_driver = True
433
+ should_be_validator = False # Driver doesn't need to also be Validator
434
+ logger.info(f"Node {node_id[:8]}... assigned as DRIVER (first node, will embed data)")
435
+
436
+ elif role_hint == "WORKER+VALIDATOR":
437
+ # Become Worker+Validator: skip Layer 0, get middle + last layer
438
+ # This is the KEY for distributed training - enables pipeline with loss computation
439
+ should_be_driver = False
440
+ should_be_validator = True # Need LM head to compute loss!
441
+ # For bootstrap, allow Validator without stake requirement
442
+ has_validator_stake = True # Bootstrap mode - stake checked later
443
+ logger.info(f"Node {node_id[:8]}... assigned as WORKER+VALIDATOR (will compute loss, enable distributed training)")
444
+
445
+ else: # "WORKER"
446
+ # Become Worker: middle layers only (redundancy)
447
+ should_be_driver = False
448
+ should_be_validator = False
449
+ logger.info(f"Node {node_id[:8]}... assigned as WORKER (middle layers, redundancy)")
450
+
451
+ # Assign Layer 0 (Driver) - if we should be Driver
452
+ if should_be_driver and len(assigned_layers) < max_layers_for_node:
453
+ if not any(a.node_id == node_id for a in self.layer_assignments[0]):
454
+ self._assign_layer(0, node_id, node_url, grpc_addr)
455
+ assigned_layers.append(0)
456
+ # Ensure current_num_layers accounts for layer 0
457
+ if self.current_num_layers == 0:
458
+ self.current_num_layers = 1
459
+
460
+ # 2. DISTRIBUTED LAYER ASSIGNMENT
461
+ # Based on role, fill appropriate layers for pipeline parallelism
462
+ is_driver = 0 in assigned_layers
463
+ last_layer = max(0, self.current_num_layers - 1)
464
+
465
+ if role_hint == "WORKER+VALIDATOR":
466
+ # CRITICAL: Assign LAST layer FIRST (for LM head) to ensure we can compute loss
467
+ if len(assigned_layers) < max_layers_for_node:
468
+ if last_layer not in assigned_layers and not any(a.node_id == node_id for a in self.layer_assignments[last_layer]):
469
+ self._assign_layer(last_layer, node_id, node_url, grpc_addr)
470
+ assigned_layers.append(last_layer)
471
+ logger.info(f"Node {node_id[:8]}... assigned last layer {last_layer} (Validator role)")
472
+
473
+ # Then fill middle layers from the END (closer to Validator)
474
+ # This creates contiguous layer ranges: Driver has 0-N, Validator has M-31
475
+ for layer_id in range(last_layer - 1, 0, -1): # Reverse order, skip layer 0
476
+ if len(assigned_layers) >= max_layers_for_node:
477
+ break
478
+ if layer_id not in assigned_layers and not any(a.node_id == node_id for a in self.layer_assignments[layer_id]):
479
+ self._assign_layer(layer_id, node_id, node_url, grpc_addr)
480
+ assigned_layers.append(layer_id)
481
+ else:
482
+ # Driver or Worker: fill from layer 1 onwards
483
+ layers_to_check = range(1, self.current_num_layers)
484
+
485
+ for layer_id in layers_to_check:
486
+ if len(assigned_layers) >= max_layers_for_node:
487
+ break
488
+
489
+ current_replicas = len(self.layer_assignments[layer_id])
490
+ if current_replicas < self.MIN_REPLICAS:
491
+ if layer_id not in assigned_layers and not any(a.node_id == node_id for a in self.layer_assignments[layer_id]):
492
+ self._assign_layer(layer_id, node_id, node_url, grpc_addr)
493
+ assigned_layers.append(layer_id)
494
+
495
+ # 3. Assign Last Layer (Validator) if Driver with extra capacity
496
+ if should_be_validator and len(assigned_layers) < max_layers_for_node:
497
+ if last_layer not in assigned_layers and not any(a.node_id == node_id for a in self.layer_assignments[last_layer]):
498
+ self._assign_layer(last_layer, node_id, node_url, grpc_addr)
499
+ assigned_layers.append(last_layer)
500
+
501
+ # 4. Calculate remaining capacity
502
+ remaining_capacity = max_layers_for_node - len(assigned_layers)
503
+
504
+ # 5. DHT layer discovery already done above (at role assignment)
505
+ # dht_layers variable is already populated
506
+
507
+ # 6. If we still have capacity, grow the model
508
+ if remaining_capacity > 0:
509
+ # CAP MODEL GROWTH for solo/early network as safety net
510
+ # Even with gradient checkpointing, too many layers cause OOM due to:
511
+ # - Optimizer states (2x model size for Adam)
512
+ # - Gradient accumulation during backward pass
513
+ # - PyTorch memory fragmentation
514
+ # 32 layers is safe for most devices (~200M params with 512 hidden)
515
+ MAX_SOLO_LAYERS = 32 # Conservative cap for solo node training
516
+ total_nodes = len(set(
517
+ a.node_id for assignments in self.layer_assignments.values()
518
+ for a in assignments
519
+ ))
520
+
521
+ if total_nodes <= 2: # Solo or near-solo mode
522
+ max_growth = max(0, MAX_SOLO_LAYERS - len(assigned_layers))
523
+ if remaining_capacity > max_growth:
524
+ logger.warning(f"[SOLO MODE] Capping growth from {remaining_capacity} to {max_growth} layers "
525
+ f"(MAX_SOLO_LAYERS={MAX_SOLO_LAYERS} prevents OOM)")
526
+ remaining_capacity = max_growth
527
+
528
+ # Add new layers to grow the model
529
+ new_layers = self._grow_model(remaining_capacity, node_id, node_url, grpc_addr)
530
+ assigned_layers.extend(new_layers)
531
+
532
+ # Handle embedding and LM head tracking
533
+ # Any node with Layer 0 has embedding
534
+ if 0 in assigned_layers:
535
+ # Update tracking (just keeps one for reference, but multiple exist)
536
+ self.embedding_holder = node_id
537
+ logger.info(f"Node {node_id[:8]}... became a Driver (Layer 0)")
538
+
539
+ # Any node with highest assigned layer becomes Validator (LM head holder)
540
+ # CRITICAL FIX: Use max(assigned_layers), NOT current_num_layers - 1
541
+ # This handles the case where checkpoint has fewer layers than stale DHT data
542
+ if assigned_layers:
543
+ actual_last_layer = max(assigned_layers)
544
+ # Check if this node holds the highest layer in the network
545
+ # OR if there's no other holder for this layer yet
546
+ if not self.lm_head_holder or actual_last_layer >= (self.current_num_layers - 1):
547
+ self.lm_head_holder = node_id
548
+ self.current_num_layers = actual_last_layer + 1 # Update to match reality
549
+ logger.info(f"Node {node_id[:8]}... became a Validator (Layer {actual_last_layer})")
550
+
551
+ # EARLY NETWORK NOTICE: When there are <10 nodes, each must hold many/all layers
552
+ # This is TEMPORARY - as more nodes join, layers will be distributed
553
+ if len(assigned_layers) > 50:
554
+ logger.warning(f"Node {node_id[:8]}... holding {len(assigned_layers)} layers due to low network size")
555
+ logger.warning(f"This is temporary - as more nodes join, model will shard across network")
556
+
557
+ logger.info(f"Node {node_id[:8]}... registered: {len(assigned_layers)} layers assigned "
558
+ f"(capacity: {max_layers_for_node} layers, {available_memory_mb:.0f}MB)")
559
+
560
+ return assigned_layers
561
+
562
+ def _assign_layer(self, layer_id: int, node_id: str, node_url: str, grpc_addr: str):
563
+ """Assign a layer to a node."""
564
+ assignment = LayerAssignment(
565
+ layer_id=layer_id,
566
+ node_id=node_id,
567
+ node_url=node_url,
568
+ grpc_addr=grpc_addr,
569
+ )
570
+ self.layer_assignments[layer_id].append(assignment)
571
+
572
+ # Announce to DHT
573
+ if self.dht:
574
+ try:
575
+ import json
576
+ key = f"layer_{layer_id}"
577
+ current = self.dht.lookup_value(key)
578
+ holders = json.loads(current) if current else []
579
+ if grpc_addr not in holders:
580
+ holders.append(grpc_addr)
581
+ self.dht.store(key, json.dumps(holders))
582
+ except Exception as e:
583
+ logger.debug(f"DHT announce failed: {e}")
584
+
585
+ def _discover_network_layers_from_dht(self) -> Set[int]:
586
+ """
587
+ Query DHT to discover which layers exist in the network.
588
+
589
+ DECENTRALIZED COORDINATION:
590
+ - Each node announces "layer_X" to DHT when it holds layer X
591
+ - New nodes query DHT to see what layers already exist
592
+ - This prevents layer overlap without centralized coordination
593
+ """
594
+ discovered_layers = set()
595
+
596
+ if not self.dht:
597
+ return discovered_layers
598
+
599
+ try:
600
+ # Query for layers 0-1000 (reasonable max)
601
+ # DHT lookup is fast - O(log N) hops
602
+ for layer_id in range(min(1000, self.current_num_layers + 100)):
603
+ key = f"layer_{layer_id}"
604
+ try:
605
+ value = self.dht.lookup_value(key)
606
+ if value:
607
+ discovered_layers.add(layer_id)
608
+ except Exception:
609
+ continue
610
+
611
+ if discovered_layers:
612
+ logger.info(f"DHT layer discovery: found {len(discovered_layers)} layers "
613
+ f"(range: {min(discovered_layers)}-{max(discovered_layers)})")
614
+ except Exception as e:
615
+ logger.debug(f"DHT layer discovery failed: {e}")
616
+
617
+ return discovered_layers
618
+
619
+ def _grow_model(
620
+ self,
621
+ num_new_layers: int,
622
+ node_id: str,
623
+ node_url: str,
624
+ grpc_addr: str
625
+ ) -> List[int]:
626
+ """
627
+ Grow the model by adding new layers.
628
+
629
+ This is how the model organically grows with the network!
630
+ """
631
+ new_layers = []
632
+
633
+ for _ in range(num_new_layers):
634
+ new_layer_id = self.current_num_layers
635
+ self._assign_layer(new_layer_id, node_id, node_url, grpc_addr)
636
+ new_layers.append(new_layer_id)
637
+ self.current_num_layers += 1
638
+
639
+ if new_layers:
640
+ logger.info(f"Model grew: now {self.current_num_layers} layers "
641
+ f"(added layers {new_layers[0]}-{new_layers[-1]})")
642
+
643
+ return new_layers
644
+
645
+ def upgrade_to_validator(self, node_id: str, node_url: str, grpc_addr: str) -> bool:
646
+ """
647
+ Upgrade a node to Validator role (assign LM head) when stake requirement is met.
648
+
649
+ This is called when a node stakes enough NEURO to become a Validator.
650
+ No restart required - the node's role is upgraded dynamically.
651
+
652
+ Returns True if upgrade was successful.
653
+ """
654
+ from neuroshard.core.economics.constants import VALIDATOR_MIN_MEMORY_MB
655
+
656
+ with self.lock:
657
+ # Check if node has sufficient memory
658
+ memory = self.node_capacities.get(node_id, 0)
659
+ if memory < VALIDATOR_MIN_MEMORY_MB:
660
+ logger.warning(f"Node {node_id[:8]}... cannot be Validator: insufficient memory ({memory}MB)")
661
+ return False
662
+
663
+ # Check if already a validator
664
+ last_layer = max(0, self.current_num_layers - 1)
665
+ if any(a.node_id == node_id for a in self.layer_assignments[last_layer]):
666
+ logger.info(f"Node {node_id[:8]}... is already a Validator")
667
+ return True
668
+
669
+ # Assign the last layer (LM head)
670
+ self._assign_layer(last_layer, node_id, node_url, grpc_addr)
671
+ self.lm_head_holder = node_id
672
+
673
+ logger.info(f"Node {node_id[:8]}... upgraded to VALIDATOR (assigned layer {last_layer})")
674
+ return True
675
+
676
+ def demote_from_validator(self, node_id: str) -> bool:
677
+ """
678
+ Demote a node from Validator role when stake drops below requirement.
679
+
680
+ This is called when:
681
+ 1. A validator unstakes and drops below the required amount
682
+ 2. The network grows and the required stake increases (tier change)
683
+
684
+ The node keeps its other layer assignments but loses the LM head.
685
+
686
+ Returns True if demotion was successful.
687
+ """
688
+ with self.lock:
689
+ return self._demote_from_validator_unlocked(node_id)
690
+
691
+ def _demote_from_validator_unlocked(self, node_id: str) -> bool:
692
+ """
693
+ Internal demotion logic (caller must hold self.lock).
694
+
695
+ Split out to avoid deadlock when called from validate_all_validators().
696
+ """
697
+ last_layer = max(0, self.current_num_layers - 1)
698
+
699
+ # Check if node is currently a validator
700
+ current_assignments = self.layer_assignments.get(last_layer, [])
701
+ was_validator = any(a.node_id == node_id for a in current_assignments)
702
+
703
+ if not was_validator:
704
+ logger.debug(f"Node {node_id[:8]}... is not a validator, nothing to demote")
705
+ return False
706
+
707
+ # Remove from last layer assignments
708
+ self.layer_assignments[last_layer] = [
709
+ a for a in current_assignments if a.node_id != node_id
710
+ ]
711
+
712
+ # Update lm_head_holder if this was the holder
713
+ if self.lm_head_holder == node_id:
714
+ # Find another validator if available
715
+ remaining = self.layer_assignments.get(last_layer, [])
716
+ if remaining:
717
+ self.lm_head_holder = remaining[0].node_id
718
+ else:
719
+ self.lm_head_holder = None
720
+
721
+ logger.warning(f"Node {node_id[:8]}... DEMOTED from Validator (insufficient stake)")
722
+ return True
723
+
724
+ def validate_all_validators(self, get_stake_fn) -> List[str]:
725
+ """
726
+ Validate all current validators still meet stake requirements.
727
+
728
+ Called periodically or when stake tier changes to ensure all validators
729
+ have sufficient stake for the current network size.
730
+
731
+ IMPORTANT: Never demotes below MIN_VALIDATORS (2) to ensure the network
732
+ can always compute real loss. The stake requirement only applies to
733
+ validators beyond the minimum.
734
+
735
+ Args:
736
+ get_stake_fn: Function(node_id) -> float that returns current stake
737
+
738
+ Returns:
739
+ List of node_ids that were demoted
740
+ """
741
+ from neuroshard.core.economics.constants import get_dynamic_validator_stake
742
+
743
+ MIN_VALIDATORS = 2 # Network needs at least 2 validators to function
744
+
745
+ demoted = []
746
+
747
+ with self.lock:
748
+ last_layer = max(0, self.current_num_layers - 1)
749
+ current_validators = list(self.layer_assignments.get(last_layer, []))
750
+ num_validators = len(current_validators)
751
+
752
+ # CRITICAL: Never demote below MIN_VALIDATORS
753
+ # Otherwise the network can't compute real cross-entropy loss!
754
+ if num_validators <= MIN_VALIDATORS:
755
+ logger.debug(f"Only {num_validators} validators - skipping stake check (minimum {MIN_VALIDATORS} required)")
756
+ return []
757
+
758
+ # Get current stake requirement
759
+ required_stake = get_dynamic_validator_stake(num_validators)
760
+
761
+ # Sort validators by stake (lowest first) to demote lowest-stake first
762
+ validators_with_stake = [
763
+ (assignment, get_stake_fn(assignment.node_id))
764
+ for assignment in current_validators
765
+ ]
766
+ validators_with_stake.sort(key=lambda x: x[1]) # Lowest stake first
767
+
768
+ for assignment, node_stake in validators_with_stake:
769
+ # Check if we'd go below minimum
770
+ remaining_validators = num_validators - len(demoted)
771
+ if remaining_validators <= MIN_VALIDATORS:
772
+ logger.info(f"Stopping demotion: {remaining_validators} validators remain (minimum {MIN_VALIDATORS})")
773
+ break
774
+
775
+ if node_stake < required_stake:
776
+ logger.warning(
777
+ f"Validator {assignment.node_id[:8]}... has {node_stake:.0f} NEURO "
778
+ f"but {required_stake:.0f} required - DEMOTING"
779
+ )
780
+ # Use unlocked version since we already hold self.lock
781
+ if self._demote_from_validator_unlocked(assignment.node_id):
782
+ demoted.append(assignment.node_id)
783
+
784
+ return demoted
785
+
786
+ def unregister_node(self, node_id: str):
787
+ """
788
+ Unregister a node and redistribute its layers.
789
+
790
+ This handles graceful degradation when nodes leave.
791
+ """
792
+ with self.lock:
793
+ # Remove from capacities
794
+ self.node_capacities.pop(node_id, None)
795
+
796
+ # Find all layers this node was holding
797
+ orphaned_layers = []
798
+
799
+ for layer_id, assignments in self.layer_assignments.items():
800
+ # Remove this node's assignment
801
+ self.layer_assignments[layer_id] = [
802
+ a for a in assignments if a.node_id != node_id
803
+ ]
804
+
805
+ # Check if layer is now orphaned (< MIN_REPLICAS)
806
+ if len(self.layer_assignments[layer_id]) < self.MIN_REPLICAS:
807
+ orphaned_layers.append(layer_id)
808
+
809
+ # Handle embedding/head holder leaving
810
+ if self.embedding_holder == node_id:
811
+ self.embedding_holder = None
812
+ if self.lm_head_holder == node_id:
813
+ self.lm_head_holder = None
814
+
815
+ if orphaned_layers:
816
+ logger.warning(f"Node {node_id[:8]}... left, {len(orphaned_layers)} layers need redistribution")
817
+ # In production, we would trigger redistribution here
818
+
819
+ def get_layer_holders(self, layer_id: int) -> List[LayerAssignment]:
820
+ """Get all nodes holding a specific layer."""
821
+ with self.lock:
822
+ return list(self.layer_assignments.get(layer_id, []))
823
+
824
+ def get_pipeline_route(self) -> List[Tuple[int, str]]:
825
+ """
826
+ Get the route for pipeline inference.
827
+
828
+ Returns list of (layer_id, grpc_addr) for each layer in order.
829
+
830
+ Filters out dead/stale nodes based on heartbeat timeout.
831
+ """
832
+ with self.lock:
833
+ route = []
834
+ now = time.time()
835
+
836
+ for layer_id in range(self.current_num_layers):
837
+ holders = self.layer_assignments.get(layer_id, [])
838
+ if not holders:
839
+ logger.error(f"Layer {layer_id} has no holders!")
840
+ continue
841
+
842
+ # ROBUSTNESS: Filter out stale holders (expired heartbeat)
843
+ active_holders = [
844
+ h for h in holders
845
+ if (now - h.last_heartbeat) < self.HEARTBEAT_TIMEOUT
846
+ ]
847
+
848
+ if not active_holders:
849
+ logger.warning(f"Layer {layer_id} has no ACTIVE holders "
850
+ f"({len(holders)} total, all stale)")
851
+ continue
852
+
853
+ # Pick best active holder (most recent heartbeat)
854
+ active_holders.sort(key=lambda h: -h.last_heartbeat)
855
+ route.append((layer_id, active_holders[0].grpc_addr))
856
+
857
+ logger.debug(f"Layer {layer_id}: selected {active_holders[0].node_id[:16]}... "
858
+ f"(heartbeat {now - active_holders[0].last_heartbeat:.1f}s ago)")
859
+
860
+ return route
861
+
862
+ def get_network_capacity(self) -> NetworkCapacity:
863
+ """Get current network capacity with dynamic architecture."""
864
+ with self.lock:
865
+ total_memory = sum(self.node_capacities.values())
866
+
867
+ # Calculate max layers based on current architecture
868
+ if self.current_architecture:
869
+ memory_per_layer = estimate_memory_per_layer(self.current_architecture)
870
+ max_layers = int(total_memory * 0.6 / memory_per_layer)
871
+ else:
872
+ max_layers = 0
873
+
874
+ layer_coverage = {
875
+ layer_id: len(assignments)
876
+ for layer_id, assignments in self.layer_assignments.items()
877
+ }
878
+
879
+ return NetworkCapacity(
880
+ total_nodes=len(self.node_capacities),
881
+ total_memory_mb=total_memory,
882
+ max_layers=max_layers,
883
+ assigned_layers=self.current_num_layers,
884
+ layer_coverage=layer_coverage,
885
+ )
886
+
887
+ def heartbeat(self, node_id: str, layer_ids: List[int]):
888
+ """Update heartbeat for a node's layers."""
889
+ with self.lock:
890
+ now = time.time()
891
+ for layer_id in layer_ids:
892
+ for assignment in self.layer_assignments.get(layer_id, []):
893
+ if assignment.node_id == node_id:
894
+ assignment.last_heartbeat = now
895
+
896
+ def cleanup_stale_assignments(self) -> int:
897
+ """
898
+ Remove stale layer assignments (nodes that haven't heartbeat recently).
899
+
900
+ Returns number of stale assignments removed.
901
+
902
+ Called periodically to prevent dead peers from being selected for pipeline routing.
903
+ """
904
+ with self.lock:
905
+ now = time.time()
906
+ removed_count = 0
907
+
908
+ for layer_id, assignments in list(self.layer_assignments.items()):
909
+ # Filter out stale assignments
910
+ active_assignments = [
911
+ a for a in assignments
912
+ if (now - a.last_heartbeat) < self.HEARTBEAT_TIMEOUT
913
+ ]
914
+
915
+ stale_count = len(assignments) - len(active_assignments)
916
+ if stale_count > 0:
917
+ logger.info(f"Layer {layer_id}: removed {stale_count} stale assignments "
918
+ f"({len(active_assignments)} remain)")
919
+ removed_count += stale_count
920
+
921
+ # Update assignments
922
+ if active_assignments:
923
+ self.layer_assignments[layer_id] = active_assignments
924
+ else:
925
+ # No active holders for this layer!
926
+ logger.warning(f"Layer {layer_id}: NO active holders remaining!")
927
+ del self.layer_assignments[layer_id]
928
+
929
+ return removed_count
930
+
931
+
932
+ class DynamicNeuroLLM:
933
+ """
934
+ A NeuroLLM that dynamically scales with the network.
935
+
936
+ Key differences from fixed-phase model:
937
+ - Number of layers AND hidden dimension determined by network capacity
938
+ - Layers are distributed across nodes
939
+ - Model grows organically in BOTH width and depth
940
+ - Architecture adapts automatically as network expands
941
+ """
942
+
943
+ def __init__(
944
+ self,
945
+ node_id: str,
946
+ layer_pool: DynamicLayerPool,
947
+ device: str = "cpu"
948
+ ):
949
+ self.node_id = node_id
950
+ self.layer_pool = layer_pool
951
+ self.device = device
952
+
953
+ # Get current architecture from layer pool
954
+ if layer_pool.current_architecture is None:
955
+ raise RuntimeError("Layer pool has no architecture - call _auto_recalculate_architecture first")
956
+ self.architecture = layer_pool.current_architecture
957
+
958
+ # My assigned layers
959
+ self.my_layers: Dict[int, torch.nn.Module] = {}
960
+ self.my_layer_ids: List[int] = []
961
+
962
+ # Callback for when layers change (set by DynamicNeuroNode to sync state)
963
+ self._on_layers_changed: Optional[callable] = None
964
+
965
+ # Reference to P2P manager for DHT updates during layer removal
966
+ self._p2p_manager = None
967
+
968
+ # Do I hold embedding/head?
969
+ self.has_embedding = False
970
+ self.has_lm_head = False
971
+
972
+ # Shared components (if I hold them)
973
+ self.embedding: Optional[torch.nn.Embedding] = None
974
+ self.lm_head: Optional[torch.nn.Linear] = None
975
+ self.final_norm: Optional[torch.nn.Module] = None
976
+
977
+ # Training mode flag (PyTorch convention)
978
+ self.training = False
979
+
980
+ logger.info(f"DynamicNeuroLLM initialized for node {node_id[:8]}... "
981
+ f"with {self.architecture.num_layers}L × {self.architecture.hidden_dim}H architecture")
982
+
983
+ def initialize_layers(self, layer_ids: List[int]):
984
+ """Initialize the layers assigned to this node using DYNAMIC architecture."""
985
+ from neuroshard.core.model.llm import NeuroLLMConfig, NeuroDecoderLayer
986
+
987
+ # Create config from current architecture (DYNAMIC!)
988
+ config = NeuroLLMConfig(
989
+ hidden_dim=self.architecture.hidden_dim,
990
+ intermediate_dim=self.architecture.intermediate_dim,
991
+ num_layers=self.architecture.num_layers,
992
+ num_heads=self.architecture.num_heads,
993
+ num_kv_heads=self.architecture.num_kv_heads,
994
+ vocab_size=self.architecture.vocab_size,
995
+ max_seq_len=self.architecture.max_seq_len,
996
+ dropout=self.architecture.dropout,
997
+ rope_theta=self.architecture.rope_theta,
998
+ )
999
+
1000
+ for layer_id in layer_ids:
1001
+ layer = NeuroDecoderLayer(config, layer_id)
1002
+ layer.to(self.device)
1003
+ self.my_layers[layer_id] = layer
1004
+
1005
+ self.my_layer_ids = sorted(layer_ids)
1006
+
1007
+ # Initialize embedding if I'm the holder (uses dynamic hidden_dim!)
1008
+ # Vocab capacity can grow dynamically as tokenizer learns more merges
1009
+ self.vocab_capacity = INITIAL_VOCAB_SIZE
1010
+ if self.layer_pool.embedding_holder == self.node_id:
1011
+ self.embedding = torch.nn.Embedding(self.vocab_capacity, self.architecture.hidden_dim)
1012
+ self.embedding.to(self.device)
1013
+ self.has_embedding = True
1014
+
1015
+ # Initialize LM head if I'm the holder (uses dynamic hidden_dim!)
1016
+ if self.layer_pool.lm_head_holder == self.node_id:
1017
+ self.lm_head = torch.nn.Linear(self.architecture.hidden_dim, self.vocab_capacity, bias=False)
1018
+ from neuroshard.core.model.llm import RMSNorm
1019
+ self.final_norm = RMSNorm(self.architecture.hidden_dim)
1020
+ self.lm_head.to(self.device)
1021
+ self.final_norm.to(self.device)
1022
+ self.has_lm_head = True
1023
+
1024
+ logger.info(f"Initialized {len(layer_ids)} layers: {layer_ids}, "
1025
+ f"arch={self.architecture.num_layers}L×{self.architecture.hidden_dim}H, "
1026
+ f"embedding={self.has_embedding}, head={self.has_lm_head}")
1027
+
1028
+ def initialize_lm_head(self) -> bool:
1029
+ """
1030
+ Dynamically initialize the LM head (for validator upgrade).
1031
+
1032
+ Called when a node is upgraded to Validator after staking.
1033
+ No restart required - initializes the head in place.
1034
+
1035
+ Returns True if initialization was successful.
1036
+ """
1037
+ if self.has_lm_head:
1038
+ logger.info("LM head already initialized")
1039
+ return True
1040
+
1041
+ try:
1042
+ from neuroshard.core.model.llm import RMSNorm
1043
+
1044
+ self.lm_head = torch.nn.Linear(self.architecture.hidden_dim, self.vocab_capacity, bias=False)
1045
+ self.final_norm = RMSNorm(self.architecture.hidden_dim)
1046
+ self.lm_head.to(self.device)
1047
+ self.final_norm.to(self.device)
1048
+ self.has_lm_head = True
1049
+
1050
+ # Add last layer to my layers if not already there
1051
+ last_layer = self.architecture.num_layers - 1
1052
+ if last_layer not in self.my_layer_ids:
1053
+ self.my_layer_ids.append(last_layer)
1054
+ self.my_layer_ids = sorted(self.my_layer_ids)
1055
+
1056
+ logger.info(f"LM head initialized! Now computing REAL cross-entropy loss")
1057
+ return True
1058
+ except Exception as e:
1059
+ logger.error(f"Failed to initialize LM head: {e}")
1060
+
1061
+ def disable_lm_head(self) -> bool:
1062
+ """
1063
+ Disable the LM head (for validator demotion).
1064
+
1065
+ Called when a validator is demoted due to insufficient stake.
1066
+ The node reverts to Worker role and uses activation norm as loss.
1067
+
1068
+ Returns True if demotion was successful.
1069
+ """
1070
+ if not self.has_lm_head:
1071
+ logger.debug("LM head not initialized, nothing to disable")
1072
+ return False
1073
+
1074
+ try:
1075
+ # Free memory from LM head
1076
+ if self.lm_head is not None:
1077
+ del self.lm_head
1078
+ self.lm_head = None
1079
+ if self.final_norm is not None:
1080
+ del self.final_norm
1081
+ self.final_norm = None
1082
+
1083
+ self.has_lm_head = False
1084
+
1085
+ # Force garbage collection to free memory
1086
+ import gc
1087
+ gc.collect()
1088
+ if torch.cuda.is_available():
1089
+ torch.cuda.empty_cache()
1090
+
1091
+ logger.warning(f"LM head DISABLED - node demoted to Worker (will use activation norm as loss)")
1092
+ return True
1093
+ except Exception as e:
1094
+ logger.error(f"Failed to disable LM head: {e}")
1095
+ return False
1096
+
1097
+ def add_layers(self, new_layer_ids: List[int]) -> List[int]:
1098
+ """
1099
+ Dynamically add new layers to a running model.
1100
+
1101
+ This allows the model to grow during training without restart!
1102
+ Called when:
1103
+ - Network needs more layers
1104
+ - Node has available memory
1105
+ - Vocab expansion freed up memory
1106
+
1107
+ Args:
1108
+ new_layer_ids: Layer IDs to add
1109
+
1110
+ Returns:
1111
+ List of layer IDs that were successfully added
1112
+ """
1113
+ from neuroshard.core.model.llm import NeuroLLMConfig, NeuroDecoderLayer
1114
+
1115
+ # Filter out layers we already have
1116
+ layers_to_add = [lid for lid in new_layer_ids if lid not in self.my_layers]
1117
+ if not layers_to_add:
1118
+ return []
1119
+
1120
+ # Check memory before adding
1121
+ memory_per_layer = estimate_memory_per_layer(self.architecture)
1122
+ required_mb = memory_per_layer * len(layers_to_add)
1123
+
1124
+ try:
1125
+ if torch.cuda.is_available() and self.device != "cpu":
1126
+ free_mb = (torch.cuda.get_device_properties(0).total_memory -
1127
+ torch.cuda.memory_allocated()) / (1024 * 1024)
1128
+ else:
1129
+ import psutil
1130
+ free_mb = psutil.virtual_memory().available / (1024 * 1024)
1131
+
1132
+ if free_mb < required_mb * 1.5: # 1.5x safety margin
1133
+ logger.warning(f"[LAYER] Insufficient memory to add {len(layers_to_add)} layers: "
1134
+ f"need {required_mb:.0f}MB, have {free_mb:.0f}MB")
1135
+ # Add as many as we can fit
1136
+ can_fit = int(free_mb / (memory_per_layer * 1.5))
1137
+ layers_to_add = layers_to_add[:can_fit]
1138
+ if not layers_to_add:
1139
+ return []
1140
+ except Exception as e:
1141
+ logger.warning(f"[LAYER] Memory check failed: {e}, proceeding cautiously")
1142
+
1143
+ # Create config from architecture
1144
+ config = NeuroLLMConfig(
1145
+ hidden_dim=self.architecture.hidden_dim,
1146
+ intermediate_dim=self.architecture.intermediate_dim,
1147
+ num_layers=self.architecture.num_layers,
1148
+ num_heads=self.architecture.num_heads,
1149
+ num_kv_heads=self.architecture.num_kv_heads,
1150
+ vocab_size=self.architecture.vocab_size,
1151
+ max_seq_len=self.architecture.max_seq_len,
1152
+ dropout=self.architecture.dropout,
1153
+ rope_theta=self.architecture.rope_theta,
1154
+ )
1155
+
1156
+ added = []
1157
+ for layer_id in layers_to_add:
1158
+ try:
1159
+ layer = NeuroDecoderLayer(config, layer_id)
1160
+ layer.to(self.device)
1161
+ self.my_layers[layer_id] = layer
1162
+ added.append(layer_id)
1163
+ except Exception as e:
1164
+ logger.error(f"[LAYER] Failed to add layer {layer_id}: {e}")
1165
+ break # Stop on first failure
1166
+
1167
+ if added:
1168
+ self.my_layer_ids = sorted(self.my_layers.keys())
1169
+ logger.info(f"[LAYER] ✅ Added {len(added)} layers: {added}")
1170
+ logger.info(f"[LAYER] Now holding {len(self.my_layer_ids)} layers: {self.my_layer_ids[:5]}...{self.my_layer_ids[-5:]}")
1171
+
1172
+ return added
1173
+
1174
+ def remove_layers(self, layer_ids_to_remove: List[int], layer_pool=None, p2p_manager=None) -> List[int]:
1175
+ """
1176
+ Dynamically remove layers from a running model.
1177
+
1178
+ This allows the model to shrink during training without restart!
1179
+ Called when:
1180
+ - Vocab growth needs more memory
1181
+ - Network is redistributing layers
1182
+ - Memory pressure detected
1183
+
1184
+ IMPORTANT: This also updates layer_pool and DHT announcements so other
1185
+ nodes know we no longer hold these layers!
1186
+
1187
+ Args:
1188
+ layer_ids_to_remove: Layer IDs to remove
1189
+ layer_pool: DynamicLayerPool to update assignments (optional)
1190
+ p2p_manager: P2PManager to update DHT announcements (optional)
1191
+
1192
+ Returns:
1193
+ List of layer IDs that were successfully removed
1194
+ """
1195
+ # Can't remove layers we don't have
1196
+ layers_to_remove = [lid for lid in layer_ids_to_remove if lid in self.my_layers]
1197
+ if not layers_to_remove:
1198
+ return []
1199
+
1200
+ # Don't remove all layers - keep at least 1
1201
+ if len(layers_to_remove) >= len(self.my_layers):
1202
+ layers_to_remove = layers_to_remove[:-1] # Keep last one
1203
+ if not layers_to_remove:
1204
+ logger.warning("[LAYER] Cannot remove all layers")
1205
+ return []
1206
+
1207
+ removed = []
1208
+ for layer_id in layers_to_remove:
1209
+ try:
1210
+ # Delete the layer
1211
+ layer = self.my_layers.pop(layer_id)
1212
+ del layer
1213
+ removed.append(layer_id)
1214
+ except Exception as e:
1215
+ logger.error(f"[LAYER] Failed to remove layer {layer_id}: {e}")
1216
+
1217
+ if removed:
1218
+ self.my_layer_ids = sorted(self.my_layers.keys())
1219
+
1220
+ # UPDATE LAYER POOL: Remove this node from removed layers' assignments
1221
+ if layer_pool:
1222
+ try:
1223
+ with layer_pool.lock:
1224
+ for layer_id in removed:
1225
+ if layer_id in layer_pool.layer_assignments:
1226
+ # Remove our node from this layer's holders
1227
+ layer_pool.layer_assignments[layer_id] = [
1228
+ a for a in layer_pool.layer_assignments[layer_id]
1229
+ if a.node_id != self.node_id
1230
+ ]
1231
+ logger.info(f"[LAYER] Updated layer_pool: removed self from layer {layer_id}")
1232
+ except Exception as e:
1233
+ logger.warning(f"[LAYER] Could not update layer_pool: {e}")
1234
+
1235
+ # UPDATE DHT: Change announced layer range
1236
+ if p2p_manager and self.my_layer_ids:
1237
+ try:
1238
+ new_start = min(self.my_layer_ids)
1239
+ new_end = max(self.my_layer_ids)
1240
+ p2p_manager.start_layer = new_start
1241
+ p2p_manager.end_layer = new_end
1242
+ p2p_manager.shard_range = f"{new_start}-{new_end}"
1243
+ logger.info(f"[LAYER] Updated P2P shard_range: {new_start}-{new_end}")
1244
+
1245
+ # Re-announce immediately so network knows
1246
+ if hasattr(p2p_manager, '_announce_once'):
1247
+ p2p_manager._announce_once(verbose=True)
1248
+ except Exception as e:
1249
+ logger.warning(f"[LAYER] Could not update P2P shard_range: {e}")
1250
+
1251
+ # Force garbage collection to free memory
1252
+ import gc
1253
+ gc.collect()
1254
+ if torch.cuda.is_available():
1255
+ torch.cuda.empty_cache()
1256
+
1257
+ logger.info(f"[LAYER] ✅ Removed {len(removed)} layers: {removed}")
1258
+ logger.info(f"[LAYER] Now holding {len(self.my_layer_ids)} layers")
1259
+
1260
+ # NOTIFY: Call callback so node can sync its state
1261
+ if self._on_layers_changed:
1262
+ try:
1263
+ self._on_layers_changed(self.my_layer_ids)
1264
+ except Exception as e:
1265
+ logger.warning(f"[LAYER] Callback failed: {e}")
1266
+
1267
+ return removed
1268
+
1269
+ def expand_vocabulary(self, new_vocab_size: int) -> bool:
1270
+ """
1271
+ Expand embedding and lm_head to accommodate a larger vocabulary.
1272
+
1273
+ This is called when the tokenizer learns new BPE merges that exceed
1274
+ the current vocabulary capacity. The expansion preserves existing
1275
+ token embeddings while initializing new ones.
1276
+
1277
+ For an ever-growing decentralized LLM, vocabulary expansion is essential
1278
+ as millions of users contribute diverse training data across languages
1279
+ and domains.
1280
+
1281
+ Args:
1282
+ new_vocab_size: The new vocabulary size (must be > current capacity)
1283
+
1284
+ Returns:
1285
+ True if expansion was successful, False otherwise
1286
+ """
1287
+ if new_vocab_size <= self.vocab_capacity:
1288
+ return True # No expansion needed
1289
+
1290
+ # Check against max (if set)
1291
+ if MAX_VOCAB_SIZE is not None and new_vocab_size > MAX_VOCAB_SIZE:
1292
+ logger.warning(f"Requested vocab {new_vocab_size} exceeds MAX_VOCAB_SIZE {MAX_VOCAB_SIZE}")
1293
+ new_vocab_size = MAX_VOCAB_SIZE
1294
+
1295
+ # Round up to next VOCAB_GROWTH_CHUNK for efficient memory alignment
1296
+ new_capacity = ((new_vocab_size + VOCAB_GROWTH_CHUNK - 1) // VOCAB_GROWTH_CHUNK) * VOCAB_GROWTH_CHUNK
1297
+ if MAX_VOCAB_SIZE is not None:
1298
+ new_capacity = min(new_capacity, MAX_VOCAB_SIZE)
1299
+
1300
+ # MEMORY CHECK: Estimate memory needed for expansion
1301
+ # Embedding + LM head expansion: 2 * (new - old) * hidden_dim * 4 bytes (float32)
1302
+ hidden_dim = self.architecture.hidden_dim
1303
+ expansion_params = 2 * (new_capacity - self.vocab_capacity) * hidden_dim
1304
+ expansion_memory_mb = (expansion_params * 4) / (1024 * 1024) # Just weights, no optimizer yet
1305
+
1306
+ # Check available memory (GPU or CPU)
1307
+ try:
1308
+ if torch.cuda.is_available() and self.device != "cpu":
1309
+ # Check GPU memory
1310
+ free_memory_mb = (torch.cuda.get_device_properties(0).total_memory -
1311
+ torch.cuda.memory_allocated()) / (1024 * 1024)
1312
+ else:
1313
+ # Check system RAM
1314
+ import psutil
1315
+ free_memory_mb = psutil.virtual_memory().available / (1024 * 1024)
1316
+
1317
+ # Need at least 2x expansion memory (weights + temporary copy during expansion)
1318
+ required_mb = expansion_memory_mb * 2
1319
+
1320
+ if free_memory_mb < required_mb:
1321
+ logger.warning(f"[VOCAB] Insufficient memory for expansion: need {required_mb:.0f}MB, "
1322
+ f"have {free_memory_mb:.0f}MB free")
1323
+
1324
+ # TRY: Remove some layers to make room for vocab expansion
1325
+ # Vocab is more important than extra layers (all nodes need same vocab)
1326
+ memory_per_layer = estimate_memory_per_layer(self.architecture)
1327
+ layers_to_free = int((required_mb - free_memory_mb) / memory_per_layer) + 1
1328
+
1329
+ if len(self.my_layers) > layers_to_free + 1: # Keep at least 1 layer
1330
+ # Remove highest-numbered layers (least important for Driver/Validator)
1331
+ layers_to_remove = sorted(self.my_layers.keys(), reverse=True)[:layers_to_free]
1332
+ logger.warning(f"[VOCAB] Attempting to free memory by removing {layers_to_free} layers: {layers_to_remove}")
1333
+
1334
+ # Pass layer_pool so network state is updated
1335
+ # p2p_manager will be notified via layer_pool.on_layers_changed callback
1336
+ removed = self.remove_layers(
1337
+ layers_to_remove,
1338
+ layer_pool=self.layer_pool,
1339
+ p2p_manager=getattr(self, '_p2p_manager', None)
1340
+ )
1341
+ if removed:
1342
+ # Recalculate free memory after layer removal
1343
+ if torch.cuda.is_available() and self.device != "cpu":
1344
+ free_memory_mb = (torch.cuda.get_device_properties(0).total_memory -
1345
+ torch.cuda.memory_allocated()) / (1024 * 1024)
1346
+ else:
1347
+ import psutil
1348
+ free_memory_mb = psutil.virtual_memory().available / (1024 * 1024)
1349
+
1350
+ if free_memory_mb >= required_mb:
1351
+ logger.info(f"[VOCAB] ✅ Freed enough memory by removing {len(removed)} layers")
1352
+ # Continue with expansion below
1353
+ else:
1354
+ logger.warning(f"[VOCAB] Still insufficient memory after layer removal")
1355
+ return False
1356
+ else:
1357
+ logger.warning(f"[VOCAB] Could not remove layers, capping expansion")
1358
+ return False
1359
+ else:
1360
+ logger.warning(f"[VOCAB] Not enough layers to remove, capping expansion")
1361
+ return False
1362
+
1363
+ logger.info(f"[VOCAB] Memory check passed: {expansion_memory_mb:.0f}MB needed, "
1364
+ f"{free_memory_mb:.0f}MB available")
1365
+ except Exception as e:
1366
+ logger.warning(f"[VOCAB] Could not check memory: {e}, proceeding with expansion")
1367
+
1368
+ logger.info(f"[VOCAB] Expanding vocabulary: {self.vocab_capacity} → {new_capacity}")
1369
+
1370
+ try:
1371
+ # Expand embedding if we have it
1372
+ if self.has_embedding and self.embedding is not None:
1373
+ old_embedding = self.embedding
1374
+ old_vocab = old_embedding.weight.shape[0]
1375
+ hidden_dim = old_embedding.weight.shape[1]
1376
+
1377
+ # Create new larger embedding
1378
+ new_embedding = torch.nn.Embedding(new_capacity, hidden_dim)
1379
+ new_embedding.to(self.device)
1380
+
1381
+ # Copy existing embeddings
1382
+ with torch.no_grad():
1383
+ new_embedding.weight[:old_vocab] = old_embedding.weight
1384
+ # Initialize new embeddings with small random values
1385
+ # (similar to how transformers initialize)
1386
+ std = 0.02
1387
+ new_embedding.weight[old_vocab:].normal_(mean=0.0, std=std)
1388
+
1389
+ # Replace old embedding
1390
+ del self.embedding
1391
+ self.embedding = new_embedding
1392
+
1393
+ logger.info(f"[VOCAB] Expanded embedding: {old_vocab} → {new_capacity} tokens")
1394
+
1395
+ # Expand lm_head if we have it
1396
+ if self.has_lm_head and self.lm_head is not None:
1397
+ old_lm_head = self.lm_head
1398
+ old_vocab = old_lm_head.weight.shape[0]
1399
+ hidden_dim = old_lm_head.weight.shape[1]
1400
+
1401
+ # Create new larger lm_head
1402
+ new_lm_head = torch.nn.Linear(hidden_dim, new_capacity, bias=False)
1403
+ new_lm_head.to(self.device)
1404
+
1405
+ # Copy existing weights
1406
+ with torch.no_grad():
1407
+ new_lm_head.weight[:old_vocab] = old_lm_head.weight
1408
+ # Initialize new output weights
1409
+ std = 0.02
1410
+ new_lm_head.weight[old_vocab:].normal_(mean=0.0, std=std)
1411
+
1412
+ # Replace old lm_head
1413
+ del self.lm_head
1414
+ self.lm_head = new_lm_head
1415
+
1416
+ logger.info(f"[VOCAB] Expanded lm_head: {old_vocab} → {new_capacity} output classes")
1417
+
1418
+ # Update capacity
1419
+ old_capacity = self.vocab_capacity
1420
+ self.vocab_capacity = new_capacity
1421
+
1422
+ # Force garbage collection
1423
+ import gc
1424
+ gc.collect()
1425
+ if torch.cuda.is_available():
1426
+ torch.cuda.empty_cache()
1427
+
1428
+ logger.info(f"[VOCAB] ✅ Vocabulary expansion complete: {old_capacity} → {new_capacity}")
1429
+ return True
1430
+
1431
+ except Exception as e:
1432
+ logger.error(f"[VOCAB] Failed to expand vocabulary: {e}")
1433
+ return False
1434
+
1435
+ def check_and_expand_vocab_if_needed(self) -> bool:
1436
+ """
1437
+ Check if tokenizer vocabulary exceeds current capacity and expand if needed.
1438
+
1439
+ This should be called periodically (e.g., after loading new tokenizer from CDN)
1440
+ to ensure the model can handle all tokens in the current vocabulary.
1441
+
1442
+ Returns:
1443
+ True if no expansion needed or expansion successful, False on failure
1444
+ """
1445
+ if self.tokenizer is None:
1446
+ return True
1447
+
1448
+ current_vocab = self.tokenizer.current_vocab_size
1449
+ if current_vocab > self.vocab_capacity:
1450
+ logger.info(f"[VOCAB] Tokenizer vocab ({current_vocab}) exceeds capacity ({self.vocab_capacity})")
1451
+ return self.expand_vocabulary(current_vocab)
1452
+
1453
+ return True
1454
+
1455
+ def forward_my_layers(
1456
+ self,
1457
+ hidden_states: torch.Tensor,
1458
+ start_layer: Optional[int] = None,
1459
+ end_layer: Optional[int] = None,
1460
+ ) -> torch.Tensor:
1461
+ """Forward through my assigned layers."""
1462
+ if start_layer is None:
1463
+ start_layer = min(self.my_layer_ids) if self.my_layer_ids else 0
1464
+ if end_layer is None:
1465
+ end_layer = max(self.my_layer_ids) + 1 if self.my_layer_ids else 0
1466
+
1467
+ x = hidden_states
1468
+
1469
+ for layer_id in range(start_layer, end_layer):
1470
+ if layer_id in self.my_layers:
1471
+ x, _ = self.my_layers[layer_id](x)
1472
+
1473
+ return x
1474
+
1475
+ def embed(self, input_ids: torch.Tensor) -> torch.Tensor:
1476
+ """Embed input tokens (only if I hold embedding)."""
1477
+ if not self.has_embedding:
1478
+ raise RuntimeError("This node does not hold the embedding layer")
1479
+ return self.embedding(input_ids)
1480
+
1481
+ def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
1482
+ """Compute logits (only if I hold LM head)."""
1483
+ if not self.has_lm_head:
1484
+ raise RuntimeError("This node does not hold the LM head")
1485
+ x = self.final_norm(hidden_states)
1486
+ return self.lm_head(x)
1487
+
1488
+ def get_num_params(self) -> int:
1489
+ """Get number of parameters on this node."""
1490
+ total = 0
1491
+ for layer in self.my_layers.values():
1492
+ total += sum(p.numel() for p in layer.parameters())
1493
+ if self.embedding:
1494
+ total += sum(p.numel() for p in self.embedding.parameters())
1495
+ if self.lm_head:
1496
+ total += sum(p.numel() for p in self.lm_head.parameters())
1497
+ if self.final_norm:
1498
+ total += sum(p.numel() for p in self.final_norm.parameters())
1499
+ return total
1500
+
1501
+ def parameters(self):
1502
+ """Yield all parameters for this node's model components (for optimizer/gradient clipping)."""
1503
+ for layer in self.my_layers.values():
1504
+ yield from layer.parameters()
1505
+ if self.embedding:
1506
+ yield from self.embedding.parameters()
1507
+ if self.lm_head:
1508
+ yield from self.lm_head.parameters()
1509
+ if self.final_norm:
1510
+ yield from self.final_norm.parameters()
1511
+
1512
+ def named_parameters(self, prefix: str = '', recurse: bool = True):
1513
+ """
1514
+ Yield (name, param) tuples for all parameters.
1515
+
1516
+ This is the standard PyTorch interface for iterating over named parameters.
1517
+ """
1518
+ # Layers
1519
+ for layer_id, layer in self.my_layers.items():
1520
+ layer_prefix = f"{prefix}layers.{layer_id}." if prefix else f"layers.{layer_id}."
1521
+ for name, param in layer.named_parameters(prefix='', recurse=recurse):
1522
+ yield layer_prefix + name, param
1523
+
1524
+ # Embedding
1525
+ if self.embedding:
1526
+ emb_prefix = f"{prefix}embedding." if prefix else "embedding."
1527
+ for name, param in self.embedding.named_parameters(prefix='', recurse=recurse):
1528
+ yield emb_prefix + name, param
1529
+
1530
+ # LM Head
1531
+ if self.lm_head:
1532
+ head_prefix = f"{prefix}lm_head." if prefix else "lm_head."
1533
+ for name, param in self.lm_head.named_parameters(prefix='', recurse=recurse):
1534
+ yield head_prefix + name, param
1535
+
1536
+ # Final Norm
1537
+ if self.final_norm:
1538
+ norm_prefix = f"{prefix}final_norm." if prefix else "final_norm."
1539
+ for name, param in self.final_norm.named_parameters(prefix='', recurse=recurse):
1540
+ yield norm_prefix + name, param
1541
+
1542
+ def state_dict(self) -> Dict[str, Any]:
1543
+ """
1544
+ Return the state dictionary of the model.
1545
+
1546
+ This is the standard PyTorch interface for saving model state.
1547
+ """
1548
+ state = {}
1549
+
1550
+ # Layers
1551
+ for layer_id, layer in self.my_layers.items():
1552
+ for name, param in layer.state_dict().items():
1553
+ state[f"layers.{layer_id}.{name}"] = param
1554
+
1555
+ # Embedding
1556
+ if self.embedding:
1557
+ for name, param in self.embedding.state_dict().items():
1558
+ state[f"embedding.{name}"] = param
1559
+
1560
+ # LM Head
1561
+ if self.lm_head:
1562
+ for name, param in self.lm_head.state_dict().items():
1563
+ state[f"lm_head.{name}"] = param
1564
+
1565
+ # Final Norm
1566
+ if self.final_norm:
1567
+ for name, param in self.final_norm.state_dict().items():
1568
+ state[f"final_norm.{name}"] = param
1569
+
1570
+ return state
1571
+
1572
+ def load_state_dict(self, state_dict: Dict[str, Any], strict: bool = True):
1573
+ """
1574
+ Load state dictionary into the model.
1575
+
1576
+ This is the standard PyTorch interface for loading model state.
1577
+ """
1578
+ # Group state by component
1579
+ layer_states: Dict[int, Dict[str, Any]] = {}
1580
+ embedding_state: Dict[str, Any] = {}
1581
+ lm_head_state: Dict[str, Any] = {}
1582
+ final_norm_state: Dict[str, Any] = {}
1583
+
1584
+ for key, value in state_dict.items():
1585
+ if key.startswith("layers."):
1586
+ parts = key.split(".", 2)
1587
+ layer_id = int(parts[1])
1588
+ param_name = parts[2]
1589
+ if layer_id not in layer_states:
1590
+ layer_states[layer_id] = {}
1591
+ layer_states[layer_id][param_name] = value
1592
+ elif key.startswith("embedding."):
1593
+ param_name = key[len("embedding."):]
1594
+ embedding_state[param_name] = value
1595
+ elif key.startswith("lm_head."):
1596
+ param_name = key[len("lm_head."):]
1597
+ lm_head_state[param_name] = value
1598
+ elif key.startswith("final_norm."):
1599
+ param_name = key[len("final_norm."):]
1600
+ final_norm_state[param_name] = value
1601
+
1602
+ # Load into components
1603
+ for layer_id, layer in self.my_layers.items():
1604
+ if layer_id in layer_states:
1605
+ layer.load_state_dict(layer_states[layer_id], strict=strict)
1606
+
1607
+ if self.embedding and embedding_state:
1608
+ self.embedding.load_state_dict(embedding_state, strict=strict)
1609
+
1610
+ if self.lm_head and lm_head_state:
1611
+ self.lm_head.load_state_dict(lm_head_state, strict=strict)
1612
+
1613
+ if self.final_norm and final_norm_state:
1614
+ self.final_norm.load_state_dict(final_norm_state, strict=strict)
1615
+
1616
+ def zero_grad(self, set_to_none: bool = False):
1617
+ """
1618
+ Zero all gradients.
1619
+
1620
+ This is the standard PyTorch interface for zeroing gradients.
1621
+ """
1622
+ for param in self.parameters():
1623
+ if param.grad is not None:
1624
+ if set_to_none:
1625
+ param.grad = None
1626
+ else:
1627
+ param.grad.zero_()
1628
+
1629
+ def train(self, mode: bool = True) -> 'DynamicNeuroLLM':
1630
+ """
1631
+ Set the model to training mode.
1632
+
1633
+ This is the standard PyTorch interface for setting training mode
1634
+ on all submodules.
1635
+ """
1636
+ self.training = mode
1637
+ for layer in self.my_layers.values():
1638
+ layer.train(mode)
1639
+ if self.embedding:
1640
+ self.embedding.train(mode)
1641
+ if self.lm_head:
1642
+ self.lm_head.train(mode)
1643
+ if self.final_norm:
1644
+ self.final_norm.train(mode)
1645
+ return self
1646
+
1647
+ def eval(self) -> 'DynamicNeuroLLM':
1648
+ """Set the model to evaluation mode."""
1649
+ return self.train(False)
1650
+
1651
+ def get_my_contribution(self) -> Dict[str, Any]:
1652
+ """Get this node's contribution to the network."""
1653
+ capacity = self.layer_pool.get_network_capacity()
1654
+
1655
+ return {
1656
+ "node_id": self.node_id[:16] + "...",
1657
+ "my_layers": self.my_layer_ids,
1658
+ "my_params": self.get_num_params(),
1659
+ "has_embedding": self.has_embedding,
1660
+ "has_lm_head": self.has_lm_head,
1661
+ "network_total_layers": capacity.assigned_layers,
1662
+ "network_total_nodes": capacity.total_nodes,
1663
+ "contribution_ratio": len(self.my_layer_ids) / max(1, capacity.assigned_layers),
1664
+ }
1665
+
1666
+
1667
+ def calculate_reward_multiplier(
1668
+ num_layers_held: int,
1669
+ total_network_layers: int,
1670
+ has_embedding: bool,
1671
+ has_lm_head: bool
1672
+ ) -> float:
1673
+ """
1674
+ Calculate NEURO reward multiplier based on contribution.
1675
+
1676
+ Roles:
1677
+ - Worker: Standard reward based on layers
1678
+ - Driver (Embedding): 1.2x bonus (bandwidth cost)
1679
+ - Validator (Head): 1.2x bonus (compute/consensus cost)
1680
+ """
1681
+ if total_network_layers == 0:
1682
+ return 1.0
1683
+
1684
+ # Base multiplier from layer contribution
1685
+ layer_ratio = num_layers_held / total_network_layers
1686
+ base_multiplier = 1.0 + layer_ratio # 1.0 to 2.0 based on layers
1687
+
1688
+ # Bonus for critical components (Roles)
1689
+ if has_embedding:
1690
+ base_multiplier *= 1.2 # 20% bonus for Driving (Data bandwidth)
1691
+ if has_lm_head:
1692
+ base_multiplier *= 1.2 # 20% bonus for Validating (Loss calc + Gradient origin)
1693
+
1694
+ return base_multiplier
1695
+
1696
+
1697
+ # ============================================================================
1698
+ # DYNAMIC NEURO NODE - The Main Node Class
1699
+ # ============================================================================
1700
+
1701
+ class DynamicNeuroNode:
1702
+ """
1703
+ A truly decentralized node that contributes based on available memory.
1704
+
1705
+ NO PHASES. NO CENTRAL COORDINATION.
1706
+
1707
+ How it works:
1708
+ 1. Node starts, detects available memory
1709
+ 2. Registers with network, gets assigned layers
1710
+ 3. Loads only the layers it's responsible for
1711
+ 4. Participates in training (computes gradients for its layers)
1712
+ 5. Participates in inference (forwards through its layers)
1713
+ 6. Earns NEURO proportional to its contribution
1714
+
1715
+ The more memory you have, the more layers you hold, the more you earn.
1716
+ """
1717
+
1718
+ CHECKPOINT_DIR = None # Set in __init__
1719
+
1720
+ def __init__(
1721
+ self,
1722
+ node_id: str,
1723
+ port: int = 8000,
1724
+ tracker_url: str = "https://neuroshard.com/api/tracker",
1725
+ node_token: Optional[str] = None,
1726
+ device: str = "cpu",
1727
+ available_memory_mb: Optional[float] = None,
1728
+ enable_training: bool = True,
1729
+ max_storage_mb: float = 100.0,
1730
+ max_cpu_threads: Optional[int] = None,
1731
+ ):
1732
+ self.node_id = node_id
1733
+ self.port = port
1734
+ self.tracker_url = tracker_url
1735
+ self.node_token = node_token
1736
+
1737
+ # Detect device automatically if "auto" or "cpu" (backward compatibility)
1738
+ if device in ("auto", "cpu"):
1739
+ if torch.cuda.is_available():
1740
+ self.device = "cuda"
1741
+ logger.info(f"[NODE] GPU detected: CUDA available")
1742
+ elif torch.backends.mps.is_available():
1743
+ # MPS (Apple Silicon GPU) - enabled now that we have GIL yields
1744
+ self.device = "mps"
1745
+ logger.info(f"[NODE] GPU detected: Apple Metal (MPS)")
1746
+ else:
1747
+ self.device = "cpu"
1748
+ logger.info(f"[NODE] No GPU detected, using CPU")
1749
+
1750
+ # Help debug why CUDA isn't available
1751
+ import subprocess
1752
+ import sys
1753
+
1754
+ # Check if NVIDIA GPU exists
1755
+ has_nvidia_gpu = False
1756
+ try:
1757
+ if sys.platform == 'win32':
1758
+ result = subprocess.run(['nvidia-smi'], capture_output=True, text=True, timeout=2)
1759
+ has_nvidia_gpu = result.returncode == 0
1760
+ elif sys.platform == 'darwin':
1761
+ # macOS doesn't have NVIDIA support (use MPS instead)
1762
+ pass
1763
+ else: # Linux
1764
+ result = subprocess.run(['nvidia-smi'], capture_output=True, text=True, timeout=2)
1765
+ has_nvidia_gpu = result.returncode == 0
1766
+ except Exception:
1767
+ pass
1768
+
1769
+ # Detailed diagnostic
1770
+ logger.info(f"[NODE] torch.cuda.is_available() = False")
1771
+
1772
+ # Check if PyTorch was built with CUDA
1773
+ try:
1774
+ cuda_available = torch.cuda.is_available()
1775
+ cuda_built = getattr(torch.version, 'cuda', None)
1776
+ torch_version = torch.__version__
1777
+ logger.info(f"[NODE] PyTorch version: {torch_version}")
1778
+ logger.info(f"[NODE] CUDA compiled version: {cuda_built if cuda_built else 'None (CPU-only build)'}")
1779
+ except Exception as e:
1780
+ logger.info(f"[NODE] Could not get CUDA info: {e}")
1781
+
1782
+ # Provide helpful diagnostic
1783
+ if has_nvidia_gpu:
1784
+ logger.warning("⚠️ NVIDIA GPU DETECTED BUT NOT BEING USED!")
1785
+ logger.warning("Your system has an NVIDIA GPU, but this PyTorch installation is CPU-only.")
1786
+ logger.warning("🔧 TO ENABLE GPU (for 5-10x faster training):")
1787
+ logger.warning("If running the .exe (frozen build):")
1788
+ logger.warning(" Unfortunately, the bundled Python environment can't easily be modified.")
1789
+ logger.warning(" We recommend running from source for GPU support.")
1790
+ logger.warning("If running from source:")
1791
+ logger.warning(" pip uninstall torch")
1792
+ logger.warning(" pip install torch --index-url https://download.pytorch.org/whl/cu121")
1793
+ logger.warning("To verify: python -c \"import torch; print(torch.cuda.is_available())\"")
1794
+ else:
1795
+ self.device = device
1796
+ logger.info(f"[NODE] Device manually set to: {self.device}")
1797
+
1798
+ logger.info(f"Using device: {self.device}")
1799
+
1800
+ self.enable_training = enable_training
1801
+ self.max_storage_mb = max_storage_mb
1802
+ self.max_cpu_threads = max_cpu_threads
1803
+
1804
+ # CPU thread limiting is done in runner.py BEFORE any torch operations
1805
+ # (torch.set_num_interop_threads must be called before any parallel work)
1806
+ if max_cpu_threads and self.device == "cpu":
1807
+ torch.set_num_threads(max_cpu_threads) # Intra-op parallelism only
1808
+ logger.info(f"Set PyTorch intra-op threads: {max_cpu_threads}")
1809
+
1810
+ # Detect memory if not provided
1811
+ if available_memory_mb is None:
1812
+ self.available_memory_mb = self._detect_available_memory()
1813
+ else:
1814
+ self.available_memory_mb = available_memory_mb
1815
+
1816
+ # Checkpoint directory
1817
+ from pathlib import Path
1818
+ self.CHECKPOINT_DIR = Path.home() / ".neuroshard" / "checkpoints"
1819
+ self.CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
1820
+
1821
+ # Compute wallet_id from token for stable checkpoint naming
1822
+ # (node_id can change if machine_id changes, but wallet_id is stable)
1823
+ if node_token:
1824
+ self.wallet_id = hashlib.sha256(node_token.encode()).hexdigest()[:16]
1825
+ else:
1826
+ self.wallet_id = self.node_id[:16] # Fallback to node_id
1827
+
1828
+ # Layer pool (shared across network via DHT)
1829
+ self.layer_pool: Optional[DynamicLayerPool] = None
1830
+
1831
+ # My model (only my layers)
1832
+ self.model: Optional[DynamicNeuroLLM] = None
1833
+ self.my_layer_ids: List[int] = []
1834
+
1835
+ # Tokenizer
1836
+ self.tokenizer = None
1837
+
1838
+ # Training components (enable_training set in __init__)
1839
+ self.optimizer: Optional[torch.optim.Optimizer] = None
1840
+ self.training_coordinator = None
1841
+ self.data_manager = None
1842
+ self.gradient_gossip = None
1843
+
1844
+ # Training lock to prevent concurrent training operations
1845
+ # (local training vs pipeline training conflict)
1846
+ self._training_lock = threading.Lock()
1847
+
1848
+ # P2P
1849
+ self.p2p_manager = None
1850
+
1851
+ # Stats
1852
+ self.is_running = False
1853
+ self.total_tokens_processed = 0
1854
+ self.total_training_rounds = 0
1855
+ self.current_loss = float('inf')
1856
+ self.inference_count = 0
1857
+ self.training_contribution_count = 0
1858
+
1859
+ # KV cache for inference
1860
+ self.kv_cache: Dict[str, Any] = {}
1861
+
1862
+ # Training context (keeps tensors alive for backward pass)
1863
+ # session_id -> {input, output, prev_peer}
1864
+ self.training_context: Dict[str, Any] = {}
1865
+
1866
+ logger.info(f"DynamicNeuroNode initialized: memory={self.available_memory_mb:.0f}MB")
1867
+
1868
+ def _detect_available_memory(self) -> float:
1869
+ """Detect available system memory."""
1870
+ try:
1871
+ import psutil
1872
+ mem = psutil.virtual_memory()
1873
+ # Use 70% of available memory for safety
1874
+ return mem.available * 0.7 / (1024 * 1024)
1875
+ except ImportError:
1876
+ # Fallback
1877
+ return 2000 # Assume 2GB
1878
+
1879
+ def start(self):
1880
+ """Start the node."""
1881
+ logger.info("Starting DynamicNeuroNode...")
1882
+
1883
+ # 1. Initialize layer pool
1884
+ dht = None
1885
+ if self.p2p_manager and hasattr(self.p2p_manager, 'dht'):
1886
+ dht = self.p2p_manager.dht
1887
+ self.layer_pool = DynamicLayerPool(dht_protocol=dht)
1888
+
1889
+ # Pass device hint for memory calculations (CPU needs more conservative limits)
1890
+ self.layer_pool._device_hint = self.device
1891
+
1892
+ # 1b. SMART ARCHITECTURE RECONCILIATION
1893
+ # This handles the case where the network has evolved while we were offline
1894
+ self._reconcile_architecture()
1895
+
1896
+ # 1c. PRE-FETCH TOKENIZER VOCAB SIZE for accurate memory calculation
1897
+ # This is critical for dynamic vocab - we need to know vocab size BEFORE
1898
+ # assigning layers, otherwise we'll assign too many and OOM when vocab expands
1899
+ self._prefetch_vocab_capacity()
1900
+
1901
+ # 2. Get staked amount from ledger (for Validator eligibility)
1902
+ staked_amount = 0.0
1903
+ if self.p2p_manager and self.p2p_manager.ledger:
1904
+ try:
1905
+ account_info = self.p2p_manager.ledger.get_account_info()
1906
+ staked_amount = account_info.get("stake", 0.0)
1907
+ logger.info(f"Current stake: {staked_amount:.2f} NEURO")
1908
+ except Exception as e:
1909
+ logger.debug(f"Could not get stake info: {e}")
1910
+
1911
+ # 3. Register with network and get layer assignments
1912
+ self.my_layer_ids = self.layer_pool.register_node(
1913
+ node_id=self.node_id,
1914
+ node_url=f"http://localhost:{self.port}",
1915
+ grpc_addr=f"localhost:{self.port + 1000}",
1916
+ available_memory_mb=self.available_memory_mb,
1917
+ staked_amount=staked_amount
1918
+ )
1919
+
1920
+ logger.info(f"Assigned {len(self.my_layer_ids)} layers: {self.my_layer_ids}")
1921
+
1922
+ # 3. Initialize model with my layers
1923
+ self.model = DynamicNeuroLLM(
1924
+ node_id=self.node_id,
1925
+ layer_pool=self.layer_pool,
1926
+ device=self.device
1927
+ )
1928
+ self.model.initialize_layers(self.my_layer_ids)
1929
+
1930
+ # 3b. Set up callback for dynamic layer changes
1931
+ # When model removes layers (e.g., for vocab expansion), sync node state
1932
+ def on_layers_changed(new_layer_ids: List[int]):
1933
+ self.my_layer_ids = new_layer_ids
1934
+ # Update P2P shard_range if available
1935
+ if self.p2p_manager and new_layer_ids:
1936
+ new_start = min(new_layer_ids)
1937
+ new_end = max(new_layer_ids)
1938
+ self.p2p_manager.start_layer = new_start
1939
+ self.p2p_manager.end_layer = new_end
1940
+ self.p2p_manager.shard_range = f"{new_start}-{new_end}"
1941
+ logger.info(f"[NODE] Synced layer_ids after change: {new_layer_ids}")
1942
+
1943
+ self.model._on_layers_changed = on_layers_changed
1944
+
1945
+ # 4. Initialize tokenizer with learned BPE merges from CDN
1946
+ from neuroshard.core.model.tokenizer import get_neuro_tokenizer, NeuroTokenizer
1947
+ self.tokenizer = get_neuro_tokenizer()
1948
+ self._load_learned_tokenizer() # Update with BPE merges from CDN
1949
+
1950
+ # 5. Try to load existing checkpoint (resume training)
1951
+ self._load_checkpoint()
1952
+
1953
+ # 6. Setup training
1954
+ if self.enable_training:
1955
+ self._setup_training()
1956
+
1957
+ self.is_running = True
1958
+
1959
+ # Log contribution
1960
+ contribution = self.model.get_my_contribution()
1961
+ logger.info(f"Node started: {contribution['my_params']/1e6:.1f}M params, "
1962
+ f"{len(self.my_layer_ids)} layers, "
1963
+ f"embed={self.model.has_embedding}, head={self.model.has_lm_head}")
1964
+
1965
+ # Verify model is actually on the expected device
1966
+ if self.my_layer_ids and self.model.my_layers:
1967
+ first_layer = self.model.my_layers[self.my_layer_ids[0]]
1968
+ param_device = next(first_layer.parameters()).device
1969
+ if str(param_device) != self.device and not (self.device == "cuda" and "cuda" in str(param_device)):
1970
+ logger.error(f"[DEVICE] Model device mismatch! Expected {self.device}, got {param_device}")
1971
+ else:
1972
+ logger.info(f"[DEVICE] Model verified on: {param_device}")
1973
+
1974
+ def _load_learned_tokenizer(self):
1975
+ """
1976
+ Load learned BPE tokenizer from Genesis CDN.
1977
+
1978
+ This ensures the tokenizer used for inference matches the one used
1979
+ for training data tokenization, providing consistency across the network.
1980
+ """
1981
+ import requests
1982
+ import os
1983
+
1984
+ GENESIS_CDN_URL = "https://dwquwt9gkkeil.cloudfront.net"
1985
+ cache_dir = os.path.join(os.path.expanduser("~"), ".neuroshard", "data_cache")
1986
+
1987
+ try:
1988
+ tokenizer_url = f"{GENESIS_CDN_URL}/tokenizer.json"
1989
+ tokenizer_cache_path = os.path.join(cache_dir, "tokenizer.json")
1990
+
1991
+ # Try to fetch from CDN
1992
+ try:
1993
+ logger.debug(f"[TOKENIZER] Checking for learned tokenizer from {tokenizer_url}...")
1994
+ resp = requests.get(tokenizer_url, timeout=10)
1995
+
1996
+ if resp.status_code == 200:
1997
+ remote_tokenizer_data = resp.json()
1998
+ remote_vocab_size = remote_tokenizer_data.get("next_merge_id", 0)
1999
+
2000
+ # Cache locally
2001
+ os.makedirs(cache_dir, exist_ok=True)
2002
+ with open(tokenizer_cache_path, 'w') as f:
2003
+ f.write(resp.text)
2004
+
2005
+ # Update tokenizer if remote has more merges
2006
+ if remote_vocab_size > self.tokenizer.next_merge_id:
2007
+ from neuroshard.core.model.tokenizer import NeuroTokenizer
2008
+ learned_tokenizer = NeuroTokenizer.load(tokenizer_cache_path)
2009
+
2010
+ self.tokenizer.merges = learned_tokenizer.merges
2011
+ self.tokenizer.merge_to_tokens = learned_tokenizer.merge_to_tokens
2012
+ self.tokenizer.next_merge_id = learned_tokenizer.next_merge_id
2013
+
2014
+ logger.info(f"[TOKENIZER] Loaded BPE tokenizer: {self.tokenizer.current_vocab_size} tokens, {len(self.tokenizer.merges)} merges")
2015
+
2016
+ # CRITICAL: Check if model needs vocabulary expansion after loading new tokenizer
2017
+ if self.model is not None:
2018
+ self.model.tokenizer = self.tokenizer
2019
+ self.model.check_and_expand_vocab_if_needed()
2020
+ # Update layer pool's vocab_capacity for future layer calculations
2021
+ if hasattr(self, 'layer_pool') and self.layer_pool:
2022
+ self.layer_pool.vocab_capacity = self.model.vocab_capacity
2023
+ else:
2024
+ logger.debug(f"[TOKENIZER] Already up to date: {self.tokenizer.current_vocab_size} tokens")
2025
+ return
2026
+ except requests.RequestException as e:
2027
+ logger.debug(f"[TOKENIZER] CDN fetch failed: {e}")
2028
+
2029
+ # Fallback to cached version
2030
+ if os.path.exists(tokenizer_cache_path) and self.tokenizer.next_merge_id <= 266:
2031
+ try:
2032
+ from neuroshard.core.model.tokenizer import NeuroTokenizer
2033
+ learned_tokenizer = NeuroTokenizer.load(tokenizer_cache_path)
2034
+
2035
+ if learned_tokenizer.next_merge_id > self.tokenizer.next_merge_id:
2036
+ self.tokenizer.merges = learned_tokenizer.merges
2037
+ self.tokenizer.merge_to_tokens = learned_tokenizer.merge_to_tokens
2038
+ self.tokenizer.next_merge_id = learned_tokenizer.next_merge_id
2039
+ logger.info(f"[TOKENIZER] Loaded cached BPE tokenizer: {self.tokenizer.current_vocab_size} tokens")
2040
+
2041
+ # CRITICAL: Check if model needs vocabulary expansion
2042
+ if self.model is not None:
2043
+ self.model.tokenizer = self.tokenizer
2044
+ self.model.check_and_expand_vocab_if_needed()
2045
+ # Update layer pool's vocab_capacity for future layer calculations
2046
+ if hasattr(self, 'layer_pool') and self.layer_pool:
2047
+ self.layer_pool.vocab_capacity = self.model.vocab_capacity
2048
+ except Exception as e:
2049
+ logger.warning(f"[TOKENIZER] Failed to load cached tokenizer: {e}")
2050
+
2051
+ except Exception as e:
2052
+ logger.warning(f"[TOKENIZER] Error loading learned tokenizer: {e}")
2053
+
2054
+ def _setup_training(self):
2055
+ """Setup training components."""
2056
+ from neuroshard.core.training.distributed import FederatedDataManager
2057
+
2058
+ # Collect all parameters from my layers
2059
+ all_params = []
2060
+ for layer in self.model.my_layers.values():
2061
+ all_params.extend(layer.parameters())
2062
+ if self.model.embedding:
2063
+ all_params.extend(self.model.embedding.parameters())
2064
+ if self.model.lm_head:
2065
+ all_params.extend(self.model.lm_head.parameters())
2066
+ if self.model.final_norm:
2067
+ all_params.extend(self.model.final_norm.parameters())
2068
+
2069
+ self.optimizer = torch.optim.AdamW(all_params, lr=1e-4, weight_decay=0.01)
2070
+
2071
+ self.data_manager = FederatedDataManager(
2072
+ tokenizer=self.tokenizer,
2073
+ max_seq_len=2048
2074
+ )
2075
+
2076
+ # DYNAMIC TRAINING CONFIG: Calculate based on current model size and device
2077
+ # This will be recalculated when model grows via recalculate_training_config()
2078
+ num_layers = len(self.my_layer_ids)
2079
+
2080
+ # Smart gradient checkpointing decision
2081
+ # CRITICAL: Always enable checkpointing for models with many layers!
2082
+ # Without checkpointing, attention scores alone need: batch × heads × seq² × layers × 4 bytes
2083
+ # For 46 layers: 8 × 16 × 2048² × 46 × 4 = ~92GB (way more than any GPU!)
2084
+
2085
+ # SIMPLE RULE: Enable checkpointing if layers > 16 (always safe)
2086
+ # This avoids complex calculations that can have bugs with timing of vocab expansion
2087
+ if num_layers > 16:
2088
+ self._use_gradient_checkpointing = True
2089
+ logger.info(f"[NODE] Gradient checkpointing: ENABLED (layers={num_layers} > 16)")
2090
+ elif self.device != "cuda":
2091
+ # CPU/MPS always use checkpointing for memory efficiency
2092
+ self._use_gradient_checkpointing = True
2093
+ logger.info(f"[NODE] Gradient checkpointing: ENABLED (device={self.device})")
2094
+ else:
2095
+ # Small CUDA models can skip checkpointing for speed
2096
+ self._use_gradient_checkpointing = False
2097
+ logger.info(f"[NODE] Gradient checkpointing: DISABLED (layers={num_layers} ≤ 16, CUDA)")
2098
+
2099
+ # Calculate memory-aware training batch size
2100
+ self._training_batch_size = self._calculate_training_batch_size()
2101
+
2102
+ logger.info(f"Training initialized: batch_size={self._training_batch_size}, "
2103
+ f"checkpointing={self._use_gradient_checkpointing}, "
2104
+ f"layers={num_layers}, device={self.device}")
2105
+
2106
+ # CUDA sanity check: verify GPU is actually usable
2107
+ if self.device == "cuda":
2108
+ try:
2109
+ import time as _time
2110
+ test_tensor = torch.randn(1000, 1000, device="cuda")
2111
+ start = _time.time()
2112
+ _ = torch.matmul(test_tensor, test_tensor)
2113
+ torch.cuda.synchronize()
2114
+ elapsed = _time.time() - start
2115
+ del test_tensor
2116
+ torch.cuda.empty_cache()
2117
+ logger.info(f"[CUDA] GPU sanity check passed: 1000x1000 matmul in {elapsed*1000:.1f}ms")
2118
+ except Exception as e:
2119
+ logger.error(f"[CUDA] GPU sanity check FAILED: {e}")
2120
+ logger.error("[CUDA] Training will likely run on CPU despite device=cuda!")
2121
+
2122
+ def _calculate_training_batch_size(self) -> int:
2123
+ """
2124
+ Calculate optimal batch size based on available memory, device, and model size.
2125
+
2126
+ DYNAMIC: This is called initially and can be recalculated when model grows.
2127
+ SMART: Considers GPU memory, gradient checkpointing, and actual model size.
2128
+ """
2129
+ seq_len = 512 # Typical sequence length
2130
+ hidden_dim = self.layer_pool.current_architecture.hidden_dim
2131
+ num_layers = len(self.my_layer_ids)
2132
+
2133
+ # Calculate model memory footprint (params + gradients + optimizer states)
2134
+ model_params = sum(p.numel() for p in self.model.parameters())
2135
+ # Model memory: weights (fp32=4 bytes) × 4 (weights + grads + adam_m + adam_v)
2136
+ model_memory_mb = (model_params * 4 * 4) / (1024 * 1024)
2137
+
2138
+ # For CUDA, check actual GPU memory available
2139
+ if self.device == "cuda":
2140
+ try:
2141
+ gpu_total = torch.cuda.get_device_properties(0).total_memory / (1024 * 1024)
2142
+ gpu_allocated = torch.cuda.memory_allocated(0) / (1024 * 1024)
2143
+ logger.info(f"[NODE] CUDA memory: {gpu_allocated:.0f}MB used / {gpu_total:.0f}MB total")
2144
+ effective_memory_mb = self.available_memory_mb
2145
+ except Exception:
2146
+ effective_memory_mb = self.available_memory_mb
2147
+ else:
2148
+ effective_memory_mb = self.available_memory_mb
2149
+
2150
+ # CORRECT FORMULA: Available for activations = Total - Model memory
2151
+ # Leave 10% buffer for system overhead
2152
+ available_for_activations = max(100, (effective_memory_mb * 0.9) - model_memory_mb)
2153
+
2154
+ # With gradient checkpointing, activation memory is MUCH lower
2155
+ use_checkpointing = getattr(self, '_use_gradient_checkpointing', False)
2156
+ if use_checkpointing:
2157
+ # Checkpointing: Only need to store ~sqrt(num_layers) worth of activations
2158
+ # Plus inputs/outputs at checkpoint boundaries
2159
+ checkpoint_segments = max(1, int(num_layers ** 0.5))
2160
+ # Memory per sample: seq_len × hidden_dim × checkpoint_segments × 4 bytes × 2 (fwd+bwd)
2161
+ mem_per_sample_mb = (seq_len * hidden_dim * checkpoint_segments * 4 * 2) / (1024 * 1024)
2162
+ logger.info(f"[NODE] Gradient checkpointing: {checkpoint_segments} segments "
2163
+ f"(~{mem_per_sample_mb:.1f}MB/sample)")
2164
+ else:
2165
+ # No checkpointing: full activation memory for all layers
2166
+ mem_per_sample_mb = (seq_len * hidden_dim * num_layers * 4 * 2) / (1024 * 1024)
2167
+
2168
+ logger.info(f"[NODE] Memory budget: total={effective_memory_mb:.0f}MB, "
2169
+ f"model={model_memory_mb:.0f}MB, "
2170
+ f"available_for_activations={available_for_activations:.0f}MB")
2171
+
2172
+ # Calculate max batch size from available memory
2173
+ max_batch = max(1, int(available_for_activations / max(1, mem_per_sample_mb)))
2174
+
2175
+ # SMART CLAMPING based on device capability
2176
+ if self.device == "cuda" and effective_memory_mb > 16000:
2177
+ # High-memory CUDA (Jetson Orin 32GB, RTX 3090 24GB): up to 8
2178
+ max_batch = min(max_batch, 8)
2179
+ elif self.device == "cuda" and effective_memory_mb > 8000:
2180
+ # Medium CUDA: up to 4
2181
+ max_batch = min(max_batch, 4)
2182
+ elif self.device == "cuda":
2183
+ # Small CUDA: up to 2
2184
+ max_batch = min(max_batch, 2)
2185
+ elif num_layers > 100:
2186
+ # Large model on CPU/MPS: conservative
2187
+ max_batch = min(max_batch, 2)
2188
+ else:
2189
+ max_batch = min(max_batch, 4)
2190
+
2191
+ batch_size = max(1, max_batch)
2192
+
2193
+ logger.info(f"[NODE] Training config: batch_size={batch_size}, "
2194
+ f"model={model_params/1e6:.1f}M params ({num_layers} layers × {hidden_dim} dim), "
2195
+ f"checkpointing={use_checkpointing}, device={self.device}")
2196
+
2197
+ return batch_size
2198
+
2199
+ def recalculate_training_config(self):
2200
+ """
2201
+ Recalculate training configuration after model architecture changes.
2202
+
2203
+ Called when:
2204
+ - Model grows (new layers added)
2205
+ - Memory allocation changes
2206
+ - Device changes
2207
+ """
2208
+ old_batch = getattr(self, '_training_batch_size', None)
2209
+ self._training_batch_size = self._calculate_training_batch_size()
2210
+
2211
+ # Update gradient checkpointing based on new model size
2212
+ num_layers = len(self.my_layer_ids)
2213
+ old_checkpointing = getattr(self, '_use_gradient_checkpointing', False)
2214
+
2215
+ # Simple checkpointing rule: enable if layers > 16
2216
+ if num_layers > 16:
2217
+ self._use_gradient_checkpointing = True
2218
+ elif self.device != "cuda":
2219
+ self._use_gradient_checkpointing = True
2220
+ else:
2221
+ self._use_gradient_checkpointing = False
2222
+
2223
+ if old_batch != self._training_batch_size or old_checkpointing != self._use_gradient_checkpointing:
2224
+ logger.info(f"[NODE] Training config updated: batch_size={old_batch}→{self._training_batch_size}, "
2225
+ f"checkpointing={old_checkpointing}→{self._use_gradient_checkpointing}")
2226
+
2227
+ def stop(self):
2228
+ """Stop the node."""
2229
+ logger.info("Stopping DynamicNeuroNode...")
2230
+
2231
+ self.is_running = False
2232
+
2233
+ # Unregister from network
2234
+ if self.layer_pool:
2235
+ self.layer_pool.unregister_node(self.node_id)
2236
+
2237
+ # Save checkpoint
2238
+ self._save_checkpoint()
2239
+
2240
+ logger.info("DynamicNeuroNode stopped")
2241
+
2242
+ def connect_p2p(self, p2p_manager):
2243
+ """Connect to P2P network."""
2244
+ self.p2p_manager = p2p_manager
2245
+
2246
+ # Initialize Data Swarm
2247
+ from neuroshard.core.network.p2p_data import DataSwarm
2248
+
2249
+ # Ensure cache dir exists in a writable location
2250
+ data_cache_dir = self.CHECKPOINT_DIR / "data_cache"
2251
+ data_cache_dir.mkdir(parents=True, exist_ok=True)
2252
+
2253
+ self.swarm = DataSwarm(p2p_manager, cache_dir=str(data_cache_dir))
2254
+
2255
+ # Update layer pool with DHT
2256
+ if self.layer_pool and hasattr(p2p_manager, 'dht'):
2257
+ self.layer_pool.dht = p2p_manager.dht
2258
+
2259
+ # IMPORTANT: Give model access to p2p_manager for dynamic layer updates
2260
+ # When vocab expansion removes layers, the model needs to update DHT
2261
+ if self.model:
2262
+ self.model._p2p_manager = p2p_manager
2263
+
2264
+ logger.info("Connected to P2P network and Data Swarm")
2265
+
2266
+ # ==================== INFERENCE ====================
2267
+
2268
+ def forward(self, input_ids: torch.Tensor, session_id: Optional[str] = None) -> torch.Tensor:
2269
+ """
2270
+ Forward pass - routes through network if needed.
2271
+
2272
+ If this node has all layers: process locally
2273
+ If not: forward to nodes with other layers
2274
+ """
2275
+ # Check if we can do full inference locally
2276
+ capacity = self.layer_pool.get_network_capacity()
2277
+
2278
+ if len(self.my_layer_ids) == capacity.assigned_layers and self.model.has_embedding and self.model.has_lm_head:
2279
+ # We have everything - do local inference
2280
+ return self._forward_local(input_ids)
2281
+ else:
2282
+ # Need to route through network
2283
+ return self._forward_distributed(input_ids, session_id)
2284
+
2285
+ def _forward_local(self, input_ids: torch.Tensor) -> torch.Tensor:
2286
+ """Full local inference (when we have all layers)."""
2287
+ with torch.no_grad():
2288
+ # Embed
2289
+ hidden = self.model.embed(input_ids.to(self.device))
2290
+
2291
+ # Forward through all layers
2292
+ hidden = self.model.forward_my_layers(hidden)
2293
+
2294
+ # Compute logits
2295
+ logits = self.model.compute_logits(hidden)
2296
+
2297
+ self.inference_count += 1
2298
+ self.total_tokens_processed += input_ids.numel()
2299
+
2300
+ return logits
2301
+
2302
+ def _forward_distributed(self, input_ids: torch.Tensor, session_id: Optional[str] = None) -> torch.Tensor:
2303
+ """Distributed inference through network pipeline."""
2304
+ # Get pipeline route
2305
+ route = self.layer_pool.get_pipeline_route()
2306
+
2307
+ if not route:
2308
+ raise RuntimeError("No pipeline route available")
2309
+
2310
+ # Start with embedding
2311
+ if self.model.has_embedding:
2312
+ hidden = self.model.embed(input_ids.to(self.device))
2313
+ else:
2314
+ # Request embedding from holder
2315
+ hidden = self._request_embedding(input_ids)
2316
+
2317
+ # Forward through layers (local or remote)
2318
+ current_layer = 0
2319
+ for layer_id, grpc_addr in route:
2320
+ if layer_id in self.model.my_layers:
2321
+ # Local layer
2322
+ hidden, _ = self.model.my_layers[layer_id](hidden)
2323
+ else:
2324
+ # Remote layer - forward to peer
2325
+ hidden = self._forward_to_peer(grpc_addr, hidden, layer_id)
2326
+ current_layer = layer_id
2327
+
2328
+ # Compute logits
2329
+ if self.model.has_lm_head:
2330
+ logits = self.model.compute_logits(hidden)
2331
+ else:
2332
+ # Request from holder
2333
+ logits = self._request_logits(hidden)
2334
+
2335
+ self.inference_count += 1
2336
+ self.total_tokens_processed += input_ids.numel()
2337
+
2338
+ return logits
2339
+
2340
+ def forward_pipeline(
2341
+ self,
2342
+ hidden_states: torch.Tensor,
2343
+ attention_mask: Optional[torch.Tensor] = None,
2344
+ position_ids: Optional[torch.Tensor] = None,
2345
+ training_labels: Optional[torch.Tensor] = None,
2346
+ session_id: Optional[str] = None,
2347
+ sender_url: Optional[str] = None,
2348
+ use_cache: bool = False
2349
+ ) -> Tuple[torch.Tensor, Optional[List[torch.Tensor]]]:
2350
+ """
2351
+ Forward pass for pipeline parallelism (received from peer).
2352
+ """
2353
+ # Enable gradient tracking if training
2354
+ is_training = training_labels is not None
2355
+
2356
+ if is_training:
2357
+ hidden_states.requires_grad_(True)
2358
+ hidden_states.retain_grad()
2359
+
2360
+ # Check if input is token IDs (embedding request)
2361
+ # Integer dtype or 2D shape [batch, seq] implies input_ids
2362
+ # This happens when a client sends input_ids to the Driver (Layer 0)
2363
+ if (hidden_states.dtype in [torch.long, torch.int64, torch.int32] or
2364
+ len(hidden_states.shape) == 2) and self.model.has_embedding:
2365
+
2366
+ # Ensure correct dtype
2367
+ if hidden_states.dtype != torch.long:
2368
+ hidden_states = hidden_states.to(torch.long)
2369
+
2370
+ # Embed tokens
2371
+ hidden_states = self.model.embed(hidden_states)
2372
+
2373
+ if is_training:
2374
+ hidden_states.requires_grad_(True)
2375
+ hidden_states.retain_grad()
2376
+
2377
+ # Forward through local layers
2378
+ output = self.model.forward_my_layers(hidden_states)
2379
+
2380
+ if is_training and session_id:
2381
+ # Save context for backward pass
2382
+ self.training_context[session_id] = {
2383
+ "input": hidden_states,
2384
+ "output": output,
2385
+ "sender_url": sender_url,
2386
+ "timestamp": time.time()
2387
+ }
2388
+ # Cleanup old sessions
2389
+ now = time.time()
2390
+ to_remove = [s for s, ctx in self.training_context.items() if now - ctx["timestamp"] > 600]
2391
+ for s in to_remove:
2392
+ del self.training_context[s]
2393
+
2394
+ # If we are the Validator (Last Layer holder)
2395
+ # DYNAMIC CHECK: Query layer_pool for current lm_head_holder
2396
+ # This handles the case where a new Validator joined and took over
2397
+ is_current_validator = self.model.has_lm_head
2398
+ if hasattr(self, 'layer_pool') and self.layer_pool:
2399
+ is_current_validator = (self.layer_pool.lm_head_holder == self.node_id)
2400
+
2401
+ if is_current_validator:
2402
+ logits = self.model.compute_logits(output)
2403
+
2404
+ # Calculate Loss if labels present
2405
+ if training_labels is not None:
2406
+ loss = torch.nn.functional.cross_entropy(
2407
+ logits.view(-1, logits.size(-1)),
2408
+ training_labels.view(-1),
2409
+ ignore_index=-100
2410
+ )
2411
+
2412
+ # Use training lock to prevent conflict with local training
2413
+ with self._training_lock:
2414
+ # Trigger Backward Pass
2415
+ self.optimizer.zero_grad()
2416
+ loss.backward()
2417
+
2418
+ # Propagate gradient back to previous node
2419
+ if sender_url and session_id:
2420
+ # The gradient we send back is dL/d(input_hidden_states)
2421
+ # hidden_states.grad is populated by backward()
2422
+ if hidden_states.grad is not None:
2423
+ self._backward_to_peer(
2424
+ sender_url,
2425
+ hidden_states.grad,
2426
+ # Target shard is whatever layer sent this to us.
2427
+ # Assuming sender holds previous layers.
2428
+ # We send to the sender's LAST layer.
2429
+ # Simplified: just send to the node, it routes.
2430
+ 0,
2431
+ session_id
2432
+ )
2433
+
2434
+ # Step Optimizer
2435
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
2436
+ self.optimizer.step()
2437
+
2438
+ self.total_training_rounds += 1
2439
+ self.current_loss = loss.item()
2440
+
2441
+ return logits, None
2442
+
2443
+ return logits, None
2444
+
2445
+ # If we are a Worker (Middle Layer), we need to forward to next peer
2446
+ my_last_layer = max(self.my_layer_ids) if self.my_layer_ids else 0
2447
+ next_layer = my_last_layer + 1
2448
+
2449
+ if self.p2p_manager:
2450
+ next_hop = self.p2p_manager.get_next_hop(next_layer)
2451
+ if next_hop:
2452
+ return self._forward_to_peer(
2453
+ next_hop,
2454
+ output,
2455
+ next_layer,
2456
+ labels=training_labels,
2457
+ session_id=session_id
2458
+ )
2459
+
2460
+ logger.warning(f"Pipeline broken at layer {next_layer}: no peer found")
2461
+ return output, None
2462
+
2463
+ def backward_pipeline(self, grad_output: torch.Tensor, session_id: str):
2464
+ """
2465
+ Backward pass received from next peer.
2466
+ """
2467
+ if session_id not in self.training_context:
2468
+ logger.warning(f"Received backward for unknown session {session_id}")
2469
+ return
2470
+
2471
+ ctx = self.training_context[session_id]
2472
+ output = ctx["output"]
2473
+ input_tensor = ctx["input"]
2474
+ sender_url = ctx["sender_url"]
2475
+
2476
+ # Use training lock to prevent conflict with local training
2477
+ with self._training_lock:
2478
+ # Run local backward
2479
+ # output is the tensor we produced in forward_pipeline
2480
+ # grad_output is dL/d(output) received from next peer
2481
+ self.optimizer.zero_grad()
2482
+ output.backward(grad_output)
2483
+
2484
+ # Propagate back
2485
+ if sender_url and input_tensor.grad is not None:
2486
+ # Find previous layer ID? Not strictly needed for routing if we have direct sender URL
2487
+ # But _backward_to_peer takes layer_id
2488
+ self._backward_to_peer(sender_url, input_tensor.grad, 0, session_id)
2489
+
2490
+ # Step Optimizer
2491
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
2492
+ self.optimizer.step()
2493
+
2494
+ # Cleanup
2495
+ del self.training_context[session_id]
2496
+
2497
+ def _request_embedding(self, input_ids: torch.Tensor) -> torch.Tensor:
2498
+ """Request embedding from the node that holds it."""
2499
+ # Find a node that holds Layer 0 (Driver)
2500
+ peer_url = None
2501
+
2502
+ # 1. Check layer pool assignments
2503
+ if self.layer_pool:
2504
+ assignments = self.layer_pool.get_layer_holders(0)
2505
+ if assignments:
2506
+ # Pick one (e.g., random for load balancing)
2507
+ import random
2508
+ peer_url = random.choice(assignments).grpc_addr
2509
+
2510
+ # 2. Fallback to P2P manager routing
2511
+ if not peer_url and self.p2p_manager:
2512
+ peer_url = self.p2p_manager.get_next_hop(0)
2513
+
2514
+ if not peer_url:
2515
+ raise RuntimeError("No embedding holder (Driver/Layer 0) found in network")
2516
+
2517
+ # Call peer - Send input_ids to Layer 0 holder
2518
+ # The receiver's forward_pipeline will detect it's input_ids and run embed()
2519
+ result, _ = self._forward_to_peer(peer_url, input_ids, 0)
2520
+ return result
2521
+
2522
+ def _forward_to_peer(self, peer_url: str, hidden: torch.Tensor, layer_id: int, labels: Optional[torch.Tensor] = None, session_id: str = None) -> torch.Tensor:
2523
+ """
2524
+ Forward hidden states to a peer for processing.
2525
+
2526
+ SECURITY: Calculates and validates SHA256 checksums to detect tampering.
2527
+ """
2528
+ from protos import neuroshard_pb2
2529
+ from protos import neuroshard_pb2_grpc
2530
+ from neuroshard.core.network.connection_pool import get_channel
2531
+ import numpy as np
2532
+ import hashlib
2533
+
2534
+ try:
2535
+ parsed = urlparse(peer_url)
2536
+ ip = parsed.hostname
2537
+ # gRPC port convention
2538
+ port = (parsed.port or 80) + 1000
2539
+
2540
+ channel = get_channel(f"{ip}:{port}")
2541
+ stub = neuroshard_pb2_grpc.NeuroShardServiceStub(channel)
2542
+
2543
+ # Serialize hidden states
2544
+ hidden_bytes = hidden.detach().cpu().numpy().tobytes()
2545
+ hidden_shape = list(hidden.shape)
2546
+
2547
+ # CHECKSUM: Calculate SHA256 hash for integrity verification
2548
+ checksum = hashlib.sha256(hidden_bytes).hexdigest()
2549
+ logger.debug(f"[SECURITY] Sending layer {layer_id} with checksum: {checksum[:16]}...")
2550
+
2551
+ # Serialize labels if present
2552
+ labels_bytes = b""
2553
+ if labels is not None:
2554
+ labels_bytes = labels.cpu().numpy().tobytes()
2555
+
2556
+ req_session_id = session_id or f"train_{time.time()}"
2557
+
2558
+ # Get my URL for backward routing
2559
+ my_url = ""
2560
+ if self.p2p_manager:
2561
+ my_url = self.p2p_manager.my_url
2562
+
2563
+ req = neuroshard_pb2.PipelineForwardRequest(
2564
+ session_id=req_session_id,
2565
+ request_id=f"req_{time.time()}",
2566
+ hidden_states=hidden_bytes,
2567
+ hidden_shape=hidden_shape,
2568
+ target_shard=layer_id,
2569
+ use_cache=False,
2570
+ training_labels=labels_bytes,
2571
+ sender_url=my_url
2572
+ )
2573
+
2574
+ # Store context for backward pass
2575
+ # We need to know WHO sent us this so we can send gradients back?
2576
+ # No, this function is called by US sending to THEM.
2577
+ # We need to know who THEY are so when they send us gradients back, we verify?
2578
+ # Actually, we don't need to do anything here for backward.
2579
+ # They will call PipelineBackward on US.
2580
+
2581
+ resp = stub.PipelineForward(req, timeout=30.0)
2582
+
2583
+ if not resp.success:
2584
+ raise RuntimeError(f"Peer error: {resp.error_message}")
2585
+
2586
+ # Deserialize result
2587
+ if resp.is_final:
2588
+ # It's logits
2589
+ result_bytes = resp.logits
2590
+ result = torch.from_numpy(
2591
+ np.frombuffer(result_bytes, dtype=np.float32)
2592
+ ).reshape(list(resp.logits_shape))
2593
+ else:
2594
+ # It's hidden states (recursive/chained)
2595
+ result_bytes = resp.hidden_states
2596
+ result = torch.from_numpy(
2597
+ np.frombuffer(result_bytes, dtype=np.float32)
2598
+ ).reshape(list(resp.hidden_shape))
2599
+
2600
+ # CHECKSUM VALIDATION: Verify integrity of received data
2601
+ received_checksum = hashlib.sha256(result_bytes).hexdigest()
2602
+ logger.debug(f"[SECURITY] Received layer {layer_id} result with checksum: {received_checksum[:16]}...")
2603
+
2604
+ # AUDIT TRAIL: Store checksum in PipelineSession for tamper detection
2605
+ if session_id and self.ledger and hasattr(self.ledger, 'inference_market'):
2606
+ market = self.ledger.inference_market
2607
+ if market and hasattr(market, 'active_sessions'):
2608
+ for sess_id, session in market.active_sessions.items():
2609
+ if sess_id == session_id or session.request_id in session_id:
2610
+ session.activations_hashes.append(received_checksum)
2611
+ logger.debug(f"[AUDIT] Stored checksum for layer {layer_id} in session")
2612
+ break
2613
+
2614
+ return result.to(self.device), None
2615
+
2616
+ except Exception as e:
2617
+ logger.error(f"Failed to forward to peer {peer_url}: {e}")
2618
+ return hidden, None
2619
+
2620
+ def _backward_to_peer(self, peer_url: str, grad_output: torch.Tensor, layer_id: int, session_id: str):
2621
+ """Send gradients back to the previous peer."""
2622
+ from protos import neuroshard_pb2
2623
+ from protos import neuroshard_pb2_grpc
2624
+ from neuroshard.core.network.connection_pool import get_channel
2625
+
2626
+ try:
2627
+ parsed = urlparse(peer_url)
2628
+ ip = parsed.hostname
2629
+ port = (parsed.port or 80) + 1000
2630
+
2631
+ channel = get_channel(f"{ip}:{port}")
2632
+ stub = neuroshard_pb2_grpc.NeuroShardServiceStub(channel)
2633
+
2634
+ grad_bytes = grad_output.detach().cpu().numpy().tobytes()
2635
+ grad_shape = list(grad_output.shape)
2636
+
2637
+ req = neuroshard_pb2.PipelineBackwardRequest(
2638
+ session_id=session_id,
2639
+ request_id=f"bw_{time.time()}",
2640
+ grad_output=grad_bytes,
2641
+ grad_shape=grad_shape,
2642
+ target_shard=layer_id
2643
+ )
2644
+
2645
+ stub.PipelineBackward(req, timeout=10.0)
2646
+
2647
+ except Exception as e:
2648
+ logger.error(f"Failed to backward to peer {peer_url}: {e}")
2649
+
2650
+ def _request_logits(self, hidden: torch.Tensor) -> torch.Tensor:
2651
+ """Request logits from the node that holds LM head."""
2652
+ # Find Last Layer holder (Validator)
2653
+ if not self.layer_pool:
2654
+ return hidden
2655
+
2656
+ capacity = self.layer_pool.get_network_capacity()
2657
+ last_layer = max(0, capacity.assigned_layers - 1)
2658
+
2659
+ peer_url = None
2660
+
2661
+ # 1. Check layer pool assignments
2662
+ assignments = self.layer_pool.get_layer_holders(last_layer)
2663
+ if assignments:
2664
+ import random
2665
+ peer_url = random.choice(assignments).grpc_addr
2666
+
2667
+ # 2. Fallback to P2P manager
2668
+ if not peer_url and self.p2p_manager:
2669
+ peer_url = self.p2p_manager.get_next_hop(last_layer)
2670
+
2671
+ if not peer_url:
2672
+ raise RuntimeError(f"No Validator (Layer {last_layer}) found in network")
2673
+
2674
+ # Forward hidden states to peer targeting Last Layer
2675
+ # The receiver will compute logits and return is_final=True
2676
+ return self._forward_to_peer(peer_url, hidden, last_layer)
2677
+
2678
+ def generate(
2679
+ self,
2680
+ prompt: str,
2681
+ max_new_tokens: int = 50,
2682
+ temperature: float = 1.0,
2683
+ ) -> str:
2684
+ """Generate text from prompt."""
2685
+ try:
2686
+ if not self.tokenizer:
2687
+ raise RuntimeError("Tokenizer not initialized")
2688
+
2689
+ input_ids = torch.tensor([self.tokenizer.encode(prompt)], dtype=torch.long)
2690
+ logger.debug(f"[GENERATE] Encoded prompt: {input_ids.shape} tokens")
2691
+
2692
+ # Move to model's device (handles CPU, CUDA, MPS)
2693
+ generated = input_ids.clone().to(self.device)
2694
+
2695
+ # Get current vocabulary size from tokenizer
2696
+ # Only tokens 0 to current_vocab_size-1 are valid (have learned representations)
2697
+ # This is NOT a workaround - it's how BPE tokenizers work (vocab grows over time)
2698
+ valid_vocab_size = self.tokenizer.current_vocab_size
2699
+
2700
+ for step in range(max_new_tokens):
2701
+ logits = self.forward(generated)
2702
+ next_logits = logits[:, -1, :] / temperature
2703
+
2704
+ # Constrain to valid vocabulary (standard BPE tokenizer behavior)
2705
+ # Tokens beyond current_vocab_size don't exist in the tokenizer yet
2706
+ if valid_vocab_size < next_logits.size(-1):
2707
+ next_logits[:, valid_vocab_size:] = float('-inf')
2708
+
2709
+ probs = torch.softmax(next_logits, dim=-1)
2710
+ next_token = torch.multinomial(probs, num_samples=1)
2711
+ generated = torch.cat([generated, next_token], dim=-1)
2712
+
2713
+ if next_token.item() == 2: # EOS
2714
+ logger.debug(f"[GENERATE] EOS at step {step+1}")
2715
+ break
2716
+
2717
+ prompt_tokens = input_ids.size(1)
2718
+ new_tokens = generated[0, prompt_tokens:].tolist()
2719
+ result = self.tokenizer.decode(new_tokens)
2720
+ logger.debug(f"[GENERATE] Generated {len(new_tokens)} tokens: '{result[:100]}...'")
2721
+
2722
+ return result
2723
+
2724
+ except Exception as e:
2725
+ logger.error(f"[GENERATE] Error: {e}")
2726
+ import traceback
2727
+ logger.error(traceback.format_exc())
2728
+ raise
2729
+
2730
+ # ==================== TRAINING ====================
2731
+
2732
+ def contribute_training_data(self, text: str, apply_dp: bool = True) -> int:
2733
+ """
2734
+ Contribute training data.
2735
+
2736
+ Returns the number of tokens added.
2737
+ """
2738
+ if not self.data_manager:
2739
+ return 0
2740
+
2741
+ # Get token count before
2742
+ stats_before = self.data_manager.get_stats()
2743
+ tokens_before = stats_before.get("total_tokens", 0)
2744
+
2745
+ self.data_manager.add_text(text, apply_dp=apply_dp)
2746
+
2747
+ # Get token count after
2748
+ stats_after = self.data_manager.get_stats()
2749
+ tokens_after = stats_after.get("total_tokens", 0)
2750
+
2751
+ tokens_added = tokens_after - tokens_before
2752
+ logger.info(f"Added {tokens_added} tokens to training buffer")
2753
+
2754
+ return tokens_added
2755
+
2756
+ def _get_training_batch(self) -> Optional[Tuple[torch.Tensor, torch.Tensor]]:
2757
+ """
2758
+ Get a training batch from the Genesis data loader.
2759
+
2760
+ Returns:
2761
+ Tuple of (input_ids, labels) or None if data not available.
2762
+
2763
+ Note: Only Drivers (with embedding) load training data.
2764
+ Workers wait for activations via pipeline, they don't need data directly.
2765
+ """
2766
+ if not self.enable_training:
2767
+ return None
2768
+
2769
+ # WORKERS DON'T LOAD DATA - they receive activations via pipeline
2770
+ # Only Drivers (with embedding) need to load training data
2771
+ if not self.model.has_embedding:
2772
+ return None # Worker - skip training data loading
2773
+
2774
+ # Initialize genesis loader if needed
2775
+ if not hasattr(self, 'genesis_loader') or self.genesis_loader is None:
2776
+ try:
2777
+ from neuroshard.core.training.distributed import GenesisDataLoader
2778
+ from neuroshard.core.model.tokenizer import get_neuro_tokenizer
2779
+ logger.info("[GENESIS] Initializing data loader...")
2780
+ self.genesis_loader = GenesisDataLoader(
2781
+ self.node_id,
2782
+ get_neuro_tokenizer(),
2783
+ max_storage_mb=self.max_storage_mb
2784
+ )
2785
+ logger.info(f"[GENESIS] Data loader ready: {self.genesis_loader.total_shards} shards available")
2786
+
2787
+ # Connect Swarm to Loader
2788
+ if hasattr(self, 'swarm') and self.swarm:
2789
+ self.genesis_loader.set_swarm(self.swarm)
2790
+ except Exception as e:
2791
+ logger.warning(f"[GENESIS] Failed to initialize loader: {e}")
2792
+ return None
2793
+
2794
+ # Check if data is ready
2795
+ if not self.genesis_loader.is_data_ready():
2796
+ return None
2797
+
2798
+ # Get batch
2799
+ batch_size = getattr(self, '_training_batch_size', 2)
2800
+ try:
2801
+ input_ids, labels = self.genesis_loader.get_batch(batch_size=batch_size)
2802
+ return input_ids, labels
2803
+ except Exception as e:
2804
+ logger.warning(f"[GENESIS] Failed to get batch: {e}")
2805
+ return None
2806
+
2807
+ def train_step(self) -> Optional[float]:
2808
+ """
2809
+ Perform a training step on my layers.
2810
+
2811
+ OPTIMIZED FOR SINGLE-NODE: When we have embedding + all layers + LM head,
2812
+ we skip distributed overhead and train locally.
2813
+
2814
+ NON-BLOCKING DATA: Uses prefetched data when available, raises RuntimeError
2815
+ if data not ready (caller should retry later).
2816
+ """
2817
+ if not self.enable_training:
2818
+ return None
2819
+
2820
+ # RUNTIME MEMORY CHECK: Skip training if system memory is critically high
2821
+ # This prevents OOM crashes and keeps the system responsive
2822
+ try:
2823
+ import psutil
2824
+ mem = psutil.virtual_memory()
2825
+ # Skip if less than 15% of system RAM is free (critical threshold)
2826
+ if mem.percent > 85:
2827
+ logger.warning(f"[NODE] System memory at {mem.percent:.0f}%, skipping training step")
2828
+ # Also try to free some memory
2829
+ import gc
2830
+ gc.collect()
2831
+ if self.device == "cuda":
2832
+ torch.cuda.empty_cache()
2833
+ elif self.device == "mps":
2834
+ torch.mps.empty_cache()
2835
+ return None
2836
+ except Exception:
2837
+ pass # If psutil fails, continue anyway
2838
+
2839
+ try:
2840
+ # SINGLE-NODE OPTIMIZATION: Check if we're a full node (Driver + Worker + Validator)
2841
+ # DYNAMIC CHECK: Use layer_pool to get current lm_head_holder
2842
+ # This handles the case where a new Validator joined and took over the LM head
2843
+ am_current_validator = self.model.has_lm_head
2844
+ if hasattr(self, 'layer_pool') and self.layer_pool:
2845
+ am_current_validator = (self.layer_pool.lm_head_holder == self.node_id)
2846
+
2847
+ is_full_node = self.model.has_embedding and am_current_validator
2848
+
2849
+ if self.model.has_embedding:
2850
+ # I am a Driver (Layer 0)
2851
+ # Use Genesis Data Loader
2852
+ if not hasattr(self, 'genesis_loader') or self.genesis_loader is None:
2853
+ try:
2854
+ from neuroshard.core.training.distributed import GenesisDataLoader
2855
+ from neuroshard.core.model.tokenizer import get_neuro_tokenizer
2856
+ logger.info("[GENESIS] Initializing data loader...")
2857
+ self.genesis_loader = GenesisDataLoader(
2858
+ self.node_id,
2859
+ get_neuro_tokenizer(),
2860
+ max_storage_mb=self.max_storage_mb
2861
+ )
2862
+ logger.info(f"[GENESIS] Data loader ready: {self.genesis_loader.total_shards} shards available")
2863
+
2864
+ # Connect Swarm to Loader
2865
+ if hasattr(self, 'swarm') and self.swarm:
2866
+ self.genesis_loader.set_swarm(self.swarm)
2867
+ except Exception as e:
2868
+ import traceback
2869
+ logger.error(f"[GENESIS] ERROR: {type(e).__name__}: {e}")
2870
+ logger.error(f"[GENESIS] {traceback.format_exc()}")
2871
+ # Mark as failed so we don't keep retrying immediately
2872
+ self.genesis_loader = None
2873
+ raise RuntimeError(f"Genesis loader init failed: {e}")
2874
+
2875
+ # Check if data is ready (non-blocking)
2876
+ if not self.genesis_loader.is_data_ready():
2877
+ # Data not ready - don't block, let caller retry
2878
+ raise RuntimeError("Data not ready - shard still loading")
2879
+
2880
+ # Get batch from Genesis Shard using memory-aware batch size
2881
+ batch_size = getattr(self, '_training_batch_size', 2)
2882
+ try:
2883
+ input_ids, labels = self.genesis_loader.get_batch(batch_size=batch_size)
2884
+ input_ids = input_ids.to(self.device)
2885
+ labels = labels.to(self.device)
2886
+ except RuntimeError as e:
2887
+ # Data not ready - propagate up
2888
+ raise
2889
+ except Exception as e:
2890
+ logger.warning(f"[GENESIS] Failed to get batch: {type(e).__name__}: {e}")
2891
+ import traceback
2892
+ logger.warning(traceback.format_exc())
2893
+ return None
2894
+
2895
+ # SINGLE-NODE OPTIMIZED PATH: Skip distributed overhead
2896
+ if is_full_node:
2897
+ return self._train_step_local(input_ids, labels)
2898
+
2899
+ # DISTRIBUTED PATH: Forward to next peer
2900
+ # Forward pass with optional gradient checkpointing
2901
+ # Note: time.sleep(0) yields GIL to keep HTTP server responsive
2902
+ embeddings = self.model.embed(input_ids)
2903
+ embeddings.requires_grad_(True)
2904
+ embeddings.retain_grad()
2905
+ time.sleep(0) # Yield GIL
2906
+
2907
+ # Use gradient checkpointing if enabled (trades CPU for memory)
2908
+ if getattr(self, '_use_gradient_checkpointing', False):
2909
+ output = torch.utils.checkpoint.checkpoint(
2910
+ self.model.forward_my_layers,
2911
+ embeddings,
2912
+ use_reentrant=False
2913
+ )
2914
+ else:
2915
+ output = self.model.forward_my_layers(embeddings)
2916
+ time.sleep(0) # Yield GIL after forward pass
2917
+
2918
+ # Distributed: Send to next peer
2919
+ my_last_layer = max(self.my_layer_ids) if self.my_layer_ids else 0
2920
+ next_layer = my_last_layer + 1
2921
+
2922
+ if self.p2p_manager:
2923
+ next_hop = self.p2p_manager.get_next_hop(next_layer)
2924
+ if next_hop:
2925
+ session_id = f"train_{self.node_id}_{time.time()}"
2926
+
2927
+ # Save context for backward
2928
+ self.training_context[session_id] = {
2929
+ "input": embeddings,
2930
+ "output": output,
2931
+ "sender_url": None, # We are the start
2932
+ "timestamp": time.time()
2933
+ }
2934
+
2935
+ result, _ = self._forward_to_peer(
2936
+ next_hop,
2937
+ output,
2938
+ next_layer,
2939
+ labels=labels,
2940
+ session_id=session_id
2941
+ )
2942
+
2943
+ # Check if forward succeeded (result should be different from output if it was processed)
2944
+ # If the peer rejected or failed, result will be the original output (unchanged)
2945
+ forward_succeeded = result is not output
2946
+
2947
+ if not forward_succeeded:
2948
+ # Pipeline forward failed - peer rejected or error
2949
+ # Clean up the training context
2950
+ if session_id in self.training_context:
2951
+ del self.training_context[session_id]
2952
+ logger.warning(f"[DISTRIBUTED] Pipeline forward failed - skipping training step")
2953
+ return None
2954
+
2955
+ # We don't get loss immediately in distributed pipeline
2956
+ # It comes back later via backward pass or status update
2957
+ # For now, return None (not inf!)
2958
+ return None
2959
+
2960
+ return None
2961
+
2962
+ else:
2963
+ # I am a Worker/Validator
2964
+ # I wait for activations from peers via gRPC (forward_pipeline)
2965
+ # So this method does nothing actively
2966
+ return None
2967
+
2968
+ except RuntimeError as e:
2969
+ error_msg = str(e)
2970
+ if "not ready" in error_msg.lower():
2971
+ # Data not ready - propagate to caller
2972
+ raise
2973
+ elif "out of memory" in error_msg.lower() or "MPS" in error_msg:
2974
+ logger.warning(f"Training step OOM - reducing batch size and clearing cache")
2975
+ # Clear GPU cache
2976
+ import gc
2977
+ gc.collect()
2978
+ if self.device == "mps":
2979
+ torch.mps.empty_cache()
2980
+ elif self.device == "cuda":
2981
+ torch.cuda.empty_cache()
2982
+
2983
+ # Reduce batch size for next attempt
2984
+ current_batch = getattr(self, '_training_batch_size', 8)
2985
+ if current_batch > 1:
2986
+ self._training_batch_size = max(1, current_batch // 2)
2987
+ logger.info(f"Reduced batch size to {self._training_batch_size}")
2988
+ else:
2989
+ # Already at minimum batch size, fall back to CPU for training
2990
+ if self.device != "cpu":
2991
+ logger.warning(f"Batch size already at minimum. Consider using --memory flag to limit layers.")
2992
+ else:
2993
+ logger.error(f"Training step failed: {e}")
2994
+ return None
2995
+ except Exception as e:
2996
+ logger.error(f"Training step failed: {e}")
2997
+ return None
2998
+
2999
+ def _train_step_local(self, input_ids: torch.Tensor, labels: torch.Tensor) -> float:
3000
+ """
3001
+ OPTIMIZED single-node training step.
3002
+
3003
+ When we have ALL components (embedding + layers + LM head), we can
3004
+ train entirely locally without any network overhead.
3005
+ """
3006
+ # DIAGNOSTIC: Verify device placement periodically
3007
+ if self.total_training_rounds % 100 == 0:
3008
+ try:
3009
+ emb_device = next(self.model.embedding.parameters()).device if self.model.embedding else 'N/A'
3010
+ layer_device = next(iter(self.model.my_layers.values())).parameters().__next__().device if self.model.my_layers else 'N/A'
3011
+ logger.info(f"[TRAIN] Device check: input={input_ids.device}, embedding={emb_device}, layer0={layer_device}")
3012
+ except Exception as e:
3013
+ logger.warning(f"[TRAIN] Device check failed: {e}")
3014
+
3015
+ # Forward pass with optional gradient checkpointing
3016
+ embeddings = self.model.embed(input_ids)
3017
+
3018
+ # Use gradient checkpointing if enabled (trades CPU for memory)
3019
+ if getattr(self, '_use_gradient_checkpointing', False):
3020
+ output = torch.utils.checkpoint.checkpoint(
3021
+ self.model.forward_my_layers,
3022
+ embeddings,
3023
+ use_reentrant=False
3024
+ )
3025
+ else:
3026
+ output = self.model.forward_my_layers(embeddings)
3027
+
3028
+ # Compute logits and loss
3029
+ logits = self.model.compute_logits(output)
3030
+
3031
+ loss = torch.nn.functional.cross_entropy(
3032
+ logits.view(-1, logits.size(-1)),
3033
+ labels.view(-1),
3034
+ ignore_index=-100
3035
+ )
3036
+
3037
+ # Use training lock to prevent conflict with pipeline training
3038
+ with self._training_lock:
3039
+ # Backward pass
3040
+ self.optimizer.zero_grad()
3041
+ loss.backward()
3042
+
3043
+ # Gradient clipping and optimizer step
3044
+ torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
3045
+ self.optimizer.step()
3046
+
3047
+ # Update stats
3048
+ self.total_training_rounds += 1
3049
+ loss_val = loss.item()
3050
+ self.current_loss = loss_val
3051
+
3052
+ # PERIODIC CHECKPOINT: Save every 100 steps
3053
+ # Synchronous save blocks training briefly (~30-60s) but avoids memory pressure
3054
+ if self.total_training_rounds % 100 == 0:
3055
+ self._save_checkpoint()
3056
+
3057
+ return loss_val
3058
+
3059
+ # ==================== STATS & PONW ====================
3060
+
3061
+ def get_stats(self) -> Dict[str, Any]:
3062
+ """Get node statistics."""
3063
+ # Safety check for shutdown race condition
3064
+ model = getattr(self, 'model', None)
3065
+ layer_pool = getattr(self, 'layer_pool', None)
3066
+
3067
+ contribution = model.get_my_contribution() if model else {}
3068
+ capacity = layer_pool.get_network_capacity() if layer_pool else None
3069
+
3070
+ # Calculate reward multiplier
3071
+ my_layer_ids = getattr(self, 'my_layer_ids', [])
3072
+ network_layers = capacity.assigned_layers if capacity else len(my_layer_ids)
3073
+ reward_multiplier = calculate_reward_multiplier(
3074
+ num_layers_held=len(my_layer_ids),
3075
+ total_network_layers=network_layers or 1,
3076
+ has_embedding=model.has_embedding if model else False,
3077
+ has_lm_head=model.has_lm_head if model else False,
3078
+ )
3079
+
3080
+ # Estimate network params (rough: ~10M params per layer)
3081
+ network_params = network_layers * 10_000_000 if network_layers else 0
3082
+
3083
+ # Get data buffer size
3084
+ data_buffer_size = 0
3085
+ if self.data_manager:
3086
+ data_stats = self.data_manager.get_stats()
3087
+ data_buffer_size = data_stats.get("buffer_size", 0)
3088
+
3089
+ # Get shard stats (if we have a genesis loader)
3090
+ shard_stats = {}
3091
+ if hasattr(self, 'genesis_loader') and self.genesis_loader:
3092
+ shard_stats = self.genesis_loader.get_stats()
3093
+
3094
+ # Multi-node identity info
3095
+ instance_id = getattr(self, 'instance_id', None)
3096
+ wallet_id = getattr(self, 'wallet_id', None)
3097
+
3098
+ return {
3099
+ "node_id": self.node_id[:16] + "...",
3100
+ "instance_id": instance_id, # Unique per machine+port
3101
+ "wallet_id": wallet_id, # Same across instances with same token
3102
+ "available_memory_mb": self.available_memory_mb,
3103
+ "my_layers": self.my_layer_ids,
3104
+ "my_params": contribution.get("my_params", 0),
3105
+ "has_embedding": contribution.get("has_embedding", False),
3106
+ "has_lm_head": contribution.get("has_lm_head", False),
3107
+ "contribution_ratio": contribution.get("contribution_ratio", 0),
3108
+ "reward_multiplier": reward_multiplier,
3109
+ "network_layers": network_layers,
3110
+ "network_params": network_params,
3111
+ "network_nodes": capacity.total_nodes if capacity else 1,
3112
+ "total_tokens_processed": self.total_tokens_processed,
3113
+ "total_training_rounds": self.total_training_rounds,
3114
+ "current_loss": self.current_loss,
3115
+ "inference_count": self.inference_count,
3116
+ "data_buffer_size": data_buffer_size,
3117
+ "shard_stats": shard_stats,
3118
+ }
3119
+
3120
+ def get_ponw_proof(self) -> Dict[str, Any]:
3121
+ """
3122
+ Generate Proof of Neural Work.
3123
+
3124
+ This proof demonstrates verifiable neural network computation
3125
+ and is used for NEURO token rewards.
3126
+ """
3127
+ contribution = self.model.get_my_contribution() if self.model else {}
3128
+ capacity = self.layer_pool.get_network_capacity() if self.layer_pool else None
3129
+
3130
+ # Calculate reward multiplier
3131
+ multiplier = calculate_reward_multiplier(
3132
+ num_layers_held=len(self.my_layer_ids),
3133
+ total_network_layers=capacity.assigned_layers if capacity else 1,
3134
+ has_embedding=self.model.has_embedding if self.model else False,
3135
+ has_lm_head=self.model.has_lm_head if self.model else False,
3136
+ )
3137
+
3138
+ timestamp = time.time()
3139
+
3140
+ # Determine role
3141
+ role = "Worker"
3142
+ if self.model and self.model.has_embedding:
3143
+ role = "Driver"
3144
+ elif self.model and self.model.has_lm_head:
3145
+ role = "Validator"
3146
+
3147
+ proof_data = {
3148
+ "node_id": self.node_id,
3149
+ "timestamp": timestamp,
3150
+ "tokens_processed": self.total_tokens_processed,
3151
+ "training_rounds": self.total_training_rounds,
3152
+ "training_contributions": self.training_contribution_count,
3153
+ "inference_count": self.inference_count,
3154
+ "layers_held": len(self.my_layer_ids),
3155
+ "layer_ids": self.my_layer_ids,
3156
+ "has_embedding": self.model.has_embedding if self.model else False,
3157
+ "has_lm_head": self.model.has_lm_head if self.model else False,
3158
+ "role": role,
3159
+ "reward_multiplier": multiplier,
3160
+ "available_memory_mb": self.available_memory_mb,
3161
+ }
3162
+
3163
+ # Add model hash for verification
3164
+ # Use architecture-based hash (consistent with SwarmEnabledDynamicNode._get_model_hash)
3165
+ if self.model:
3166
+ hasher = hashlib.sha256()
3167
+ arch_str = f"{self.model.hidden_dim}:{len(self.my_layer_ids)}:{getattr(self.model, 'num_heads', 0)}"
3168
+ hasher.update(arch_str.encode())
3169
+ for name, param in sorted(self.model.named_parameters()):
3170
+ hasher.update(f"{name}:{list(param.shape)}".encode())
3171
+ proof_data["model_hash"] = hasher.hexdigest()[:16]
3172
+
3173
+ # Sign the proof
3174
+ proof_string = f"{self.node_id}:{timestamp}:{self.total_tokens_processed}:{len(self.my_layer_ids)}:{self.total_training_rounds}"
3175
+ if self.node_token:
3176
+ # Use HMAC for proper signing
3177
+ import hmac
3178
+ signature = hmac.new(
3179
+ self.node_token.encode(),
3180
+ proof_string.encode(),
3181
+ hashlib.sha256
3182
+ ).hexdigest()
3183
+ else:
3184
+ signature = hashlib.sha256(proof_string.encode()).hexdigest()
3185
+
3186
+ proof_data["signature"] = signature
3187
+
3188
+ return proof_data
3189
+
3190
+ def _prefetch_vocab_capacity(self):
3191
+ """
3192
+ Pre-fetch tokenizer vocab size to know how much memory embeddings will need.
3193
+
3194
+ This MUST be called before register_node() to ensure accurate layer assignment.
3195
+ Without this, we'd assign layers assuming 32K vocab, then OOM when vocab expands to 288K+.
3196
+ """
3197
+ import requests
3198
+ import os
3199
+
3200
+ GENESIS_CDN_URL = "https://dwquwt9gkkeil.cloudfront.net"
3201
+ cache_dir = os.path.join(os.path.expanduser("~"), ".neuroshard", "data_cache")
3202
+ tokenizer_cache_path = os.path.join(cache_dir, "tokenizer.json")
3203
+
3204
+ vocab_size = INITIAL_VOCAB_SIZE # Default fallback
3205
+
3206
+ try:
3207
+ # Try to fetch vocab size from CDN
3208
+ tokenizer_url = f"{GENESIS_CDN_URL}/tokenizer.json"
3209
+ resp = requests.get(tokenizer_url, timeout=10)
3210
+
3211
+ if resp.status_code == 200:
3212
+ data = resp.json()
3213
+ vocab_size = data.get("next_merge_id", INITIAL_VOCAB_SIZE)
3214
+
3215
+ # Cache locally for faster startup next time
3216
+ os.makedirs(cache_dir, exist_ok=True)
3217
+ with open(tokenizer_cache_path, 'w') as f:
3218
+ f.write(resp.text)
3219
+
3220
+ logger.info(f"[VOCAB] Pre-fetched tokenizer: {vocab_size:,} tokens (for memory calculation)")
3221
+ else:
3222
+ # Try cached version
3223
+ if os.path.exists(tokenizer_cache_path):
3224
+ import json
3225
+ with open(tokenizer_cache_path, 'r') as f:
3226
+ data = json.load(f)
3227
+ vocab_size = data.get("next_merge_id", INITIAL_VOCAB_SIZE)
3228
+ logger.info(f"[VOCAB] Using cached tokenizer: {vocab_size:,} tokens")
3229
+ except Exception as e:
3230
+ # Try cached version as fallback
3231
+ try:
3232
+ if os.path.exists(tokenizer_cache_path):
3233
+ import json
3234
+ with open(tokenizer_cache_path, 'r') as f:
3235
+ data = json.load(f)
3236
+ vocab_size = data.get("next_merge_id", INITIAL_VOCAB_SIZE)
3237
+ logger.info(f"[VOCAB] Using cached tokenizer: {vocab_size:,} tokens (CDN unavailable)")
3238
+ except Exception:
3239
+ pass
3240
+ logger.debug(f"[VOCAB] Could not prefetch vocab size: {e}, using default {INITIAL_VOCAB_SIZE}")
3241
+
3242
+ # Round up to next chunk boundary (no headroom - recalculate if vocab grows)
3243
+ # Previously used 64K headroom but this wastes ~1GB memory on limited devices
3244
+ vocab_capacity = ((vocab_size + VOCAB_GROWTH_CHUNK - 1) // VOCAB_GROWTH_CHUNK) * VOCAB_GROWTH_CHUNK
3245
+
3246
+ # Update layer pool's vocab_capacity for accurate layer assignment
3247
+ self.layer_pool.vocab_capacity = vocab_capacity
3248
+ logger.info(f"[VOCAB] Layer pool vocab_capacity set to {vocab_capacity:,} (current vocab: {vocab_size:,})")
3249
+
3250
+ def _reconcile_architecture(self):
3251
+ """
3252
+ Smart architecture reconciliation for rejoining the network.
3253
+
3254
+ Handles all scenarios:
3255
+ 1. Quick restart (same architecture) → Use checkpoint
3256
+ 2. Network upgraded (larger arch) → Start fresh with network arch
3257
+ 3. Network downgraded (smaller arch) → Start fresh with network arch
3258
+ 4. Solo bootstrap (no peers) → Use checkpoint or calculate
3259
+ 5. First time (no checkpoint) → Query network or calculate
3260
+
3261
+ Priority:
3262
+ 1. Network consensus (if peers available)
3263
+ 2. Saved checkpoint (if compatible)
3264
+ 3. Fresh calculation (fallback)
3265
+ """
3266
+ saved_arch = self._peek_checkpoint_architecture()
3267
+ network_arch = self._query_network_architecture()
3268
+
3269
+ # Log what we found
3270
+ if saved_arch:
3271
+ logger.info(f"Saved checkpoint: {saved_arch.num_layers}L × {saved_arch.hidden_dim}H")
3272
+ else:
3273
+ logger.info(f"No saved checkpoint found")
3274
+
3275
+ if network_arch:
3276
+ logger.info(f"Network architecture: {network_arch.num_layers}L × {network_arch.hidden_dim}H")
3277
+ else:
3278
+ logger.info(f"No peers found (solo mode or bootstrap)")
3279
+
3280
+ # Decision matrix
3281
+ if network_arch and saved_arch:
3282
+ # Both exist - compare them
3283
+ if self._architectures_compatible(saved_arch, network_arch):
3284
+ # Perfect - checkpoint matches network
3285
+ logger.info(f"✅ Checkpoint compatible with network - will load checkpoint")
3286
+ self.layer_pool.current_architecture = network_arch
3287
+ self.layer_pool.current_num_layers = network_arch.num_layers
3288
+ else:
3289
+ # Mismatch - network takes priority
3290
+ logger.warning(f"⚠️ Architecture mismatch!")
3291
+ logger.warning(f" Checkpoint: {saved_arch.num_layers}L × {saved_arch.hidden_dim}H")
3292
+ logger.warning(f" Network: {network_arch.num_layers}L × {network_arch.hidden_dim}H")
3293
+
3294
+ # Check if network arch fits in our memory
3295
+ network_memory = network_arch.estimate_memory_mb()
3296
+ if network_memory <= self.available_memory_mb:
3297
+ logger.warning(f" → Using NETWORK architecture (checkpoint will be incompatible)")
3298
+ logger.warning(f" → Your training progress will be preserved but weights reset")
3299
+ self.layer_pool.current_architecture = network_arch
3300
+ self.layer_pool.current_num_layers = network_arch.num_layers
3301
+ # Rename old checkpoint instead of deleting
3302
+ self._archive_incompatible_checkpoint()
3303
+ else:
3304
+ logger.error(f" → Network arch needs {network_memory}MB but you only have {self.available_memory_mb}MB!")
3305
+ logger.error(f" → This node cannot participate in current network")
3306
+ logger.error(f" → Falling back to solo mode with checkpoint architecture")
3307
+ self.layer_pool.current_architecture = saved_arch
3308
+ self.layer_pool.current_num_layers = saved_arch.num_layers
3309
+
3310
+ elif network_arch:
3311
+ # Network exists but no checkpoint - join the network
3312
+ network_memory = network_arch.estimate_memory_mb()
3313
+ if network_memory <= self.available_memory_mb:
3314
+ logger.info(f"✅ Joining network with architecture: {network_arch.num_layers}L × {network_arch.hidden_dim}H")
3315
+ self.layer_pool.current_architecture = network_arch
3316
+ self.layer_pool.current_num_layers = network_arch.num_layers
3317
+ else:
3318
+ logger.warning(f"⚠️ Network arch needs {network_memory}MB but you only have {self.available_memory_mb}MB")
3319
+ logger.warning(f" → Will calculate a smaller architecture (may train in isolation)")
3320
+ # Let register_node calculate appropriate architecture
3321
+
3322
+ elif saved_arch:
3323
+ # Checkpoint exists but no network peers (solo mode)
3324
+ # IMPORTANT: Check ACTUAL layers in checkpoint, not just architecture's num_layers
3325
+ actual_saved_layers = self._get_checkpoint_layer_count()
3326
+ if actual_saved_layers and actual_saved_layers > saved_arch.num_layers:
3327
+ # Model grew beyond base architecture - calculate memory for actual layers
3328
+ memory_per_layer = estimate_memory_per_layer(saved_arch)
3329
+ saved_memory = memory_per_layer * actual_saved_layers * 1.1 # 10% overhead
3330
+ logger.info(f"Checkpoint has {actual_saved_layers} layers (grew from base {saved_arch.num_layers})")
3331
+ else:
3332
+ saved_memory = saved_arch.estimate_memory_mb()
3333
+ actual_saved_layers = saved_arch.num_layers
3334
+
3335
+ if saved_memory <= self.available_memory_mb:
3336
+ logger.info(f"✅ Solo mode - using saved checkpoint: {actual_saved_layers}L × {saved_arch.hidden_dim}H")
3337
+ self.layer_pool.current_architecture = saved_arch
3338
+ self.layer_pool.current_num_layers = actual_saved_layers
3339
+ else:
3340
+ # Memory too small for saved checkpoint - calculate what we CAN fit
3341
+ # Use current vocab_capacity for accurate memory estimation
3342
+ vocab_cap = getattr(self.layer_pool, 'vocab_capacity', INITIAL_VOCAB_SIZE)
3343
+ max_layers = calculate_layer_assignment(
3344
+ self.available_memory_mb, saved_arch,
3345
+ safety_factor=0.6, vocab_capacity=vocab_cap,
3346
+ training_mode=True # Conservative for training
3347
+ )
3348
+ logger.warning(f"⚠️ Saved checkpoint has {actual_saved_layers} layers ({saved_memory:.0f}MB) "
3349
+ f"but you only have {self.available_memory_mb:.0f}MB")
3350
+ logger.warning(f" → Will use {max_layers} layers (reduced from checkpoint)")
3351
+ self.layer_pool.current_architecture = saved_arch
3352
+ self.layer_pool.current_num_layers = max_layers
3353
+
3354
+ else:
3355
+ # No checkpoint, no network - fresh start
3356
+ logger.info(f"Fresh start - architecture will be calculated from available memory")
3357
+
3358
+ def _query_network_architecture(self) -> Optional[ModelArchitecture]:
3359
+ """
3360
+ Query the network for the current architecture.
3361
+
3362
+ Tries multiple sources:
3363
+ 1. DHT lookup for architecture announcements
3364
+ 2. Tracker API for network stats
3365
+ 3. Direct peer query
3366
+
3367
+ Returns None if no peers available (solo mode).
3368
+ """
3369
+ import requests
3370
+
3371
+ # Method 1: Try tracker API first (fastest, most reliable)
3372
+ if self.tracker_url:
3373
+ try:
3374
+ # Query tracker for network architecture
3375
+ response = requests.get(
3376
+ f"{self.tracker_url}/network_architecture",
3377
+ timeout=5
3378
+ )
3379
+ if response.ok:
3380
+ data = response.json()
3381
+ if data.get("hidden_dim"):
3382
+ arch = ModelArchitecture(
3383
+ hidden_dim=data["hidden_dim"],
3384
+ intermediate_dim=data.get("intermediate_dim", int(data["hidden_dim"] * 8 / 3)),
3385
+ num_layers=data.get("num_layers", 12),
3386
+ num_heads=data.get("num_heads", 12),
3387
+ num_kv_heads=data.get("num_kv_heads", 4),
3388
+ )
3389
+ logger.debug(f"Got network architecture from tracker: {arch.num_layers}L × {arch.hidden_dim}H")
3390
+ return arch
3391
+ except Exception as e:
3392
+ logger.debug(f"Tracker architecture query failed: {e}")
3393
+
3394
+ # Method 2: Query known peers directly
3395
+ if self.p2p_manager and self.p2p_manager.known_peers:
3396
+ for peer_url in list(self.p2p_manager.known_peers.keys())[:3]:
3397
+ try:
3398
+ response = requests.get(
3399
+ f"{peer_url}/api/node/architecture",
3400
+ timeout=3
3401
+ )
3402
+ if response.ok:
3403
+ data = response.json()
3404
+ if data.get("hidden_dim"):
3405
+ arch = ModelArchitecture(
3406
+ hidden_dim=data["hidden_dim"],
3407
+ intermediate_dim=data.get("intermediate_dim", int(data["hidden_dim"] * 8 / 3)),
3408
+ num_layers=data.get("num_layers", 12),
3409
+ num_heads=data.get("num_heads", 12),
3410
+ num_kv_heads=data.get("num_kv_heads", 4),
3411
+ )
3412
+ logger.debug(f"Got network architecture from peer {peer_url}: {arch.num_layers}L × {arch.hidden_dim}H")
3413
+ return arch
3414
+ except Exception:
3415
+ continue
3416
+
3417
+ # Method 3: DHT lookup (if available)
3418
+ if self.p2p_manager and hasattr(self.p2p_manager, 'dht') and self.p2p_manager.dht:
3419
+ try:
3420
+ import hashlib
3421
+ key = int(hashlib.sha1("network_architecture".encode()).hexdigest(), 16)
3422
+ value = self.p2p_manager.dht.lookup_value(key)
3423
+ if value:
3424
+ import json
3425
+ data = json.loads(value)
3426
+ if isinstance(data, dict) and data.get("hidden_dim"):
3427
+ arch = ModelArchitecture(
3428
+ hidden_dim=data["hidden_dim"],
3429
+ intermediate_dim=data.get("intermediate_dim", int(data["hidden_dim"] * 8 / 3)),
3430
+ num_layers=data.get("num_layers", 12),
3431
+ num_heads=data.get("num_heads", 12),
3432
+ num_kv_heads=data.get("num_kv_heads", 4),
3433
+ )
3434
+ logger.debug(f"Got network architecture from DHT: {arch.num_layers}L × {arch.hidden_dim}H")
3435
+ return arch
3436
+ except Exception as e:
3437
+ logger.debug(f"DHT architecture lookup failed: {e}")
3438
+
3439
+ return None
3440
+
3441
+ def _architectures_compatible(self, arch1: ModelArchitecture, arch2: ModelArchitecture) -> bool:
3442
+ """
3443
+ Check if two architectures are compatible for gradient exchange.
3444
+
3445
+ Compatible means: same hidden_dim, num_heads, num_kv_heads
3446
+ (num_layers can differ - nodes just hold different subsets)
3447
+ """
3448
+ return (
3449
+ arch1.hidden_dim == arch2.hidden_dim and
3450
+ arch1.num_heads == arch2.num_heads and
3451
+ arch1.num_kv_heads == arch2.num_kv_heads
3452
+ )
3453
+
3454
+ def _archive_incompatible_checkpoint(self):
3455
+ """
3456
+ Archive an incompatible checkpoint instead of deleting it.
3457
+
3458
+ Storage-aware: Keeps only MAX_ARCHIVED_CHECKPOINTS and respects
3459
+ the user's storage budget.
3460
+ """
3461
+ MAX_ARCHIVED_CHECKPOINTS = 2 # Keep at most 2 old checkpoints
3462
+
3463
+ path = self.CHECKPOINT_DIR / f"dynamic_node_{self.wallet_id}.pt"
3464
+
3465
+ if not path.exists():
3466
+ return
3467
+
3468
+ # First, clean up old archives to stay within limits
3469
+ self._cleanup_old_archives(MAX_ARCHIVED_CHECKPOINTS - 1) # Make room for new one
3470
+
3471
+ # Now archive the current checkpoint
3472
+ import time
3473
+ timestamp = int(time.time())
3474
+ archive_path = self.CHECKPOINT_DIR / f"archived_{self.wallet_id}_{timestamp}.pt"
3475
+
3476
+ try:
3477
+ path.rename(archive_path)
3478
+ logger.info(f"Archived incompatible checkpoint to: {archive_path.name}")
3479
+ except Exception as e:
3480
+ logger.warning(f"Could not archive checkpoint: {e}")
3481
+ # If archive fails, just delete it
3482
+ try:
3483
+ path.unlink()
3484
+ logger.info(f"Deleted incompatible checkpoint (archive failed)")
3485
+ except Exception:
3486
+ pass
3487
+
3488
+ def _cleanup_old_archives(self, max_keep: int = 2):
3489
+ """
3490
+ Clean up old archived checkpoints, keeping only the most recent ones.
3491
+
3492
+ Also enforces storage budget if archives are taking too much space.
3493
+ """
3494
+ # Find all archives for this wallet
3495
+ pattern = f"archived_{self.wallet_id}_*.pt"
3496
+ archives = sorted(
3497
+ self.CHECKPOINT_DIR.glob(pattern),
3498
+ key=lambda p: p.stat().st_mtime,
3499
+ reverse=True # Newest first
3500
+ )
3501
+
3502
+ # Calculate total archive size
3503
+ total_archive_mb = sum(p.stat().st_size / (1024 * 1024) for p in archives)
3504
+
3505
+ # Storage budget: archives should use at most 20% of max_storage
3506
+ archive_budget_mb = self.max_storage_mb * 0.2
3507
+
3508
+ # Delete archives that exceed count OR storage limits
3509
+ deleted_count = 0
3510
+ for i, archive in enumerate(archives):
3511
+ should_delete = False
3512
+
3513
+ # Too many archives
3514
+ if i >= max_keep:
3515
+ should_delete = True
3516
+
3517
+ # Over storage budget
3518
+ if total_archive_mb > archive_budget_mb:
3519
+ should_delete = True
3520
+
3521
+ if should_delete:
3522
+ try:
3523
+ archive_size_mb = archive.stat().st_size / (1024 * 1024)
3524
+ archive.unlink()
3525
+ total_archive_mb -= archive_size_mb
3526
+ deleted_count += 1
3527
+ logger.debug(f"Cleaned up old archive: {archive.name}")
3528
+ except Exception:
3529
+ pass
3530
+
3531
+ if deleted_count > 0:
3532
+ logger.info(f"Cleaned up {deleted_count} old archived checkpoint(s)")
3533
+
3534
+ def _get_checkpoint_layer_count(self) -> Optional[int]:
3535
+ """
3536
+ Get the actual number of layers saved in checkpoint.
3537
+
3538
+ This is important because the model may have GROWN beyond the base architecture.
3539
+ The architecture might say 11 layers, but 110 layers could be saved!
3540
+ """
3541
+ path = self.CHECKPOINT_DIR / f"dynamic_node_{self.wallet_id}.pt"
3542
+ legacy_path = self.CHECKPOINT_DIR / f"dynamic_node_{self.node_id[:16]}.pt"
3543
+
3544
+ checkpoint_path = path if path.exists() else (legacy_path if legacy_path.exists() else None)
3545
+ if not checkpoint_path:
3546
+ return None
3547
+
3548
+ try:
3549
+ checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
3550
+ layers = checkpoint.get("layers", {})
3551
+ if layers:
3552
+ return len(layers)
3553
+ # Fall back to layer_ids if present
3554
+ layer_ids = checkpoint.get("layer_ids", [])
3555
+ if layer_ids:
3556
+ return len(layer_ids)
3557
+ except Exception as e:
3558
+ logger.debug(f"Could not get checkpoint layer count: {e}")
3559
+
3560
+ return None
3561
+
3562
+ def _peek_checkpoint_architecture(self) -> Optional[ModelArchitecture]:
3563
+ """
3564
+ Peek at checkpoint to get saved architecture WITHOUT loading weights.
3565
+
3566
+ This allows us to use the same architecture as the checkpoint,
3567
+ preventing architecture drift between restarts on the same machine.
3568
+ """
3569
+ # Use wallet_id for stable checkpoint path
3570
+ path = self.CHECKPOINT_DIR / f"dynamic_node_{self.wallet_id}.pt"
3571
+
3572
+ # Also check legacy path
3573
+ legacy_path = self.CHECKPOINT_DIR / f"dynamic_node_{self.node_id[:16]}.pt"
3574
+
3575
+ checkpoint_path = None
3576
+ if path.exists():
3577
+ checkpoint_path = path
3578
+ elif legacy_path.exists():
3579
+ checkpoint_path = legacy_path
3580
+
3581
+ if not checkpoint_path:
3582
+ return None
3583
+
3584
+ try:
3585
+ # Load just the metadata (weights_only would fail, but we catch it)
3586
+ checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
3587
+
3588
+ arch_dict = checkpoint.get("architecture")
3589
+ if arch_dict:
3590
+ return ModelArchitecture.from_dict(arch_dict)
3591
+ except Exception as e:
3592
+ logger.debug(f"Could not peek checkpoint architecture: {e}")
3593
+
3594
+ return None
3595
+
3596
+ def _load_embedding_with_vocab_expansion(self, embedding: nn.Embedding, state_dict: dict, name: str):
3597
+ """
3598
+ Load embedding weights, handling vocab size expansion gracefully.
3599
+
3600
+ When vocabulary grows (tokenizer learns new merges), the checkpoint has fewer
3601
+ tokens than the current model. This method:
3602
+ 1. Loads weights for existing tokens (preserves all training)
3603
+ 2. Keeps randomly initialized weights for new tokens
3604
+ """
3605
+ checkpoint_weight = state_dict.get("weight")
3606
+ if checkpoint_weight is None:
3607
+ logger.warning(f"[CHECKPOINT] No weight found in {name} state_dict")
3608
+ return
3609
+
3610
+ checkpoint_vocab_size = checkpoint_weight.shape[0]
3611
+ current_vocab_size = embedding.weight.shape[0]
3612
+
3613
+ if checkpoint_vocab_size == current_vocab_size:
3614
+ # Same size - normal load
3615
+ embedding.load_state_dict(state_dict)
3616
+ logger.info(f"[CHECKPOINT] Loaded {name}: {current_vocab_size} tokens")
3617
+ elif checkpoint_vocab_size < current_vocab_size:
3618
+ # Vocab expanded - partial load (PRESERVE TRAINING!)
3619
+ with torch.no_grad():
3620
+ embedding.weight[:checkpoint_vocab_size] = checkpoint_weight
3621
+ logger.info(f"[CHECKPOINT] Loaded {name} with vocab expansion: "
3622
+ f"{checkpoint_vocab_size} → {current_vocab_size} tokens "
3623
+ f"(preserved {checkpoint_vocab_size} trained embeddings)")
3624
+ else:
3625
+ # Vocab shrunk (unusual) - load what fits
3626
+ with torch.no_grad():
3627
+ embedding.weight[:] = checkpoint_weight[:current_vocab_size]
3628
+ logger.warning(f"[CHECKPOINT] Loaded {name} with vocab truncation: "
3629
+ f"{checkpoint_vocab_size} → {current_vocab_size} tokens")
3630
+
3631
+ def _load_lm_head_with_vocab_expansion(self, lm_head: nn.Linear, state_dict: dict, name: str):
3632
+ """
3633
+ Load LM head weights, handling vocab size expansion gracefully.
3634
+
3635
+ Similar to embedding expansion - preserves trained weights for existing tokens.
3636
+ """
3637
+ checkpoint_weight = state_dict.get("weight")
3638
+ checkpoint_bias = state_dict.get("bias")
3639
+
3640
+ if checkpoint_weight is None:
3641
+ logger.warning(f"[CHECKPOINT] No weight found in {name} state_dict")
3642
+ return
3643
+
3644
+ checkpoint_vocab_size = checkpoint_weight.shape[0]
3645
+ current_vocab_size = lm_head.weight.shape[0]
3646
+
3647
+ if checkpoint_vocab_size == current_vocab_size:
3648
+ # Same size - normal load
3649
+ lm_head.load_state_dict(state_dict)
3650
+ logger.info(f"[CHECKPOINT] Loaded {name}: {current_vocab_size} outputs")
3651
+ elif checkpoint_vocab_size < current_vocab_size:
3652
+ # Vocab expanded - partial load (PRESERVE TRAINING!)
3653
+ with torch.no_grad():
3654
+ lm_head.weight[:checkpoint_vocab_size] = checkpoint_weight
3655
+ if checkpoint_bias is not None and lm_head.bias is not None:
3656
+ lm_head.bias[:checkpoint_vocab_size] = checkpoint_bias
3657
+ logger.info(f"[CHECKPOINT] Loaded {name} with vocab expansion: "
3658
+ f"{checkpoint_vocab_size} → {current_vocab_size} outputs "
3659
+ f"(preserved {checkpoint_vocab_size} trained weights)")
3660
+ else:
3661
+ # Vocab shrunk (unusual) - load what fits
3662
+ with torch.no_grad():
3663
+ lm_head.weight[:] = checkpoint_weight[:current_vocab_size]
3664
+ if checkpoint_bias is not None and lm_head.bias is not None:
3665
+ lm_head.bias[:] = checkpoint_bias[:current_vocab_size]
3666
+ logger.warning(f"[CHECKPOINT] Loaded {name} with vocab truncation: "
3667
+ f"{checkpoint_vocab_size} → {current_vocab_size} outputs")
3668
+
3669
+ def _load_checkpoint(self):
3670
+ """Load checkpoint from disk if it exists (resume training)."""
3671
+ # Use wallet_id for stable checkpoint path (survives node_id changes)
3672
+ path = self.CHECKPOINT_DIR / f"dynamic_node_{self.wallet_id}.pt"
3673
+
3674
+ # Also check legacy path (node_id-based) for migration
3675
+ legacy_path = self.CHECKPOINT_DIR / f"dynamic_node_{self.node_id[:16]}.pt"
3676
+ if not path.exists() and legacy_path.exists():
3677
+ logger.info(f"Migrating checkpoint from legacy path: {legacy_path.name} -> {path.name}")
3678
+ legacy_path.rename(path)
3679
+
3680
+ if not path.exists():
3681
+ logger.info(f"No checkpoint found at {path.name}, starting fresh")
3682
+ return False
3683
+
3684
+ logger.info(f"Loading checkpoint from: {path.name}")
3685
+ try:
3686
+ checkpoint = torch.load(path, map_location=self.device, weights_only=False)
3687
+
3688
+ # ARCHITECTURE COMPATIBILITY CHECK
3689
+ saved_arch_dict = checkpoint.get("architecture")
3690
+ if saved_arch_dict:
3691
+ saved_arch = ModelArchitecture.from_dict(saved_arch_dict)
3692
+ current_arch = self.model.architecture
3693
+
3694
+ # Check if architecture changed (includes num_heads for head_dim compatibility)
3695
+ if (saved_arch.hidden_dim != current_arch.hidden_dim or
3696
+ saved_arch.intermediate_dim != current_arch.intermediate_dim or
3697
+ saved_arch.num_heads != current_arch.num_heads or
3698
+ saved_arch.num_kv_heads != current_arch.num_kv_heads):
3699
+ logger.warning(f"Architecture mismatch! Checkpoint is incompatible.")
3700
+ logger.warning(f" Saved: {saved_arch.num_layers}L × {saved_arch.hidden_dim}H, "
3701
+ f"heads={saved_arch.num_heads}/{saved_arch.num_kv_heads}")
3702
+ logger.warning(f" Current: {current_arch.num_layers}L × {current_arch.hidden_dim}H, "
3703
+ f"heads={current_arch.num_heads}/{current_arch.num_kv_heads}")
3704
+ logger.warning(f" Starting fresh (architecture was upgraded)")
3705
+ # Delete incompatible checkpoint
3706
+ try:
3707
+ path.unlink()
3708
+ logger.info(f"Deleted incompatible checkpoint: {path}")
3709
+ except Exception:
3710
+ pass
3711
+ return False
3712
+ else:
3713
+ logger.warning("Legacy checkpoint without architecture info - starting fresh")
3714
+ # Delete legacy checkpoint
3715
+ try:
3716
+ path.unlink()
3717
+ logger.info(f"Deleted legacy checkpoint: {path}")
3718
+ except Exception:
3719
+ pass
3720
+ return False
3721
+
3722
+ # Check layer assignment compatibility
3723
+ saved_layers = set(checkpoint.get("layer_ids", []))
3724
+ current_layers = set(self.my_layer_ids)
3725
+
3726
+ if saved_layers != current_layers:
3727
+ # Layers changed - try to load what we can
3728
+ common_layers = saved_layers.intersection(current_layers)
3729
+ if common_layers:
3730
+ logger.warning(f"Layer assignment changed: saved={len(saved_layers)}, current={len(current_layers)}, common={len(common_layers)}")
3731
+ logger.info(f"Will load {len(common_layers)} common layers from checkpoint")
3732
+ else:
3733
+ logger.warning(f"No common layers between checkpoint and current assignment, starting fresh")
3734
+ return False
3735
+
3736
+ # Load layer weights
3737
+ for layer_id, state_dict in checkpoint.get("layers", {}).items():
3738
+ layer_id = int(layer_id)
3739
+ if layer_id in self.model.my_layers:
3740
+ self.model.my_layers[layer_id].load_state_dict(state_dict)
3741
+
3742
+ # Load embedding if present (handle vocab size changes gracefully)
3743
+ if self.model.embedding and "embedding" in checkpoint:
3744
+ self._load_embedding_with_vocab_expansion(
3745
+ self.model.embedding,
3746
+ checkpoint["embedding"],
3747
+ "embedding"
3748
+ )
3749
+
3750
+ # Load LM head if present (handle vocab size changes gracefully)
3751
+ if self.model.lm_head and "lm_head" in checkpoint:
3752
+ self._load_lm_head_with_vocab_expansion(
3753
+ self.model.lm_head,
3754
+ checkpoint["lm_head"],
3755
+ "lm_head"
3756
+ )
3757
+
3758
+ # Load final norm if present
3759
+ if self.model.final_norm and "final_norm" in checkpoint:
3760
+ self.model.final_norm.load_state_dict(checkpoint["final_norm"])
3761
+
3762
+ # Restore training state
3763
+ self.total_training_rounds = checkpoint.get("total_training_rounds", 0)
3764
+
3765
+ # Store optimizer state for later loading (after optimizer is created)
3766
+ if "optimizer" in checkpoint:
3767
+ self._pending_optimizer_state = checkpoint["optimizer"]
3768
+
3769
+ # Store DiLoCo state for later loading (after swarm is created)
3770
+ if "diloco" in checkpoint:
3771
+ self._pending_diloco_state = checkpoint["diloco"]
3772
+ logger.info("[NODE] DiLoCo state found in checkpoint, will restore after swarm init")
3773
+
3774
+ # Count how many layers were actually loaded
3775
+ loaded_layer_count = sum(1 for lid in checkpoint.get("layers", {}).keys() if int(lid) in self.model.my_layers)
3776
+ logger.info(f"Checkpoint loaded: {self.total_training_rounds} training rounds, "
3777
+ f"{loaded_layer_count}/{len(current_layers)} layers from {path}")
3778
+ return True
3779
+
3780
+ except Exception as e:
3781
+ logger.warning(f"Failed to load checkpoint: {e}, starting fresh")
3782
+ return False
3783
+
3784
+ def _restore_pending_state(self):
3785
+ """
3786
+ Restore optimizer and DiLoCo state after they are initialized.
3787
+
3788
+ Called after swarm/optimizer are set up to restore checkpoint state.
3789
+ """
3790
+ # Restore optimizer state
3791
+ if hasattr(self, '_pending_optimizer_state') and self._pending_optimizer_state:
3792
+ if hasattr(self, 'optimizer') and self.optimizer:
3793
+ try:
3794
+ self.optimizer.load_state_dict(self._pending_optimizer_state)
3795
+ logger.info("[NODE] Restored optimizer state from checkpoint")
3796
+ except Exception as e:
3797
+ logger.warning(f"[NODE] Could not restore optimizer state: {e}")
3798
+ del self._pending_optimizer_state
3799
+
3800
+ # Restore DiLoCo state
3801
+ if hasattr(self, '_pending_diloco_state') and self._pending_diloco_state:
3802
+ if hasattr(self, 'swarm') and self.swarm:
3803
+ diloco = getattr(self.swarm, 'diloco_trainer', None)
3804
+ if diloco and hasattr(diloco, 'load_state_dict'):
3805
+ try:
3806
+ diloco.load_state_dict(self._pending_diloco_state)
3807
+ logger.info(f"[NODE] Restored DiLoCo state (inner_step={diloco.stats.inner_step_count})")
3808
+ except Exception as e:
3809
+ logger.warning(f"[NODE] Could not restore DiLoCo state: {e}")
3810
+ del self._pending_diloco_state
3811
+
3812
+ # Class-level save lock to prevent concurrent checkpoint saves
3813
+ _checkpoint_save_lock = threading.Lock()
3814
+ _checkpoint_save_in_progress = False
3815
+
3816
+ def _save_checkpoint(self, async_save: bool = True):
3817
+ """
3818
+ Smart checkpoint saving with STREAMING ASYNC for memory-constrained systems.
3819
+
3820
+ THREE MODES:
3821
+ 1. BULK ASYNC (>32GB free OR >2.5x checkpoint): Clone all, save in thread
3822
+ 2. STREAMING ASYNC (>500MB free): Clone one layer at a time, save incrementally
3823
+ 3. SYNC (<500MB free): Blocking save (last resort)
3824
+
3825
+ Streaming async blocks training only during the brief snapshot (~1-2s for 110 layers),
3826
+ then saves to disk in background (~10-60s) while training continues.
3827
+
3828
+ Thread-safe: concurrent saves are serialized via lock.
3829
+ """
3830
+ if not self.model:
3831
+ return
3832
+
3833
+ # Prevent concurrent saves (async save might still be in progress)
3834
+ if not DynamicNeuroNode._checkpoint_save_lock.acquire(blocking=False):
3835
+ logger.debug("[NODE] Checkpoint save skipped - another save in progress")
3836
+ return
3837
+
3838
+ # Use wallet_id for stable checkpoint path
3839
+ path = self.CHECKPOINT_DIR / f"dynamic_node_{self.wallet_id}.pt"
3840
+ temp_path = self.CHECKPOINT_DIR / f"dynamic_node_{self.wallet_id}.pt.tmp"
3841
+
3842
+ # 1. Assess memory situation (respect configured --memory limit!)
3843
+ try:
3844
+ total_params = sum(p.numel() for p in self.model.parameters())
3845
+ checkpoint_size_mb = (total_params * 4) / (1024 * 1024)
3846
+
3847
+ # Get ACTUAL available memory (respecting configured limit)
3848
+ vm = psutil.virtual_memory()
3849
+ system_available_mb = vm.available / (1024 * 1024)
3850
+
3851
+ # Check current process memory usage
3852
+ process = psutil.Process()
3853
+ process_used_mb = process.memory_info().rss / (1024 * 1024)
3854
+
3855
+ # If user set --memory limit, respect it
3856
+ # Available = min(system_available, configured_limit - current_usage)
3857
+ configured_limit = getattr(self, 'available_memory_mb', None)
3858
+ if configured_limit:
3859
+ # How much headroom do we have within our configured limit?
3860
+ headroom_mb = max(0, configured_limit - process_used_mb)
3861
+ # Use the more conservative of system available or our headroom
3862
+ available_mb = min(system_available_mb, headroom_mb)
3863
+ logger.debug(f"[NODE] Memory check: process={process_used_mb:.0f}MB, "
3864
+ f"limit={configured_limit:.0f}MB, headroom={headroom_mb:.0f}MB, "
3865
+ f"system_free={system_available_mb:.0f}MB, using={available_mb:.0f}MB")
3866
+ else:
3867
+ available_mb = system_available_mb
3868
+
3869
+ # Determine save mode based on available headroom
3870
+ # Bulk async needs 2.5x checkpoint size to clone everything
3871
+ can_bulk_async = (available_mb > (checkpoint_size_mb * 2.5)) or (available_mb > 32000)
3872
+ # Streaming async just needs enough for 1 layer (~50-100MB typically)
3873
+ can_stream_async = available_mb > 500
3874
+
3875
+ except Exception as e:
3876
+ logger.warning(f"[NODE] Could not assess memory: {e}. Using streaming async.")
3877
+ can_bulk_async = False
3878
+ can_stream_async = True
3879
+ checkpoint_size_mb = 0
3880
+ available_mb = 0
3881
+
3882
+ try:
3883
+ # ============ ATOMIC SNAPSHOT PHASE ============
3884
+ # ALL state must be captured together to ensure consistency.
3885
+ # DiLoCo state + model weights must be from the same "moment in time".
3886
+
3887
+ # Helper to deep-clone DiLoCo state (ALL tensors to CPU for async safety)
3888
+ def _clone_diloco_state():
3889
+ if not hasattr(self, 'swarm') or not self.swarm:
3890
+ return None
3891
+ diloco = getattr(self.swarm, 'diloco_trainer', None)
3892
+ if not diloco or not hasattr(diloco, 'state_dict'):
3893
+ return None
3894
+ try:
3895
+ state = diloco.state_dict()
3896
+
3897
+ # Deep clone optimizer state (handles both PyTorch and custom formats)
3898
+ def _clone_optimizer_state(opt_state):
3899
+ if opt_state is None:
3900
+ return None
3901
+ cloned = {}
3902
+ for key, value in opt_state.items():
3903
+ if isinstance(value, torch.Tensor):
3904
+ # Direct tensor (e.g., in custom optimizers)
3905
+ cloned[key] = value.detach().clone().cpu()
3906
+ elif isinstance(value, dict):
3907
+ # Nested dict (e.g., 'state' or 'velocity' dicts)
3908
+ cloned[key] = {}
3909
+ for k, v in value.items():
3910
+ if isinstance(v, torch.Tensor):
3911
+ cloned[key][k] = v.detach().clone().cpu()
3912
+ elif isinstance(v, dict):
3913
+ # PyTorch optimizer 'state' has param_idx -> {key: tensor}
3914
+ cloned[key][k] = {}
3915
+ for kk, vv in v.items():
3916
+ if isinstance(vv, torch.Tensor):
3917
+ cloned[key][k][kk] = vv.detach().clone().cpu()
3918
+ else:
3919
+ cloned[key][k][kk] = vv
3920
+ else:
3921
+ cloned[key][k] = v
3922
+ elif isinstance(value, list):
3923
+ # List (e.g., param_groups) - shallow copy is fine
3924
+ cloned[key] = list(value)
3925
+ else:
3926
+ # Scalar values (lr, momentum, etc.)
3927
+ cloned[key] = value
3928
+ return cloned
3929
+
3930
+ # Deep clone all tensors to CPU for async safety
3931
+ cloned = {
3932
+ 'config': dict(state.get('config', {})),
3933
+ 'inner_optimizer': _clone_optimizer_state(state.get('inner_optimizer')),
3934
+ 'outer_optimizer': _clone_optimizer_state(state.get('outer_optimizer')),
3935
+ 'initial_weights': {
3936
+ k: v.detach().clone().cpu()
3937
+ for k, v in state.get('initial_weights', {}).items()
3938
+ },
3939
+ 'stats': dict(state.get('stats', {})),
3940
+ 'phase': state.get('phase', 'idle'),
3941
+ }
3942
+ return cloned
3943
+ except Exception as e:
3944
+ logger.warning(f"[NODE] Could not snapshot DiLoCo state: {e}")
3945
+ return None
3946
+
3947
+ # ============ MODE 1: BULK ASYNC (plenty of memory) ============
3948
+ if async_save and can_bulk_async:
3949
+ logger.debug(f"[NODE] Checkpoint: BULK ASYNC (Free: {available_mb:.0f}MB)")
3950
+
3951
+ # Capture everything atomically
3952
+ checkpoint = {
3953
+ "node_id": self.node_id,
3954
+ "layer_ids": list(self.my_layer_ids),
3955
+ "architecture": self.model.architecture.to_dict(),
3956
+ "architecture_version": self.layer_pool.architecture_version if self.layer_pool else 1,
3957
+ "has_embedding": self.model.has_embedding,
3958
+ "has_lm_head": self.model.has_lm_head,
3959
+ "total_training_rounds": self.total_training_rounds,
3960
+ "current_loss": self.current_loss,
3961
+ "timestamp": time.time(),
3962
+ "layers": {
3963
+ layer_id: {k: v.clone().cpu() for k, v in layer.state_dict().items()}
3964
+ for layer_id, layer in self.model.my_layers.items()
3965
+ },
3966
+ }
3967
+ if self.model.embedding:
3968
+ checkpoint["embedding"] = {k: v.clone().cpu() for k, v in self.model.embedding.state_dict().items()}
3969
+ if self.model.lm_head:
3970
+ checkpoint["lm_head"] = {k: v.clone().cpu() for k, v in self.model.lm_head.state_dict().items()}
3971
+ if self.model.final_norm:
3972
+ checkpoint["final_norm"] = {k: v.clone().cpu() for k, v in self.model.final_norm.state_dict().items()}
3973
+
3974
+ # DiLoCo state captured AFTER model weights (both in same atomic snapshot)
3975
+ diloco_state = _clone_diloco_state()
3976
+ if diloco_state:
3977
+ checkpoint["diloco"] = diloco_state
3978
+
3979
+ def _do_bulk_save():
3980
+ try:
3981
+ torch.save(checkpoint, temp_path, _use_new_zipfile_serialization=False)
3982
+ import shutil
3983
+ shutil.move(str(temp_path), str(path))
3984
+ logger.info(f"[NODE] Checkpoint saved ({len(self.my_layer_ids)} layers)")
3985
+ except Exception as e:
3986
+ logger.error(f"[NODE] Checkpoint save failed: {e}")
3987
+ if temp_path.exists(): temp_path.unlink()
3988
+ finally:
3989
+ DynamicNeuroNode._checkpoint_save_lock.release()
3990
+
3991
+ # Use daemon=False so checkpoint completes even during shutdown
3992
+ threading.Thread(target=_do_bulk_save, daemon=False).start()
3993
+ return # Lock will be released by background thread
3994
+
3995
+ # ============ MODE 2: STREAMING ASYNC (memory-efficient) ============
3996
+ if async_save and can_stream_async:
3997
+ logger.info(f"[NODE] Checkpoint: STREAMING ASYNC (Free: {available_mb:.0f}MB, cloning {len(self.model.my_layers)} layers)")
3998
+
3999
+ # SNAPSHOT PHASE: Clone one layer at a time into a list
4000
+ # This brief pause (~1-2s) ensures consistency without needing full clone memory
4001
+ snapshot_start = time.time()
4002
+
4003
+ # Capture metadata first (lightweight)
4004
+ checkpoint_meta = {
4005
+ "node_id": self.node_id,
4006
+ "layer_ids": list(self.my_layer_ids),
4007
+ "architecture": self.model.architecture.to_dict(),
4008
+ "architecture_version": self.layer_pool.architecture_version if self.layer_pool else 1,
4009
+ "has_embedding": self.model.has_embedding,
4010
+ "has_lm_head": self.model.has_lm_head,
4011
+ "total_training_rounds": self.total_training_rounds,
4012
+ "current_loss": self.current_loss,
4013
+ "timestamp": time.time(),
4014
+ }
4015
+
4016
+ # Clone layers one at a time (memory efficient)
4017
+ layer_snapshots = []
4018
+ for layer_id, layer in self.model.my_layers.items():
4019
+ layer_state = {k: v.detach().clone().cpu() for k, v in layer.state_dict().items()}
4020
+ layer_snapshots.append((layer_id, layer_state))
4021
+
4022
+ # Clone special modules
4023
+ embedding_snapshot = None
4024
+ lm_head_snapshot = None
4025
+ final_norm_snapshot = None
4026
+
4027
+ if self.model.embedding:
4028
+ embedding_snapshot = {k: v.detach().clone().cpu() for k, v in self.model.embedding.state_dict().items()}
4029
+ if self.model.lm_head:
4030
+ lm_head_snapshot = {k: v.detach().clone().cpu() for k, v in self.model.lm_head.state_dict().items()}
4031
+ if self.model.final_norm:
4032
+ final_norm_snapshot = {k: v.detach().clone().cpu() for k, v in self.model.final_norm.state_dict().items()}
4033
+
4034
+ # DiLoCo state - captured in SAME snapshot window as model weights
4035
+ diloco_snapshot = _clone_diloco_state()
4036
+
4037
+ snapshot_time = time.time() - snapshot_start
4038
+ logger.debug(f"[NODE] Snapshot complete in {snapshot_time:.1f}s, starting async save")
4039
+
4040
+ # ASYNC SAVE PHASE: Write to disk in background thread
4041
+ # All data is now cloned and owned by this closure - safe for async
4042
+ def _do_streaming_save():
4043
+ try:
4044
+ checkpoint = dict(checkpoint_meta)
4045
+ checkpoint["layers"] = {lid: lstate for lid, lstate in layer_snapshots}
4046
+ if embedding_snapshot:
4047
+ checkpoint["embedding"] = embedding_snapshot
4048
+ if lm_head_snapshot:
4049
+ checkpoint["lm_head"] = lm_head_snapshot
4050
+ if final_norm_snapshot:
4051
+ checkpoint["final_norm"] = final_norm_snapshot
4052
+ if diloco_snapshot:
4053
+ checkpoint["diloco"] = diloco_snapshot
4054
+
4055
+ torch.save(checkpoint, temp_path, _use_new_zipfile_serialization=False)
4056
+ import shutil
4057
+ shutil.move(str(temp_path), str(path))
4058
+ logger.info(f"[NODE] Checkpoint saved ({len(layer_snapshots)} layers)")
4059
+ except Exception as e:
4060
+ logger.error(f"[NODE] Checkpoint save failed: {e}")
4061
+ if temp_path.exists(): temp_path.unlink()
4062
+ finally:
4063
+ DynamicNeuroNode._checkpoint_save_lock.release()
4064
+
4065
+ # Use daemon=False so checkpoint completes even during shutdown
4066
+ threading.Thread(target=_do_streaming_save, daemon=False).start()
4067
+ return # Lock will be released by background thread
4068
+
4069
+ # ============ MODE 3: SYNC (last resort, very low memory) ============
4070
+ logger.warning(f"[NODE] Checkpoint: SYNC mode (Free: {available_mb:.0f}MB < 500MB minimum)")
4071
+
4072
+ checkpoint = {
4073
+ "node_id": self.node_id,
4074
+ "layer_ids": list(self.my_layer_ids),
4075
+ "architecture": self.model.architecture.to_dict(),
4076
+ "architecture_version": self.layer_pool.architecture_version if self.layer_pool else 1,
4077
+ "has_embedding": self.model.has_embedding,
4078
+ "has_lm_head": self.model.has_lm_head,
4079
+ "total_training_rounds": self.total_training_rounds,
4080
+ "current_loss": self.current_loss,
4081
+ "timestamp": time.time(),
4082
+ "layers": {
4083
+ layer_id: layer.state_dict()
4084
+ for layer_id, layer in self.model.my_layers.items()
4085
+ },
4086
+ }
4087
+ if self.model.embedding:
4088
+ checkpoint["embedding"] = self.model.embedding.state_dict()
4089
+ if self.model.lm_head:
4090
+ checkpoint["lm_head"] = self.model.lm_head.state_dict()
4091
+ if self.model.final_norm:
4092
+ checkpoint["final_norm"] = self.model.final_norm.state_dict()
4093
+
4094
+ # DiLoCo state (no need to clone for sync - we block anyway)
4095
+ diloco_state = _clone_diloco_state()
4096
+ if diloco_state:
4097
+ checkpoint["diloco"] = diloco_state
4098
+
4099
+ torch.save(checkpoint, temp_path, _use_new_zipfile_serialization=False)
4100
+ import shutil
4101
+ shutil.move(str(temp_path), str(path))
4102
+ logger.info(f"[NODE] Checkpoint saved ({len(self.my_layer_ids)} layers)")
4103
+
4104
+ # Sync mode completed successfully, release lock
4105
+ DynamicNeuroNode._checkpoint_save_lock.release()
4106
+
4107
+ except Exception as e:
4108
+ logger.error(f"[NODE] Checkpoint preparation failed: {type(e).__name__}: {e}")
4109
+ try:
4110
+ if temp_path.exists(): temp_path.unlink()
4111
+ except:
4112
+ pass
4113
+ # Release lock on exception
4114
+ DynamicNeuroNode._checkpoint_save_lock.release()
4115
+
4116
+
4117
+ def create_dynamic_node(
4118
+ node_token: str,
4119
+ port: int = 8000,
4120
+ tracker_url: str = "https://neuroshard.com/api/tracker",
4121
+ available_memory_mb: Optional[float] = None,
4122
+ enable_training: bool = True,
4123
+ max_storage_mb: float = 100.0,
4124
+ max_cpu_threads: Optional[int] = None,
4125
+ device: str = "auto",
4126
+ p2p_manager: Optional[Any] = None, # NEW: Pass P2P for DHT discovery during layer assignment
4127
+ ) -> DynamicNeuroNode:
4128
+ """
4129
+ Create and start a dynamic node.
4130
+
4131
+ MULTI-NODE SUPPORT:
4132
+ If the same token is used on multiple machines or ports, each gets a unique
4133
+ node_id (based on machine + port) while sharing the same wallet_id (based on token).
4134
+
4135
+ This means:
4136
+ - Each physical node has a unique network identity
4137
+ - Earnings accumulate to the same NEURO wallet
4138
+ - No conflicts in DHT/layer assignments
4139
+
4140
+ FULLY DECENTRALIZED:
4141
+ If p2p_manager is provided, DHT is used for network discovery during layer
4142
+ assignment. No tracker fallbacks - pure P2P!
4143
+ """
4144
+ from neuroshard.utils.hardware import get_instance_id
4145
+
4146
+ # Generate instance-specific node_id
4147
+ instance_id = get_instance_id(port)
4148
+
4149
+ # Combine token with instance for unique network identity
4150
+ # wallet_id (from token alone) is used for NEURO earnings
4151
+ # node_id (from token + instance) is used for network identity
4152
+ combined = f"{node_token}:{instance_id}"
4153
+ node_id = str(int(hashlib.sha256(combined.encode()).hexdigest(), 16))
4154
+
4155
+ # Log multi-node info
4156
+ wallet_id = hashlib.sha256(node_token.encode()).hexdigest()[:16]
4157
+ logger.info(f"Instance ID: {instance_id} (machine+port)")
4158
+ logger.info(f"Wallet ID: {wallet_id}... (for NEURO earnings)")
4159
+ logger.info(f"Node ID: {node_id[:16]}... (unique network identity)")
4160
+
4161
+ node = DynamicNeuroNode(
4162
+ node_id=node_id,
4163
+ port=port,
4164
+ tracker_url=tracker_url,
4165
+ node_token=node_token,
4166
+ available_memory_mb=available_memory_mb,
4167
+ enable_training=enable_training,
4168
+ max_storage_mb=max_storage_mb,
4169
+ max_cpu_threads=max_cpu_threads,
4170
+ device=device,
4171
+ )
4172
+
4173
+ # Store instance info for debugging
4174
+ node.instance_id = instance_id
4175
+ node.wallet_id = wallet_id
4176
+
4177
+ # CRITICAL: Connect P2P BEFORE start() so DHT is available for layer discovery!
4178
+ # This enables fully decentralized network discovery without tracker fallbacks.
4179
+ if p2p_manager:
4180
+ node.p2p_manager = p2p_manager
4181
+ logger.info("P2P connected BEFORE start - DHT available for network discovery")
4182
+
4183
+ node.start()
4184
+
4185
+ return node
4186
+