nexaroa 0.0.111__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- neuroshard/__init__.py +93 -0
- neuroshard/__main__.py +4 -0
- neuroshard/cli.py +466 -0
- neuroshard/core/__init__.py +92 -0
- neuroshard/core/consensus/verifier.py +252 -0
- neuroshard/core/crypto/__init__.py +20 -0
- neuroshard/core/crypto/ecdsa.py +392 -0
- neuroshard/core/economics/__init__.py +52 -0
- neuroshard/core/economics/constants.py +387 -0
- neuroshard/core/economics/ledger.py +2111 -0
- neuroshard/core/economics/market.py +975 -0
- neuroshard/core/economics/wallet.py +168 -0
- neuroshard/core/governance/__init__.py +74 -0
- neuroshard/core/governance/proposal.py +561 -0
- neuroshard/core/governance/registry.py +545 -0
- neuroshard/core/governance/versioning.py +332 -0
- neuroshard/core/governance/voting.py +453 -0
- neuroshard/core/model/__init__.py +30 -0
- neuroshard/core/model/dynamic.py +4186 -0
- neuroshard/core/model/llm.py +905 -0
- neuroshard/core/model/registry.py +164 -0
- neuroshard/core/model/scaler.py +387 -0
- neuroshard/core/model/tokenizer.py +568 -0
- neuroshard/core/network/__init__.py +56 -0
- neuroshard/core/network/connection_pool.py +72 -0
- neuroshard/core/network/dht.py +130 -0
- neuroshard/core/network/dht_plan.py +55 -0
- neuroshard/core/network/dht_proof_store.py +516 -0
- neuroshard/core/network/dht_protocol.py +261 -0
- neuroshard/core/network/dht_service.py +506 -0
- neuroshard/core/network/encrypted_channel.py +141 -0
- neuroshard/core/network/nat.py +201 -0
- neuroshard/core/network/nat_traversal.py +695 -0
- neuroshard/core/network/p2p.py +929 -0
- neuroshard/core/network/p2p_data.py +150 -0
- neuroshard/core/swarm/__init__.py +106 -0
- neuroshard/core/swarm/aggregation.py +729 -0
- neuroshard/core/swarm/buffers.py +643 -0
- neuroshard/core/swarm/checkpoint.py +709 -0
- neuroshard/core/swarm/compute.py +624 -0
- neuroshard/core/swarm/diloco.py +844 -0
- neuroshard/core/swarm/factory.py +1288 -0
- neuroshard/core/swarm/heartbeat.py +669 -0
- neuroshard/core/swarm/logger.py +487 -0
- neuroshard/core/swarm/router.py +658 -0
- neuroshard/core/swarm/service.py +640 -0
- neuroshard/core/training/__init__.py +29 -0
- neuroshard/core/training/checkpoint.py +600 -0
- neuroshard/core/training/distributed.py +1602 -0
- neuroshard/core/training/global_tracker.py +617 -0
- neuroshard/core/training/production.py +276 -0
- neuroshard/governance_cli.py +729 -0
- neuroshard/grpc_server.py +895 -0
- neuroshard/runner.py +3223 -0
- neuroshard/sdk/__init__.py +92 -0
- neuroshard/sdk/client.py +990 -0
- neuroshard/sdk/errors.py +101 -0
- neuroshard/sdk/types.py +282 -0
- neuroshard/tracker/__init__.py +0 -0
- neuroshard/tracker/server.py +864 -0
- neuroshard/ui/__init__.py +0 -0
- neuroshard/ui/app.py +102 -0
- neuroshard/ui/templates/index.html +1052 -0
- neuroshard/utils/__init__.py +0 -0
- neuroshard/utils/autostart.py +81 -0
- neuroshard/utils/hardware.py +121 -0
- neuroshard/utils/serialization.py +90 -0
- neuroshard/version.py +1 -0
- nexaroa-0.0.111.dist-info/METADATA +283 -0
- nexaroa-0.0.111.dist-info/RECORD +78 -0
- nexaroa-0.0.111.dist-info/WHEEL +5 -0
- nexaroa-0.0.111.dist-info/entry_points.txt +4 -0
- nexaroa-0.0.111.dist-info/licenses/LICENSE +190 -0
- nexaroa-0.0.111.dist-info/top_level.txt +2 -0
- protos/__init__.py +0 -0
- protos/neuroshard.proto +651 -0
- protos/neuroshard_pb2.py +160 -0
- protos/neuroshard_pb2_grpc.py +1298 -0
|
@@ -0,0 +1,4186 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Dynamic Model Architecture - True Decentralization
|
|
3
|
+
|
|
4
|
+
This module implements a model that grows and shrinks with the network:
|
|
5
|
+
- NO fixed phases or model sizes
|
|
6
|
+
- Model size = what the network can collectively hold
|
|
7
|
+
- Each node contributes based on its available memory
|
|
8
|
+
- More memory = more layers = more NEURO rewards
|
|
9
|
+
|
|
10
|
+
Key Concepts:
|
|
11
|
+
1. LAYER POOL: The network maintains a pool of layers
|
|
12
|
+
2. DYNAMIC ASSIGNMENT: Nodes claim layers based on their capacity
|
|
13
|
+
3. ORGANIC GROWTH: As more nodes join, model can have more layers
|
|
14
|
+
4. GRACEFUL DEGRADATION: If nodes leave, layers are redistributed
|
|
15
|
+
|
|
16
|
+
Example:
|
|
17
|
+
Day 1: 10 nodes with 4GB each = 40GB total = ~10B params possible
|
|
18
|
+
Day 30: 100 nodes with avg 8GB = 800GB total = ~200B params possible
|
|
19
|
+
|
|
20
|
+
The model AUTOMATICALLY grows as capacity grows.
|
|
21
|
+
No voting, no phases, no central coordination.
|
|
22
|
+
"""
|
|
23
|
+
|
|
24
|
+
import torch
|
|
25
|
+
import torch.nn as nn
|
|
26
|
+
import threading
|
|
27
|
+
import time
|
|
28
|
+
import logging
|
|
29
|
+
import hashlib
|
|
30
|
+
import math
|
|
31
|
+
import psutil # For adaptive memory management
|
|
32
|
+
from typing import Optional, Dict, List, Tuple, Any, Set
|
|
33
|
+
from dataclasses import dataclass, field
|
|
34
|
+
from collections import defaultdict
|
|
35
|
+
from urllib.parse import urlparse
|
|
36
|
+
|
|
37
|
+
logger = logging.getLogger(__name__)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
# Dynamic architecture - NO MORE FIXED DIMENSIONS!
|
|
41
|
+
# Architecture is now calculated based on network capacity
|
|
42
|
+
|
|
43
|
+
# Dynamic vocabulary - starts at 32K, expands WITHOUT LIMIT as network grows
|
|
44
|
+
# The embedding/lm_head grow in chunks when tokenizer vocabulary exceeds current capacity
|
|
45
|
+
INITIAL_VOCAB_SIZE = 32000 # Starting size (efficient for small networks)
|
|
46
|
+
VOCAB_GROWTH_CHUNK = 32000 # Expand by 32K at a time (efficient GPU memory alignment)
|
|
47
|
+
|
|
48
|
+
# NO HARD LIMIT - vocabulary grows with the network
|
|
49
|
+
# The only real constraints are:
|
|
50
|
+
# - Memory: ~4KB per token (at hidden_dim=1024) or ~16KB (at hidden_dim=4096)
|
|
51
|
+
# - Practical: Most use cases covered under 1M tokens
|
|
52
|
+
# For reference: GPT-4 ~100K, Claude ~100K, Gemini ~256K
|
|
53
|
+
# NeuroShard can grow FAR beyond these as a truly decentralized, ever-growing LLM
|
|
54
|
+
MAX_VOCAB_SIZE = None # None = unlimited (constrained only by available memory)
|
|
55
|
+
|
|
56
|
+
# Import the new architecture scaler
|
|
57
|
+
from neuroshard.core.model.scaler import (
|
|
58
|
+
ModelArchitecture,
|
|
59
|
+
calculate_optimal_architecture,
|
|
60
|
+
should_upgrade_architecture,
|
|
61
|
+
estimate_memory_per_layer,
|
|
62
|
+
calculate_layer_assignment,
|
|
63
|
+
)
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
@dataclass
|
|
67
|
+
class LayerAssignment:
|
|
68
|
+
"""Assignment of a layer to a node."""
|
|
69
|
+
layer_id: int
|
|
70
|
+
node_id: str
|
|
71
|
+
node_url: str
|
|
72
|
+
grpc_addr: str
|
|
73
|
+
assigned_at: float = field(default_factory=time.time)
|
|
74
|
+
last_heartbeat: float = field(default_factory=time.time)
|
|
75
|
+
version: int = 0 # Training version
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
@dataclass
|
|
79
|
+
class NetworkCapacity:
|
|
80
|
+
"""Current network capacity."""
|
|
81
|
+
total_nodes: int
|
|
82
|
+
total_memory_mb: float
|
|
83
|
+
max_layers: int # How many layers the network can support
|
|
84
|
+
assigned_layers: int # How many layers are currently assigned
|
|
85
|
+
layer_coverage: Dict[int, int] # layer_id -> replica count
|
|
86
|
+
|
|
87
|
+
|
|
88
|
+
class DynamicLayerPool:
|
|
89
|
+
"""
|
|
90
|
+
Manages the dynamic pool of model layers across the network.
|
|
91
|
+
|
|
92
|
+
This is the core of true decentralization:
|
|
93
|
+
- Layers are assigned based on node capacity
|
|
94
|
+
- Model grows BOTH deeper AND wider as network expands
|
|
95
|
+
- Architecture auto-optimizes based on scaling laws
|
|
96
|
+
- No fixed model size or phases
|
|
97
|
+
|
|
98
|
+
SCALABILITY CONSIDERATIONS:
|
|
99
|
+
==========================
|
|
100
|
+
Small network (1-10 nodes):
|
|
101
|
+
- Each node may hold ALL layers (solo training mode)
|
|
102
|
+
- No layer replication needed
|
|
103
|
+
- Fast startup, immediate training
|
|
104
|
+
|
|
105
|
+
Medium network (10-100 nodes):
|
|
106
|
+
- Layers distributed across multiple nodes
|
|
107
|
+
- MIN_REPLICAS ensures redundancy
|
|
108
|
+
- Pipeline inference works across nodes
|
|
109
|
+
|
|
110
|
+
Large network (100-1000+ nodes):
|
|
111
|
+
- Strong layer distribution
|
|
112
|
+
- MAX_LAYERS_PER_NODE caps per-node load
|
|
113
|
+
- Architecture can scale to 100B+ params
|
|
114
|
+
"""
|
|
115
|
+
|
|
116
|
+
# Minimum replicas per layer for redundancy
|
|
117
|
+
MIN_REPLICAS = 2
|
|
118
|
+
|
|
119
|
+
# Layer heartbeat timeout
|
|
120
|
+
HEARTBEAT_TIMEOUT = 120 # seconds
|
|
121
|
+
|
|
122
|
+
# Maximum layers any single node can hold (prevents memory issues in large networks)
|
|
123
|
+
# In small networks (< 100 nodes), this is effectively unlimited
|
|
124
|
+
# In large networks, it ensures load is distributed
|
|
125
|
+
MAX_LAYERS_PER_NODE = 200
|
|
126
|
+
|
|
127
|
+
# Architecture recalculation triggers
|
|
128
|
+
# NOTE: RECALC_INTERVAL_NODES is now DYNAMIC - see _get_recalc_interval()
|
|
129
|
+
MIN_UPGRADE_IMPROVEMENT = 1.3 # Only upgrade if 30%+ better
|
|
130
|
+
|
|
131
|
+
@staticmethod
|
|
132
|
+
def _get_recalc_interval(node_count: int) -> int:
|
|
133
|
+
"""
|
|
134
|
+
Get dynamic architecture recalculation interval based on network size.
|
|
135
|
+
|
|
136
|
+
At small networks, recalculate more often (every node matters).
|
|
137
|
+
At large networks, recalculate less often (stability).
|
|
138
|
+
|
|
139
|
+
Formula: min(max(5, node_count // 10), 100)
|
|
140
|
+
- 1-50 nodes: every 5 nodes
|
|
141
|
+
- 51-100 nodes: every 5-10 nodes
|
|
142
|
+
- 100-1000 nodes: every 10-100 nodes
|
|
143
|
+
- 1000+ nodes: every 100 nodes
|
|
144
|
+
"""
|
|
145
|
+
return min(max(5, node_count // 10), 100)
|
|
146
|
+
|
|
147
|
+
def __init__(self, dht_protocol=None):
|
|
148
|
+
self.dht = dht_protocol
|
|
149
|
+
|
|
150
|
+
# Layer assignments
|
|
151
|
+
self.layer_assignments: Dict[int, List[LayerAssignment]] = defaultdict(list)
|
|
152
|
+
|
|
153
|
+
# Node capacities
|
|
154
|
+
self.node_capacities: Dict[str, float] = {} # node_id -> available_mb
|
|
155
|
+
|
|
156
|
+
# DYNAMIC VOCAB: Track current vocabulary capacity for memory calculation
|
|
157
|
+
# This is updated when vocab expands and affects layer assignment
|
|
158
|
+
self.vocab_capacity: int = INITIAL_VOCAB_SIZE
|
|
159
|
+
|
|
160
|
+
# DYNAMIC ARCHITECTURE (auto-updates as network grows)
|
|
161
|
+
self.current_architecture: Optional[ModelArchitecture] = None
|
|
162
|
+
self.architecture_version: int = 0
|
|
163
|
+
self.last_node_count: int = 0
|
|
164
|
+
|
|
165
|
+
# Legacy fields (for compatibility)
|
|
166
|
+
self.current_num_layers = 0
|
|
167
|
+
self.embedding_holder: Optional[str] = None
|
|
168
|
+
self.lm_head_holder: Optional[str] = None
|
|
169
|
+
|
|
170
|
+
# Threading
|
|
171
|
+
self.lock = threading.Lock()
|
|
172
|
+
|
|
173
|
+
logger.info("DynamicLayerPool initialized with dynamic width + depth scaling")
|
|
174
|
+
|
|
175
|
+
def _auto_recalculate_architecture(self):
|
|
176
|
+
"""
|
|
177
|
+
AUTOMATED architecture optimization - no human intervention needed.
|
|
178
|
+
|
|
179
|
+
Calculates optimal architecture based on current network capacity
|
|
180
|
+
and triggers upgrade if improvement is significant.
|
|
181
|
+
"""
|
|
182
|
+
total_memory = sum(self.node_capacities.values())
|
|
183
|
+
optimal = calculate_optimal_architecture(total_memory)
|
|
184
|
+
|
|
185
|
+
if self.current_architecture is None:
|
|
186
|
+
# First initialization
|
|
187
|
+
self.current_architecture = optimal
|
|
188
|
+
self.current_num_layers = optimal.num_layers
|
|
189
|
+
self.architecture_version = 1
|
|
190
|
+
logger.info(f"🚀 Initial architecture: {optimal.num_layers}L × {optimal.hidden_dim}H "
|
|
191
|
+
f"({optimal.estimate_params()/1e6:.0f}M params)")
|
|
192
|
+
return
|
|
193
|
+
|
|
194
|
+
# Check if upgrade is worthwhile
|
|
195
|
+
should_upgrade, reason = should_upgrade_architecture(
|
|
196
|
+
self.current_architecture,
|
|
197
|
+
optimal,
|
|
198
|
+
self.MIN_UPGRADE_IMPROVEMENT
|
|
199
|
+
)
|
|
200
|
+
|
|
201
|
+
if should_upgrade:
|
|
202
|
+
logger.warning(f"🔄 ARCHITECTURE UPGRADE TRIGGERED!")
|
|
203
|
+
logger.warning(f" {reason}")
|
|
204
|
+
logger.warning(f" Old: {self.current_architecture.num_layers}L × {self.current_architecture.hidden_dim}H")
|
|
205
|
+
logger.warning(f" New: {optimal.num_layers}L × {optimal.hidden_dim}H")
|
|
206
|
+
logger.warning(f" Nodes will gradually migrate to new architecture on restart")
|
|
207
|
+
|
|
208
|
+
# Update architecture (new nodes will use new arch)
|
|
209
|
+
self.current_architecture = optimal
|
|
210
|
+
self.current_num_layers = optimal.num_layers
|
|
211
|
+
self.architecture_version += 1
|
|
212
|
+
|
|
213
|
+
# TODO: Trigger distillation-based migration for existing nodes
|
|
214
|
+
# For now, existing nodes keep their architecture until restart
|
|
215
|
+
else:
|
|
216
|
+
logger.debug(f"Architecture recalculation: no upgrade needed ({reason})")
|
|
217
|
+
|
|
218
|
+
def register_node(
|
|
219
|
+
self,
|
|
220
|
+
node_id: str,
|
|
221
|
+
node_url: str,
|
|
222
|
+
grpc_addr: str,
|
|
223
|
+
available_memory_mb: float,
|
|
224
|
+
staked_amount: float = 0.0
|
|
225
|
+
) -> List[int]:
|
|
226
|
+
"""
|
|
227
|
+
Register a node and assign layers based on its capacity AND stake.
|
|
228
|
+
|
|
229
|
+
AUTOMATIC ARCHITECTURE SCALING:
|
|
230
|
+
- Periodically recalculates optimal architecture as network grows
|
|
231
|
+
- Triggers upgrades when capacity increases significantly
|
|
232
|
+
- New nodes automatically use latest architecture
|
|
233
|
+
|
|
234
|
+
Validator role requires:
|
|
235
|
+
1. Sufficient memory (>2GB)
|
|
236
|
+
2. Minimum stake (100 NEURO) - prevents Sybil attacks
|
|
237
|
+
|
|
238
|
+
Returns list of layer IDs assigned to this node.
|
|
239
|
+
"""
|
|
240
|
+
# Import validator requirements from centralized economics (with dynamic scaling!)
|
|
241
|
+
from neuroshard.core.economics.constants import (
|
|
242
|
+
VALIDATOR_MIN_MEMORY_MB,
|
|
243
|
+
get_dynamic_validator_stake
|
|
244
|
+
)
|
|
245
|
+
|
|
246
|
+
with self.lock:
|
|
247
|
+
self.node_capacities[node_id] = available_memory_mb
|
|
248
|
+
|
|
249
|
+
# AUTO-TRIGGER: Recalculate architecture if network grew significantly
|
|
250
|
+
node_count = len(self.node_capacities)
|
|
251
|
+
recalc_interval = self._get_recalc_interval(node_count)
|
|
252
|
+
if (node_count - self.last_node_count) >= recalc_interval:
|
|
253
|
+
self._auto_recalculate_architecture()
|
|
254
|
+
self.last_node_count = node_count
|
|
255
|
+
|
|
256
|
+
# Ensure we have an architecture
|
|
257
|
+
if self.current_architecture is None:
|
|
258
|
+
self._auto_recalculate_architecture()
|
|
259
|
+
|
|
260
|
+
# Calculate how many layers this node can hold
|
|
261
|
+
# Uses current architecture's dimensions (dynamic!)
|
|
262
|
+
# DEVICE-AWARE safety factors
|
|
263
|
+
# With gradient checkpointing always enabled, CPU can use higher factor
|
|
264
|
+
device_type = getattr(self, '_device_hint', 'cpu')
|
|
265
|
+
if device_type == 'cuda':
|
|
266
|
+
safety_factor = 0.6 # GPU: efficient memory usage
|
|
267
|
+
elif device_type == 'mps':
|
|
268
|
+
safety_factor = 0.5 # Apple Silicon: moderate overhead
|
|
269
|
+
else:
|
|
270
|
+
safety_factor = 0.5 # CPU: increased from 0.3 (checkpointing reduces overhead)
|
|
271
|
+
|
|
272
|
+
# SMART DISTRIBUTED TRAINING DESIGN:
|
|
273
|
+
# Goal: Enable meaningful training for ANY network size
|
|
274
|
+
#
|
|
275
|
+
# KEY INSIGHT: When a "full node" exists (has embedding + LM head),
|
|
276
|
+
# new nodes should NOT become Drivers (would create broken overlap).
|
|
277
|
+
# Instead, they should become WORKER+VALIDATOR to enable proper pipeline.
|
|
278
|
+
#
|
|
279
|
+
# Network States:
|
|
280
|
+
# 1. Empty network → First node becomes DRIVER, grows to full node
|
|
281
|
+
# 2. One full node exists → New node becomes WORKER+VALIDATOR (pipeline!)
|
|
282
|
+
# 3. Multiple partial nodes → Fill based on MIN_REPLICAS
|
|
283
|
+
|
|
284
|
+
# FULLY DECENTRALIZED: Discover network state from DHT ONLY (no tracker fallback!)
|
|
285
|
+
# Each node has its own LOCAL layer_pool, so we need DHT for network-wide view.
|
|
286
|
+
# BUT: DHT may have STALE data from previous runs - don't blindly trust it!
|
|
287
|
+
dht_layers = set()
|
|
288
|
+
|
|
289
|
+
# DHT discovery (P2P must be connected BEFORE start() for this to work!)
|
|
290
|
+
if self.dht:
|
|
291
|
+
dht_layers = self._discover_network_layers_from_dht()
|
|
292
|
+
if dht_layers:
|
|
293
|
+
highest_layer = max(dht_layers)
|
|
294
|
+
# SANITY CHECK: Only expand if DHT layers are "reasonable"
|
|
295
|
+
max_reasonable = max(32, self.current_num_layers * 2)
|
|
296
|
+
if highest_layer >= self.current_num_layers and highest_layer < max_reasonable:
|
|
297
|
+
self.current_num_layers = highest_layer + 1
|
|
298
|
+
logger.info(f"DHT discovery: network has {self.current_num_layers} layers")
|
|
299
|
+
elif highest_layer >= max_reasonable:
|
|
300
|
+
logger.warning(f"DHT shows {highest_layer + 1} layers but seems stale")
|
|
301
|
+
logger.warning(f"Ignoring stale DHT data - will use checkpoint layer count")
|
|
302
|
+
else:
|
|
303
|
+
logger.info("DHT: No existing layers found - this may be first node or peers not yet discovered")
|
|
304
|
+
else:
|
|
305
|
+
logger.warning("DHT not available - layer assignment will be solo mode")
|
|
306
|
+
|
|
307
|
+
driver_count = len(self.layer_assignments.get(0, []))
|
|
308
|
+
validator_layer = max(0, self.current_num_layers - 1) if self.current_num_layers > 0 else 0
|
|
309
|
+
validator_count = len(self.layer_assignments.get(validator_layer, []))
|
|
310
|
+
|
|
311
|
+
# CRITICAL: Check if a FULL NODE exists (FULLY DECENTRALIZED - DHT only!)
|
|
312
|
+
# A full node has BOTH layer 0 (embedding) AND last layer (LM head)
|
|
313
|
+
has_full_node = False
|
|
314
|
+
full_node_id = None
|
|
315
|
+
|
|
316
|
+
# Check local assignments first
|
|
317
|
+
if self.current_num_layers > 0:
|
|
318
|
+
layer_0_holders = {a.node_id for a in self.layer_assignments.get(0, [])}
|
|
319
|
+
last_layer_holders = {a.node_id for a in self.layer_assignments.get(validator_layer, [])}
|
|
320
|
+
full_node_ids = layer_0_holders & last_layer_holders # Intersection
|
|
321
|
+
if full_node_ids:
|
|
322
|
+
has_full_node = True
|
|
323
|
+
full_node_id = next(iter(full_node_ids))
|
|
324
|
+
logger.info(f"Full node detected (local): {full_node_id[:8]}... (has layers 0-{validator_layer})")
|
|
325
|
+
|
|
326
|
+
# Also check DHT - if both layer 0 AND last layer exist, a full node likely exists
|
|
327
|
+
if not has_full_node and dht_layers:
|
|
328
|
+
if 0 in dht_layers and validator_layer in dht_layers:
|
|
329
|
+
has_full_node = True
|
|
330
|
+
full_node_id = "unknown_from_dht"
|
|
331
|
+
logger.info(f"Full node detected (DHT): network has layers 0-{validator_layer}")
|
|
332
|
+
|
|
333
|
+
# ROLE ASSIGNMENT PRIORITY (redesigned for proper distributed training):
|
|
334
|
+
#
|
|
335
|
+
# If NO full node exists:
|
|
336
|
+
# 1. First node → DRIVER (will grow to full node)
|
|
337
|
+
# 2. Additional nodes → fill based on MIN_REPLICAS
|
|
338
|
+
#
|
|
339
|
+
# If a FULL NODE exists:
|
|
340
|
+
# - DON'T create another Driver (would overlap and break training!)
|
|
341
|
+
# - New nodes become WORKER+VALIDATOR for proper pipeline
|
|
342
|
+
# - This creates: FullNode[0-N] → NewNode[N+1 to Last] pipeline
|
|
343
|
+
|
|
344
|
+
if has_full_node:
|
|
345
|
+
# A full node exists - new nodes should NOT overlap with it
|
|
346
|
+
# Become WORKER+VALIDATOR to enable pipeline training!
|
|
347
|
+
if node_id == full_node_id:
|
|
348
|
+
# This IS the full node re-registering
|
|
349
|
+
role_hint = "DRIVER" # Keep as driver (already full)
|
|
350
|
+
needs_embedding = True
|
|
351
|
+
logger.info(f"This is the full node - keeping as DRIVER")
|
|
352
|
+
else:
|
|
353
|
+
# New node joining a network with a full node
|
|
354
|
+
# Become WORKER+VALIDATOR: get last layer + work backwards
|
|
355
|
+
# This enables pipeline: FullNode[embedding→layers] → Us[layers→LM_head]
|
|
356
|
+
role_hint = "WORKER+VALIDATOR"
|
|
357
|
+
needs_embedding = False # Don't need embedding, saves memory!
|
|
358
|
+
logger.info(f"Full node exists - becoming WORKER+VALIDATOR for pipeline training")
|
|
359
|
+
else:
|
|
360
|
+
# No full node - use standard role assignment
|
|
361
|
+
needs_more_drivers = driver_count < self.MIN_REPLICAS
|
|
362
|
+
needs_more_validators = validator_count < self.MIN_REPLICAS
|
|
363
|
+
|
|
364
|
+
if needs_more_drivers:
|
|
365
|
+
# Network needs more Drivers for MIN_REPLICAS
|
|
366
|
+
role_hint = "DRIVER"
|
|
367
|
+
needs_embedding = True
|
|
368
|
+
logger.info(f"Network needs more Drivers ({driver_count}/{self.MIN_REPLICAS})")
|
|
369
|
+
elif needs_more_validators and self.current_num_layers > 1:
|
|
370
|
+
# Network needs more Validators for MIN_REPLICAS
|
|
371
|
+
role_hint = "WORKER+VALIDATOR"
|
|
372
|
+
needs_embedding = False
|
|
373
|
+
logger.info(f"Network needs more Validators ({validator_count}/{self.MIN_REPLICAS})")
|
|
374
|
+
else:
|
|
375
|
+
# Network has enough of both → become Worker for layer redundancy
|
|
376
|
+
role_hint = "WORKER"
|
|
377
|
+
needs_embedding = False
|
|
378
|
+
logger.info(f"Network has enough Drivers ({driver_count}) and Validators ({validator_count})")
|
|
379
|
+
|
|
380
|
+
max_layers_for_node = calculate_layer_assignment(
|
|
381
|
+
available_memory_mb,
|
|
382
|
+
self.current_architecture,
|
|
383
|
+
safety_factor=safety_factor,
|
|
384
|
+
vocab_capacity=self.vocab_capacity,
|
|
385
|
+
training_mode=True,
|
|
386
|
+
needs_embedding=needs_embedding
|
|
387
|
+
)
|
|
388
|
+
|
|
389
|
+
logger.info(f"Layer calculation ({role_hint}): {available_memory_mb:.0f}MB × {safety_factor} safety = {max_layers_for_node} layers "
|
|
390
|
+
f"(needs_embedding={needs_embedding}, drivers={driver_count}, validators={validator_count})")
|
|
391
|
+
|
|
392
|
+
# SCALABILITY: Apply MAX_LAYERS_PER_NODE cap in large networks
|
|
393
|
+
# This prevents single nodes from hogging all layers and ensures
|
|
394
|
+
# load distribution as the network grows
|
|
395
|
+
node_count = len(self.node_capacities)
|
|
396
|
+
if node_count > 100:
|
|
397
|
+
# In large networks, cap layers per node
|
|
398
|
+
max_layers_for_node = min(max_layers_for_node, self.MAX_LAYERS_PER_NODE)
|
|
399
|
+
logger.debug(f"Large network ({node_count} nodes): capped to {max_layers_for_node} layers")
|
|
400
|
+
|
|
401
|
+
if max_layers_for_node < 1:
|
|
402
|
+
logger.warning(f"Node {node_id[:8]}... has insufficient memory for even 1 layer")
|
|
403
|
+
return []
|
|
404
|
+
|
|
405
|
+
# Find layers that need more replicas
|
|
406
|
+
assigned_layers = []
|
|
407
|
+
|
|
408
|
+
# SCALABILITY STRATEGY:
|
|
409
|
+
# High-capacity nodes (>8GB) are prioritized for Layer 0 (Driver) and Last Layer (Validator)
|
|
410
|
+
# This creates parallel pipelines ("Training Gangs")
|
|
411
|
+
is_high_capacity = available_memory_mb > 8000
|
|
412
|
+
is_medium_capacity = available_memory_mb > VALIDATOR_MIN_MEMORY_MB
|
|
413
|
+
|
|
414
|
+
# Count current validators for DYNAMIC stake requirement
|
|
415
|
+
num_drivers = len(self.layer_assignments[0])
|
|
416
|
+
num_validators = len(self.layer_assignments[max(0, self.current_num_layers - 1)])
|
|
417
|
+
|
|
418
|
+
# VALIDATOR ELIGIBILITY: Dynamic stake based on network size!
|
|
419
|
+
# - Few validators (1-10): 100 NEURO (accessible for bootstrap)
|
|
420
|
+
# - Many validators (1000+): 2500 NEURO (security at scale)
|
|
421
|
+
required_validator_stake = get_dynamic_validator_stake(num_validators)
|
|
422
|
+
has_validator_stake = staked_amount >= required_validator_stake
|
|
423
|
+
|
|
424
|
+
# DISTRIBUTED TRAINING ASSIGNMENT:
|
|
425
|
+
# Based on role_hint from capacity calculation, assign appropriate layers
|
|
426
|
+
|
|
427
|
+
last_layer = max(0, self.current_num_layers - 1) if self.current_num_layers > 0 else 0
|
|
428
|
+
|
|
429
|
+
# Determine role based on network needs and node capacity
|
|
430
|
+
if role_hint == "DRIVER":
|
|
431
|
+
# Become Driver: get Layer 0 + as many layers as we can
|
|
432
|
+
should_be_driver = True
|
|
433
|
+
should_be_validator = False # Driver doesn't need to also be Validator
|
|
434
|
+
logger.info(f"Node {node_id[:8]}... assigned as DRIVER (first node, will embed data)")
|
|
435
|
+
|
|
436
|
+
elif role_hint == "WORKER+VALIDATOR":
|
|
437
|
+
# Become Worker+Validator: skip Layer 0, get middle + last layer
|
|
438
|
+
# This is the KEY for distributed training - enables pipeline with loss computation
|
|
439
|
+
should_be_driver = False
|
|
440
|
+
should_be_validator = True # Need LM head to compute loss!
|
|
441
|
+
# For bootstrap, allow Validator without stake requirement
|
|
442
|
+
has_validator_stake = True # Bootstrap mode - stake checked later
|
|
443
|
+
logger.info(f"Node {node_id[:8]}... assigned as WORKER+VALIDATOR (will compute loss, enable distributed training)")
|
|
444
|
+
|
|
445
|
+
else: # "WORKER"
|
|
446
|
+
# Become Worker: middle layers only (redundancy)
|
|
447
|
+
should_be_driver = False
|
|
448
|
+
should_be_validator = False
|
|
449
|
+
logger.info(f"Node {node_id[:8]}... assigned as WORKER (middle layers, redundancy)")
|
|
450
|
+
|
|
451
|
+
# Assign Layer 0 (Driver) - if we should be Driver
|
|
452
|
+
if should_be_driver and len(assigned_layers) < max_layers_for_node:
|
|
453
|
+
if not any(a.node_id == node_id for a in self.layer_assignments[0]):
|
|
454
|
+
self._assign_layer(0, node_id, node_url, grpc_addr)
|
|
455
|
+
assigned_layers.append(0)
|
|
456
|
+
# Ensure current_num_layers accounts for layer 0
|
|
457
|
+
if self.current_num_layers == 0:
|
|
458
|
+
self.current_num_layers = 1
|
|
459
|
+
|
|
460
|
+
# 2. DISTRIBUTED LAYER ASSIGNMENT
|
|
461
|
+
# Based on role, fill appropriate layers for pipeline parallelism
|
|
462
|
+
is_driver = 0 in assigned_layers
|
|
463
|
+
last_layer = max(0, self.current_num_layers - 1)
|
|
464
|
+
|
|
465
|
+
if role_hint == "WORKER+VALIDATOR":
|
|
466
|
+
# CRITICAL: Assign LAST layer FIRST (for LM head) to ensure we can compute loss
|
|
467
|
+
if len(assigned_layers) < max_layers_for_node:
|
|
468
|
+
if last_layer not in assigned_layers and not any(a.node_id == node_id for a in self.layer_assignments[last_layer]):
|
|
469
|
+
self._assign_layer(last_layer, node_id, node_url, grpc_addr)
|
|
470
|
+
assigned_layers.append(last_layer)
|
|
471
|
+
logger.info(f"Node {node_id[:8]}... assigned last layer {last_layer} (Validator role)")
|
|
472
|
+
|
|
473
|
+
# Then fill middle layers from the END (closer to Validator)
|
|
474
|
+
# This creates contiguous layer ranges: Driver has 0-N, Validator has M-31
|
|
475
|
+
for layer_id in range(last_layer - 1, 0, -1): # Reverse order, skip layer 0
|
|
476
|
+
if len(assigned_layers) >= max_layers_for_node:
|
|
477
|
+
break
|
|
478
|
+
if layer_id not in assigned_layers and not any(a.node_id == node_id for a in self.layer_assignments[layer_id]):
|
|
479
|
+
self._assign_layer(layer_id, node_id, node_url, grpc_addr)
|
|
480
|
+
assigned_layers.append(layer_id)
|
|
481
|
+
else:
|
|
482
|
+
# Driver or Worker: fill from layer 1 onwards
|
|
483
|
+
layers_to_check = range(1, self.current_num_layers)
|
|
484
|
+
|
|
485
|
+
for layer_id in layers_to_check:
|
|
486
|
+
if len(assigned_layers) >= max_layers_for_node:
|
|
487
|
+
break
|
|
488
|
+
|
|
489
|
+
current_replicas = len(self.layer_assignments[layer_id])
|
|
490
|
+
if current_replicas < self.MIN_REPLICAS:
|
|
491
|
+
if layer_id not in assigned_layers and not any(a.node_id == node_id for a in self.layer_assignments[layer_id]):
|
|
492
|
+
self._assign_layer(layer_id, node_id, node_url, grpc_addr)
|
|
493
|
+
assigned_layers.append(layer_id)
|
|
494
|
+
|
|
495
|
+
# 3. Assign Last Layer (Validator) if Driver with extra capacity
|
|
496
|
+
if should_be_validator and len(assigned_layers) < max_layers_for_node:
|
|
497
|
+
if last_layer not in assigned_layers and not any(a.node_id == node_id for a in self.layer_assignments[last_layer]):
|
|
498
|
+
self._assign_layer(last_layer, node_id, node_url, grpc_addr)
|
|
499
|
+
assigned_layers.append(last_layer)
|
|
500
|
+
|
|
501
|
+
# 4. Calculate remaining capacity
|
|
502
|
+
remaining_capacity = max_layers_for_node - len(assigned_layers)
|
|
503
|
+
|
|
504
|
+
# 5. DHT layer discovery already done above (at role assignment)
|
|
505
|
+
# dht_layers variable is already populated
|
|
506
|
+
|
|
507
|
+
# 6. If we still have capacity, grow the model
|
|
508
|
+
if remaining_capacity > 0:
|
|
509
|
+
# CAP MODEL GROWTH for solo/early network as safety net
|
|
510
|
+
# Even with gradient checkpointing, too many layers cause OOM due to:
|
|
511
|
+
# - Optimizer states (2x model size for Adam)
|
|
512
|
+
# - Gradient accumulation during backward pass
|
|
513
|
+
# - PyTorch memory fragmentation
|
|
514
|
+
# 32 layers is safe for most devices (~200M params with 512 hidden)
|
|
515
|
+
MAX_SOLO_LAYERS = 32 # Conservative cap for solo node training
|
|
516
|
+
total_nodes = len(set(
|
|
517
|
+
a.node_id for assignments in self.layer_assignments.values()
|
|
518
|
+
for a in assignments
|
|
519
|
+
))
|
|
520
|
+
|
|
521
|
+
if total_nodes <= 2: # Solo or near-solo mode
|
|
522
|
+
max_growth = max(0, MAX_SOLO_LAYERS - len(assigned_layers))
|
|
523
|
+
if remaining_capacity > max_growth:
|
|
524
|
+
logger.warning(f"[SOLO MODE] Capping growth from {remaining_capacity} to {max_growth} layers "
|
|
525
|
+
f"(MAX_SOLO_LAYERS={MAX_SOLO_LAYERS} prevents OOM)")
|
|
526
|
+
remaining_capacity = max_growth
|
|
527
|
+
|
|
528
|
+
# Add new layers to grow the model
|
|
529
|
+
new_layers = self._grow_model(remaining_capacity, node_id, node_url, grpc_addr)
|
|
530
|
+
assigned_layers.extend(new_layers)
|
|
531
|
+
|
|
532
|
+
# Handle embedding and LM head tracking
|
|
533
|
+
# Any node with Layer 0 has embedding
|
|
534
|
+
if 0 in assigned_layers:
|
|
535
|
+
# Update tracking (just keeps one for reference, but multiple exist)
|
|
536
|
+
self.embedding_holder = node_id
|
|
537
|
+
logger.info(f"Node {node_id[:8]}... became a Driver (Layer 0)")
|
|
538
|
+
|
|
539
|
+
# Any node with highest assigned layer becomes Validator (LM head holder)
|
|
540
|
+
# CRITICAL FIX: Use max(assigned_layers), NOT current_num_layers - 1
|
|
541
|
+
# This handles the case where checkpoint has fewer layers than stale DHT data
|
|
542
|
+
if assigned_layers:
|
|
543
|
+
actual_last_layer = max(assigned_layers)
|
|
544
|
+
# Check if this node holds the highest layer in the network
|
|
545
|
+
# OR if there's no other holder for this layer yet
|
|
546
|
+
if not self.lm_head_holder or actual_last_layer >= (self.current_num_layers - 1):
|
|
547
|
+
self.lm_head_holder = node_id
|
|
548
|
+
self.current_num_layers = actual_last_layer + 1 # Update to match reality
|
|
549
|
+
logger.info(f"Node {node_id[:8]}... became a Validator (Layer {actual_last_layer})")
|
|
550
|
+
|
|
551
|
+
# EARLY NETWORK NOTICE: When there are <10 nodes, each must hold many/all layers
|
|
552
|
+
# This is TEMPORARY - as more nodes join, layers will be distributed
|
|
553
|
+
if len(assigned_layers) > 50:
|
|
554
|
+
logger.warning(f"Node {node_id[:8]}... holding {len(assigned_layers)} layers due to low network size")
|
|
555
|
+
logger.warning(f"This is temporary - as more nodes join, model will shard across network")
|
|
556
|
+
|
|
557
|
+
logger.info(f"Node {node_id[:8]}... registered: {len(assigned_layers)} layers assigned "
|
|
558
|
+
f"(capacity: {max_layers_for_node} layers, {available_memory_mb:.0f}MB)")
|
|
559
|
+
|
|
560
|
+
return assigned_layers
|
|
561
|
+
|
|
562
|
+
def _assign_layer(self, layer_id: int, node_id: str, node_url: str, grpc_addr: str):
|
|
563
|
+
"""Assign a layer to a node."""
|
|
564
|
+
assignment = LayerAssignment(
|
|
565
|
+
layer_id=layer_id,
|
|
566
|
+
node_id=node_id,
|
|
567
|
+
node_url=node_url,
|
|
568
|
+
grpc_addr=grpc_addr,
|
|
569
|
+
)
|
|
570
|
+
self.layer_assignments[layer_id].append(assignment)
|
|
571
|
+
|
|
572
|
+
# Announce to DHT
|
|
573
|
+
if self.dht:
|
|
574
|
+
try:
|
|
575
|
+
import json
|
|
576
|
+
key = f"layer_{layer_id}"
|
|
577
|
+
current = self.dht.lookup_value(key)
|
|
578
|
+
holders = json.loads(current) if current else []
|
|
579
|
+
if grpc_addr not in holders:
|
|
580
|
+
holders.append(grpc_addr)
|
|
581
|
+
self.dht.store(key, json.dumps(holders))
|
|
582
|
+
except Exception as e:
|
|
583
|
+
logger.debug(f"DHT announce failed: {e}")
|
|
584
|
+
|
|
585
|
+
def _discover_network_layers_from_dht(self) -> Set[int]:
|
|
586
|
+
"""
|
|
587
|
+
Query DHT to discover which layers exist in the network.
|
|
588
|
+
|
|
589
|
+
DECENTRALIZED COORDINATION:
|
|
590
|
+
- Each node announces "layer_X" to DHT when it holds layer X
|
|
591
|
+
- New nodes query DHT to see what layers already exist
|
|
592
|
+
- This prevents layer overlap without centralized coordination
|
|
593
|
+
"""
|
|
594
|
+
discovered_layers = set()
|
|
595
|
+
|
|
596
|
+
if not self.dht:
|
|
597
|
+
return discovered_layers
|
|
598
|
+
|
|
599
|
+
try:
|
|
600
|
+
# Query for layers 0-1000 (reasonable max)
|
|
601
|
+
# DHT lookup is fast - O(log N) hops
|
|
602
|
+
for layer_id in range(min(1000, self.current_num_layers + 100)):
|
|
603
|
+
key = f"layer_{layer_id}"
|
|
604
|
+
try:
|
|
605
|
+
value = self.dht.lookup_value(key)
|
|
606
|
+
if value:
|
|
607
|
+
discovered_layers.add(layer_id)
|
|
608
|
+
except Exception:
|
|
609
|
+
continue
|
|
610
|
+
|
|
611
|
+
if discovered_layers:
|
|
612
|
+
logger.info(f"DHT layer discovery: found {len(discovered_layers)} layers "
|
|
613
|
+
f"(range: {min(discovered_layers)}-{max(discovered_layers)})")
|
|
614
|
+
except Exception as e:
|
|
615
|
+
logger.debug(f"DHT layer discovery failed: {e}")
|
|
616
|
+
|
|
617
|
+
return discovered_layers
|
|
618
|
+
|
|
619
|
+
def _grow_model(
|
|
620
|
+
self,
|
|
621
|
+
num_new_layers: int,
|
|
622
|
+
node_id: str,
|
|
623
|
+
node_url: str,
|
|
624
|
+
grpc_addr: str
|
|
625
|
+
) -> List[int]:
|
|
626
|
+
"""
|
|
627
|
+
Grow the model by adding new layers.
|
|
628
|
+
|
|
629
|
+
This is how the model organically grows with the network!
|
|
630
|
+
"""
|
|
631
|
+
new_layers = []
|
|
632
|
+
|
|
633
|
+
for _ in range(num_new_layers):
|
|
634
|
+
new_layer_id = self.current_num_layers
|
|
635
|
+
self._assign_layer(new_layer_id, node_id, node_url, grpc_addr)
|
|
636
|
+
new_layers.append(new_layer_id)
|
|
637
|
+
self.current_num_layers += 1
|
|
638
|
+
|
|
639
|
+
if new_layers:
|
|
640
|
+
logger.info(f"Model grew: now {self.current_num_layers} layers "
|
|
641
|
+
f"(added layers {new_layers[0]}-{new_layers[-1]})")
|
|
642
|
+
|
|
643
|
+
return new_layers
|
|
644
|
+
|
|
645
|
+
def upgrade_to_validator(self, node_id: str, node_url: str, grpc_addr: str) -> bool:
|
|
646
|
+
"""
|
|
647
|
+
Upgrade a node to Validator role (assign LM head) when stake requirement is met.
|
|
648
|
+
|
|
649
|
+
This is called when a node stakes enough NEURO to become a Validator.
|
|
650
|
+
No restart required - the node's role is upgraded dynamically.
|
|
651
|
+
|
|
652
|
+
Returns True if upgrade was successful.
|
|
653
|
+
"""
|
|
654
|
+
from neuroshard.core.economics.constants import VALIDATOR_MIN_MEMORY_MB
|
|
655
|
+
|
|
656
|
+
with self.lock:
|
|
657
|
+
# Check if node has sufficient memory
|
|
658
|
+
memory = self.node_capacities.get(node_id, 0)
|
|
659
|
+
if memory < VALIDATOR_MIN_MEMORY_MB:
|
|
660
|
+
logger.warning(f"Node {node_id[:8]}... cannot be Validator: insufficient memory ({memory}MB)")
|
|
661
|
+
return False
|
|
662
|
+
|
|
663
|
+
# Check if already a validator
|
|
664
|
+
last_layer = max(0, self.current_num_layers - 1)
|
|
665
|
+
if any(a.node_id == node_id for a in self.layer_assignments[last_layer]):
|
|
666
|
+
logger.info(f"Node {node_id[:8]}... is already a Validator")
|
|
667
|
+
return True
|
|
668
|
+
|
|
669
|
+
# Assign the last layer (LM head)
|
|
670
|
+
self._assign_layer(last_layer, node_id, node_url, grpc_addr)
|
|
671
|
+
self.lm_head_holder = node_id
|
|
672
|
+
|
|
673
|
+
logger.info(f"Node {node_id[:8]}... upgraded to VALIDATOR (assigned layer {last_layer})")
|
|
674
|
+
return True
|
|
675
|
+
|
|
676
|
+
def demote_from_validator(self, node_id: str) -> bool:
|
|
677
|
+
"""
|
|
678
|
+
Demote a node from Validator role when stake drops below requirement.
|
|
679
|
+
|
|
680
|
+
This is called when:
|
|
681
|
+
1. A validator unstakes and drops below the required amount
|
|
682
|
+
2. The network grows and the required stake increases (tier change)
|
|
683
|
+
|
|
684
|
+
The node keeps its other layer assignments but loses the LM head.
|
|
685
|
+
|
|
686
|
+
Returns True if demotion was successful.
|
|
687
|
+
"""
|
|
688
|
+
with self.lock:
|
|
689
|
+
return self._demote_from_validator_unlocked(node_id)
|
|
690
|
+
|
|
691
|
+
def _demote_from_validator_unlocked(self, node_id: str) -> bool:
|
|
692
|
+
"""
|
|
693
|
+
Internal demotion logic (caller must hold self.lock).
|
|
694
|
+
|
|
695
|
+
Split out to avoid deadlock when called from validate_all_validators().
|
|
696
|
+
"""
|
|
697
|
+
last_layer = max(0, self.current_num_layers - 1)
|
|
698
|
+
|
|
699
|
+
# Check if node is currently a validator
|
|
700
|
+
current_assignments = self.layer_assignments.get(last_layer, [])
|
|
701
|
+
was_validator = any(a.node_id == node_id for a in current_assignments)
|
|
702
|
+
|
|
703
|
+
if not was_validator:
|
|
704
|
+
logger.debug(f"Node {node_id[:8]}... is not a validator, nothing to demote")
|
|
705
|
+
return False
|
|
706
|
+
|
|
707
|
+
# Remove from last layer assignments
|
|
708
|
+
self.layer_assignments[last_layer] = [
|
|
709
|
+
a for a in current_assignments if a.node_id != node_id
|
|
710
|
+
]
|
|
711
|
+
|
|
712
|
+
# Update lm_head_holder if this was the holder
|
|
713
|
+
if self.lm_head_holder == node_id:
|
|
714
|
+
# Find another validator if available
|
|
715
|
+
remaining = self.layer_assignments.get(last_layer, [])
|
|
716
|
+
if remaining:
|
|
717
|
+
self.lm_head_holder = remaining[0].node_id
|
|
718
|
+
else:
|
|
719
|
+
self.lm_head_holder = None
|
|
720
|
+
|
|
721
|
+
logger.warning(f"Node {node_id[:8]}... DEMOTED from Validator (insufficient stake)")
|
|
722
|
+
return True
|
|
723
|
+
|
|
724
|
+
def validate_all_validators(self, get_stake_fn) -> List[str]:
|
|
725
|
+
"""
|
|
726
|
+
Validate all current validators still meet stake requirements.
|
|
727
|
+
|
|
728
|
+
Called periodically or when stake tier changes to ensure all validators
|
|
729
|
+
have sufficient stake for the current network size.
|
|
730
|
+
|
|
731
|
+
IMPORTANT: Never demotes below MIN_VALIDATORS (2) to ensure the network
|
|
732
|
+
can always compute real loss. The stake requirement only applies to
|
|
733
|
+
validators beyond the minimum.
|
|
734
|
+
|
|
735
|
+
Args:
|
|
736
|
+
get_stake_fn: Function(node_id) -> float that returns current stake
|
|
737
|
+
|
|
738
|
+
Returns:
|
|
739
|
+
List of node_ids that were demoted
|
|
740
|
+
"""
|
|
741
|
+
from neuroshard.core.economics.constants import get_dynamic_validator_stake
|
|
742
|
+
|
|
743
|
+
MIN_VALIDATORS = 2 # Network needs at least 2 validators to function
|
|
744
|
+
|
|
745
|
+
demoted = []
|
|
746
|
+
|
|
747
|
+
with self.lock:
|
|
748
|
+
last_layer = max(0, self.current_num_layers - 1)
|
|
749
|
+
current_validators = list(self.layer_assignments.get(last_layer, []))
|
|
750
|
+
num_validators = len(current_validators)
|
|
751
|
+
|
|
752
|
+
# CRITICAL: Never demote below MIN_VALIDATORS
|
|
753
|
+
# Otherwise the network can't compute real cross-entropy loss!
|
|
754
|
+
if num_validators <= MIN_VALIDATORS:
|
|
755
|
+
logger.debug(f"Only {num_validators} validators - skipping stake check (minimum {MIN_VALIDATORS} required)")
|
|
756
|
+
return []
|
|
757
|
+
|
|
758
|
+
# Get current stake requirement
|
|
759
|
+
required_stake = get_dynamic_validator_stake(num_validators)
|
|
760
|
+
|
|
761
|
+
# Sort validators by stake (lowest first) to demote lowest-stake first
|
|
762
|
+
validators_with_stake = [
|
|
763
|
+
(assignment, get_stake_fn(assignment.node_id))
|
|
764
|
+
for assignment in current_validators
|
|
765
|
+
]
|
|
766
|
+
validators_with_stake.sort(key=lambda x: x[1]) # Lowest stake first
|
|
767
|
+
|
|
768
|
+
for assignment, node_stake in validators_with_stake:
|
|
769
|
+
# Check if we'd go below minimum
|
|
770
|
+
remaining_validators = num_validators - len(demoted)
|
|
771
|
+
if remaining_validators <= MIN_VALIDATORS:
|
|
772
|
+
logger.info(f"Stopping demotion: {remaining_validators} validators remain (minimum {MIN_VALIDATORS})")
|
|
773
|
+
break
|
|
774
|
+
|
|
775
|
+
if node_stake < required_stake:
|
|
776
|
+
logger.warning(
|
|
777
|
+
f"Validator {assignment.node_id[:8]}... has {node_stake:.0f} NEURO "
|
|
778
|
+
f"but {required_stake:.0f} required - DEMOTING"
|
|
779
|
+
)
|
|
780
|
+
# Use unlocked version since we already hold self.lock
|
|
781
|
+
if self._demote_from_validator_unlocked(assignment.node_id):
|
|
782
|
+
demoted.append(assignment.node_id)
|
|
783
|
+
|
|
784
|
+
return demoted
|
|
785
|
+
|
|
786
|
+
def unregister_node(self, node_id: str):
|
|
787
|
+
"""
|
|
788
|
+
Unregister a node and redistribute its layers.
|
|
789
|
+
|
|
790
|
+
This handles graceful degradation when nodes leave.
|
|
791
|
+
"""
|
|
792
|
+
with self.lock:
|
|
793
|
+
# Remove from capacities
|
|
794
|
+
self.node_capacities.pop(node_id, None)
|
|
795
|
+
|
|
796
|
+
# Find all layers this node was holding
|
|
797
|
+
orphaned_layers = []
|
|
798
|
+
|
|
799
|
+
for layer_id, assignments in self.layer_assignments.items():
|
|
800
|
+
# Remove this node's assignment
|
|
801
|
+
self.layer_assignments[layer_id] = [
|
|
802
|
+
a for a in assignments if a.node_id != node_id
|
|
803
|
+
]
|
|
804
|
+
|
|
805
|
+
# Check if layer is now orphaned (< MIN_REPLICAS)
|
|
806
|
+
if len(self.layer_assignments[layer_id]) < self.MIN_REPLICAS:
|
|
807
|
+
orphaned_layers.append(layer_id)
|
|
808
|
+
|
|
809
|
+
# Handle embedding/head holder leaving
|
|
810
|
+
if self.embedding_holder == node_id:
|
|
811
|
+
self.embedding_holder = None
|
|
812
|
+
if self.lm_head_holder == node_id:
|
|
813
|
+
self.lm_head_holder = None
|
|
814
|
+
|
|
815
|
+
if orphaned_layers:
|
|
816
|
+
logger.warning(f"Node {node_id[:8]}... left, {len(orphaned_layers)} layers need redistribution")
|
|
817
|
+
# In production, we would trigger redistribution here
|
|
818
|
+
|
|
819
|
+
def get_layer_holders(self, layer_id: int) -> List[LayerAssignment]:
|
|
820
|
+
"""Get all nodes holding a specific layer."""
|
|
821
|
+
with self.lock:
|
|
822
|
+
return list(self.layer_assignments.get(layer_id, []))
|
|
823
|
+
|
|
824
|
+
def get_pipeline_route(self) -> List[Tuple[int, str]]:
|
|
825
|
+
"""
|
|
826
|
+
Get the route for pipeline inference.
|
|
827
|
+
|
|
828
|
+
Returns list of (layer_id, grpc_addr) for each layer in order.
|
|
829
|
+
|
|
830
|
+
Filters out dead/stale nodes based on heartbeat timeout.
|
|
831
|
+
"""
|
|
832
|
+
with self.lock:
|
|
833
|
+
route = []
|
|
834
|
+
now = time.time()
|
|
835
|
+
|
|
836
|
+
for layer_id in range(self.current_num_layers):
|
|
837
|
+
holders = self.layer_assignments.get(layer_id, [])
|
|
838
|
+
if not holders:
|
|
839
|
+
logger.error(f"Layer {layer_id} has no holders!")
|
|
840
|
+
continue
|
|
841
|
+
|
|
842
|
+
# ROBUSTNESS: Filter out stale holders (expired heartbeat)
|
|
843
|
+
active_holders = [
|
|
844
|
+
h for h in holders
|
|
845
|
+
if (now - h.last_heartbeat) < self.HEARTBEAT_TIMEOUT
|
|
846
|
+
]
|
|
847
|
+
|
|
848
|
+
if not active_holders:
|
|
849
|
+
logger.warning(f"Layer {layer_id} has no ACTIVE holders "
|
|
850
|
+
f"({len(holders)} total, all stale)")
|
|
851
|
+
continue
|
|
852
|
+
|
|
853
|
+
# Pick best active holder (most recent heartbeat)
|
|
854
|
+
active_holders.sort(key=lambda h: -h.last_heartbeat)
|
|
855
|
+
route.append((layer_id, active_holders[0].grpc_addr))
|
|
856
|
+
|
|
857
|
+
logger.debug(f"Layer {layer_id}: selected {active_holders[0].node_id[:16]}... "
|
|
858
|
+
f"(heartbeat {now - active_holders[0].last_heartbeat:.1f}s ago)")
|
|
859
|
+
|
|
860
|
+
return route
|
|
861
|
+
|
|
862
|
+
def get_network_capacity(self) -> NetworkCapacity:
|
|
863
|
+
"""Get current network capacity with dynamic architecture."""
|
|
864
|
+
with self.lock:
|
|
865
|
+
total_memory = sum(self.node_capacities.values())
|
|
866
|
+
|
|
867
|
+
# Calculate max layers based on current architecture
|
|
868
|
+
if self.current_architecture:
|
|
869
|
+
memory_per_layer = estimate_memory_per_layer(self.current_architecture)
|
|
870
|
+
max_layers = int(total_memory * 0.6 / memory_per_layer)
|
|
871
|
+
else:
|
|
872
|
+
max_layers = 0
|
|
873
|
+
|
|
874
|
+
layer_coverage = {
|
|
875
|
+
layer_id: len(assignments)
|
|
876
|
+
for layer_id, assignments in self.layer_assignments.items()
|
|
877
|
+
}
|
|
878
|
+
|
|
879
|
+
return NetworkCapacity(
|
|
880
|
+
total_nodes=len(self.node_capacities),
|
|
881
|
+
total_memory_mb=total_memory,
|
|
882
|
+
max_layers=max_layers,
|
|
883
|
+
assigned_layers=self.current_num_layers,
|
|
884
|
+
layer_coverage=layer_coverage,
|
|
885
|
+
)
|
|
886
|
+
|
|
887
|
+
def heartbeat(self, node_id: str, layer_ids: List[int]):
|
|
888
|
+
"""Update heartbeat for a node's layers."""
|
|
889
|
+
with self.lock:
|
|
890
|
+
now = time.time()
|
|
891
|
+
for layer_id in layer_ids:
|
|
892
|
+
for assignment in self.layer_assignments.get(layer_id, []):
|
|
893
|
+
if assignment.node_id == node_id:
|
|
894
|
+
assignment.last_heartbeat = now
|
|
895
|
+
|
|
896
|
+
def cleanup_stale_assignments(self) -> int:
|
|
897
|
+
"""
|
|
898
|
+
Remove stale layer assignments (nodes that haven't heartbeat recently).
|
|
899
|
+
|
|
900
|
+
Returns number of stale assignments removed.
|
|
901
|
+
|
|
902
|
+
Called periodically to prevent dead peers from being selected for pipeline routing.
|
|
903
|
+
"""
|
|
904
|
+
with self.lock:
|
|
905
|
+
now = time.time()
|
|
906
|
+
removed_count = 0
|
|
907
|
+
|
|
908
|
+
for layer_id, assignments in list(self.layer_assignments.items()):
|
|
909
|
+
# Filter out stale assignments
|
|
910
|
+
active_assignments = [
|
|
911
|
+
a for a in assignments
|
|
912
|
+
if (now - a.last_heartbeat) < self.HEARTBEAT_TIMEOUT
|
|
913
|
+
]
|
|
914
|
+
|
|
915
|
+
stale_count = len(assignments) - len(active_assignments)
|
|
916
|
+
if stale_count > 0:
|
|
917
|
+
logger.info(f"Layer {layer_id}: removed {stale_count} stale assignments "
|
|
918
|
+
f"({len(active_assignments)} remain)")
|
|
919
|
+
removed_count += stale_count
|
|
920
|
+
|
|
921
|
+
# Update assignments
|
|
922
|
+
if active_assignments:
|
|
923
|
+
self.layer_assignments[layer_id] = active_assignments
|
|
924
|
+
else:
|
|
925
|
+
# No active holders for this layer!
|
|
926
|
+
logger.warning(f"Layer {layer_id}: NO active holders remaining!")
|
|
927
|
+
del self.layer_assignments[layer_id]
|
|
928
|
+
|
|
929
|
+
return removed_count
|
|
930
|
+
|
|
931
|
+
|
|
932
|
+
class DynamicNeuroLLM:
|
|
933
|
+
"""
|
|
934
|
+
A NeuroLLM that dynamically scales with the network.
|
|
935
|
+
|
|
936
|
+
Key differences from fixed-phase model:
|
|
937
|
+
- Number of layers AND hidden dimension determined by network capacity
|
|
938
|
+
- Layers are distributed across nodes
|
|
939
|
+
- Model grows organically in BOTH width and depth
|
|
940
|
+
- Architecture adapts automatically as network expands
|
|
941
|
+
"""
|
|
942
|
+
|
|
943
|
+
def __init__(
|
|
944
|
+
self,
|
|
945
|
+
node_id: str,
|
|
946
|
+
layer_pool: DynamicLayerPool,
|
|
947
|
+
device: str = "cpu"
|
|
948
|
+
):
|
|
949
|
+
self.node_id = node_id
|
|
950
|
+
self.layer_pool = layer_pool
|
|
951
|
+
self.device = device
|
|
952
|
+
|
|
953
|
+
# Get current architecture from layer pool
|
|
954
|
+
if layer_pool.current_architecture is None:
|
|
955
|
+
raise RuntimeError("Layer pool has no architecture - call _auto_recalculate_architecture first")
|
|
956
|
+
self.architecture = layer_pool.current_architecture
|
|
957
|
+
|
|
958
|
+
# My assigned layers
|
|
959
|
+
self.my_layers: Dict[int, torch.nn.Module] = {}
|
|
960
|
+
self.my_layer_ids: List[int] = []
|
|
961
|
+
|
|
962
|
+
# Callback for when layers change (set by DynamicNeuroNode to sync state)
|
|
963
|
+
self._on_layers_changed: Optional[callable] = None
|
|
964
|
+
|
|
965
|
+
# Reference to P2P manager for DHT updates during layer removal
|
|
966
|
+
self._p2p_manager = None
|
|
967
|
+
|
|
968
|
+
# Do I hold embedding/head?
|
|
969
|
+
self.has_embedding = False
|
|
970
|
+
self.has_lm_head = False
|
|
971
|
+
|
|
972
|
+
# Shared components (if I hold them)
|
|
973
|
+
self.embedding: Optional[torch.nn.Embedding] = None
|
|
974
|
+
self.lm_head: Optional[torch.nn.Linear] = None
|
|
975
|
+
self.final_norm: Optional[torch.nn.Module] = None
|
|
976
|
+
|
|
977
|
+
# Training mode flag (PyTorch convention)
|
|
978
|
+
self.training = False
|
|
979
|
+
|
|
980
|
+
logger.info(f"DynamicNeuroLLM initialized for node {node_id[:8]}... "
|
|
981
|
+
f"with {self.architecture.num_layers}L × {self.architecture.hidden_dim}H architecture")
|
|
982
|
+
|
|
983
|
+
def initialize_layers(self, layer_ids: List[int]):
|
|
984
|
+
"""Initialize the layers assigned to this node using DYNAMIC architecture."""
|
|
985
|
+
from neuroshard.core.model.llm import NeuroLLMConfig, NeuroDecoderLayer
|
|
986
|
+
|
|
987
|
+
# Create config from current architecture (DYNAMIC!)
|
|
988
|
+
config = NeuroLLMConfig(
|
|
989
|
+
hidden_dim=self.architecture.hidden_dim,
|
|
990
|
+
intermediate_dim=self.architecture.intermediate_dim,
|
|
991
|
+
num_layers=self.architecture.num_layers,
|
|
992
|
+
num_heads=self.architecture.num_heads,
|
|
993
|
+
num_kv_heads=self.architecture.num_kv_heads,
|
|
994
|
+
vocab_size=self.architecture.vocab_size,
|
|
995
|
+
max_seq_len=self.architecture.max_seq_len,
|
|
996
|
+
dropout=self.architecture.dropout,
|
|
997
|
+
rope_theta=self.architecture.rope_theta,
|
|
998
|
+
)
|
|
999
|
+
|
|
1000
|
+
for layer_id in layer_ids:
|
|
1001
|
+
layer = NeuroDecoderLayer(config, layer_id)
|
|
1002
|
+
layer.to(self.device)
|
|
1003
|
+
self.my_layers[layer_id] = layer
|
|
1004
|
+
|
|
1005
|
+
self.my_layer_ids = sorted(layer_ids)
|
|
1006
|
+
|
|
1007
|
+
# Initialize embedding if I'm the holder (uses dynamic hidden_dim!)
|
|
1008
|
+
# Vocab capacity can grow dynamically as tokenizer learns more merges
|
|
1009
|
+
self.vocab_capacity = INITIAL_VOCAB_SIZE
|
|
1010
|
+
if self.layer_pool.embedding_holder == self.node_id:
|
|
1011
|
+
self.embedding = torch.nn.Embedding(self.vocab_capacity, self.architecture.hidden_dim)
|
|
1012
|
+
self.embedding.to(self.device)
|
|
1013
|
+
self.has_embedding = True
|
|
1014
|
+
|
|
1015
|
+
# Initialize LM head if I'm the holder (uses dynamic hidden_dim!)
|
|
1016
|
+
if self.layer_pool.lm_head_holder == self.node_id:
|
|
1017
|
+
self.lm_head = torch.nn.Linear(self.architecture.hidden_dim, self.vocab_capacity, bias=False)
|
|
1018
|
+
from neuroshard.core.model.llm import RMSNorm
|
|
1019
|
+
self.final_norm = RMSNorm(self.architecture.hidden_dim)
|
|
1020
|
+
self.lm_head.to(self.device)
|
|
1021
|
+
self.final_norm.to(self.device)
|
|
1022
|
+
self.has_lm_head = True
|
|
1023
|
+
|
|
1024
|
+
logger.info(f"Initialized {len(layer_ids)} layers: {layer_ids}, "
|
|
1025
|
+
f"arch={self.architecture.num_layers}L×{self.architecture.hidden_dim}H, "
|
|
1026
|
+
f"embedding={self.has_embedding}, head={self.has_lm_head}")
|
|
1027
|
+
|
|
1028
|
+
def initialize_lm_head(self) -> bool:
|
|
1029
|
+
"""
|
|
1030
|
+
Dynamically initialize the LM head (for validator upgrade).
|
|
1031
|
+
|
|
1032
|
+
Called when a node is upgraded to Validator after staking.
|
|
1033
|
+
No restart required - initializes the head in place.
|
|
1034
|
+
|
|
1035
|
+
Returns True if initialization was successful.
|
|
1036
|
+
"""
|
|
1037
|
+
if self.has_lm_head:
|
|
1038
|
+
logger.info("LM head already initialized")
|
|
1039
|
+
return True
|
|
1040
|
+
|
|
1041
|
+
try:
|
|
1042
|
+
from neuroshard.core.model.llm import RMSNorm
|
|
1043
|
+
|
|
1044
|
+
self.lm_head = torch.nn.Linear(self.architecture.hidden_dim, self.vocab_capacity, bias=False)
|
|
1045
|
+
self.final_norm = RMSNorm(self.architecture.hidden_dim)
|
|
1046
|
+
self.lm_head.to(self.device)
|
|
1047
|
+
self.final_norm.to(self.device)
|
|
1048
|
+
self.has_lm_head = True
|
|
1049
|
+
|
|
1050
|
+
# Add last layer to my layers if not already there
|
|
1051
|
+
last_layer = self.architecture.num_layers - 1
|
|
1052
|
+
if last_layer not in self.my_layer_ids:
|
|
1053
|
+
self.my_layer_ids.append(last_layer)
|
|
1054
|
+
self.my_layer_ids = sorted(self.my_layer_ids)
|
|
1055
|
+
|
|
1056
|
+
logger.info(f"LM head initialized! Now computing REAL cross-entropy loss")
|
|
1057
|
+
return True
|
|
1058
|
+
except Exception as e:
|
|
1059
|
+
logger.error(f"Failed to initialize LM head: {e}")
|
|
1060
|
+
|
|
1061
|
+
def disable_lm_head(self) -> bool:
|
|
1062
|
+
"""
|
|
1063
|
+
Disable the LM head (for validator demotion).
|
|
1064
|
+
|
|
1065
|
+
Called when a validator is demoted due to insufficient stake.
|
|
1066
|
+
The node reverts to Worker role and uses activation norm as loss.
|
|
1067
|
+
|
|
1068
|
+
Returns True if demotion was successful.
|
|
1069
|
+
"""
|
|
1070
|
+
if not self.has_lm_head:
|
|
1071
|
+
logger.debug("LM head not initialized, nothing to disable")
|
|
1072
|
+
return False
|
|
1073
|
+
|
|
1074
|
+
try:
|
|
1075
|
+
# Free memory from LM head
|
|
1076
|
+
if self.lm_head is not None:
|
|
1077
|
+
del self.lm_head
|
|
1078
|
+
self.lm_head = None
|
|
1079
|
+
if self.final_norm is not None:
|
|
1080
|
+
del self.final_norm
|
|
1081
|
+
self.final_norm = None
|
|
1082
|
+
|
|
1083
|
+
self.has_lm_head = False
|
|
1084
|
+
|
|
1085
|
+
# Force garbage collection to free memory
|
|
1086
|
+
import gc
|
|
1087
|
+
gc.collect()
|
|
1088
|
+
if torch.cuda.is_available():
|
|
1089
|
+
torch.cuda.empty_cache()
|
|
1090
|
+
|
|
1091
|
+
logger.warning(f"LM head DISABLED - node demoted to Worker (will use activation norm as loss)")
|
|
1092
|
+
return True
|
|
1093
|
+
except Exception as e:
|
|
1094
|
+
logger.error(f"Failed to disable LM head: {e}")
|
|
1095
|
+
return False
|
|
1096
|
+
|
|
1097
|
+
def add_layers(self, new_layer_ids: List[int]) -> List[int]:
|
|
1098
|
+
"""
|
|
1099
|
+
Dynamically add new layers to a running model.
|
|
1100
|
+
|
|
1101
|
+
This allows the model to grow during training without restart!
|
|
1102
|
+
Called when:
|
|
1103
|
+
- Network needs more layers
|
|
1104
|
+
- Node has available memory
|
|
1105
|
+
- Vocab expansion freed up memory
|
|
1106
|
+
|
|
1107
|
+
Args:
|
|
1108
|
+
new_layer_ids: Layer IDs to add
|
|
1109
|
+
|
|
1110
|
+
Returns:
|
|
1111
|
+
List of layer IDs that were successfully added
|
|
1112
|
+
"""
|
|
1113
|
+
from neuroshard.core.model.llm import NeuroLLMConfig, NeuroDecoderLayer
|
|
1114
|
+
|
|
1115
|
+
# Filter out layers we already have
|
|
1116
|
+
layers_to_add = [lid for lid in new_layer_ids if lid not in self.my_layers]
|
|
1117
|
+
if not layers_to_add:
|
|
1118
|
+
return []
|
|
1119
|
+
|
|
1120
|
+
# Check memory before adding
|
|
1121
|
+
memory_per_layer = estimate_memory_per_layer(self.architecture)
|
|
1122
|
+
required_mb = memory_per_layer * len(layers_to_add)
|
|
1123
|
+
|
|
1124
|
+
try:
|
|
1125
|
+
if torch.cuda.is_available() and self.device != "cpu":
|
|
1126
|
+
free_mb = (torch.cuda.get_device_properties(0).total_memory -
|
|
1127
|
+
torch.cuda.memory_allocated()) / (1024 * 1024)
|
|
1128
|
+
else:
|
|
1129
|
+
import psutil
|
|
1130
|
+
free_mb = psutil.virtual_memory().available / (1024 * 1024)
|
|
1131
|
+
|
|
1132
|
+
if free_mb < required_mb * 1.5: # 1.5x safety margin
|
|
1133
|
+
logger.warning(f"[LAYER] Insufficient memory to add {len(layers_to_add)} layers: "
|
|
1134
|
+
f"need {required_mb:.0f}MB, have {free_mb:.0f}MB")
|
|
1135
|
+
# Add as many as we can fit
|
|
1136
|
+
can_fit = int(free_mb / (memory_per_layer * 1.5))
|
|
1137
|
+
layers_to_add = layers_to_add[:can_fit]
|
|
1138
|
+
if not layers_to_add:
|
|
1139
|
+
return []
|
|
1140
|
+
except Exception as e:
|
|
1141
|
+
logger.warning(f"[LAYER] Memory check failed: {e}, proceeding cautiously")
|
|
1142
|
+
|
|
1143
|
+
# Create config from architecture
|
|
1144
|
+
config = NeuroLLMConfig(
|
|
1145
|
+
hidden_dim=self.architecture.hidden_dim,
|
|
1146
|
+
intermediate_dim=self.architecture.intermediate_dim,
|
|
1147
|
+
num_layers=self.architecture.num_layers,
|
|
1148
|
+
num_heads=self.architecture.num_heads,
|
|
1149
|
+
num_kv_heads=self.architecture.num_kv_heads,
|
|
1150
|
+
vocab_size=self.architecture.vocab_size,
|
|
1151
|
+
max_seq_len=self.architecture.max_seq_len,
|
|
1152
|
+
dropout=self.architecture.dropout,
|
|
1153
|
+
rope_theta=self.architecture.rope_theta,
|
|
1154
|
+
)
|
|
1155
|
+
|
|
1156
|
+
added = []
|
|
1157
|
+
for layer_id in layers_to_add:
|
|
1158
|
+
try:
|
|
1159
|
+
layer = NeuroDecoderLayer(config, layer_id)
|
|
1160
|
+
layer.to(self.device)
|
|
1161
|
+
self.my_layers[layer_id] = layer
|
|
1162
|
+
added.append(layer_id)
|
|
1163
|
+
except Exception as e:
|
|
1164
|
+
logger.error(f"[LAYER] Failed to add layer {layer_id}: {e}")
|
|
1165
|
+
break # Stop on first failure
|
|
1166
|
+
|
|
1167
|
+
if added:
|
|
1168
|
+
self.my_layer_ids = sorted(self.my_layers.keys())
|
|
1169
|
+
logger.info(f"[LAYER] ✅ Added {len(added)} layers: {added}")
|
|
1170
|
+
logger.info(f"[LAYER] Now holding {len(self.my_layer_ids)} layers: {self.my_layer_ids[:5]}...{self.my_layer_ids[-5:]}")
|
|
1171
|
+
|
|
1172
|
+
return added
|
|
1173
|
+
|
|
1174
|
+
def remove_layers(self, layer_ids_to_remove: List[int], layer_pool=None, p2p_manager=None) -> List[int]:
|
|
1175
|
+
"""
|
|
1176
|
+
Dynamically remove layers from a running model.
|
|
1177
|
+
|
|
1178
|
+
This allows the model to shrink during training without restart!
|
|
1179
|
+
Called when:
|
|
1180
|
+
- Vocab growth needs more memory
|
|
1181
|
+
- Network is redistributing layers
|
|
1182
|
+
- Memory pressure detected
|
|
1183
|
+
|
|
1184
|
+
IMPORTANT: This also updates layer_pool and DHT announcements so other
|
|
1185
|
+
nodes know we no longer hold these layers!
|
|
1186
|
+
|
|
1187
|
+
Args:
|
|
1188
|
+
layer_ids_to_remove: Layer IDs to remove
|
|
1189
|
+
layer_pool: DynamicLayerPool to update assignments (optional)
|
|
1190
|
+
p2p_manager: P2PManager to update DHT announcements (optional)
|
|
1191
|
+
|
|
1192
|
+
Returns:
|
|
1193
|
+
List of layer IDs that were successfully removed
|
|
1194
|
+
"""
|
|
1195
|
+
# Can't remove layers we don't have
|
|
1196
|
+
layers_to_remove = [lid for lid in layer_ids_to_remove if lid in self.my_layers]
|
|
1197
|
+
if not layers_to_remove:
|
|
1198
|
+
return []
|
|
1199
|
+
|
|
1200
|
+
# Don't remove all layers - keep at least 1
|
|
1201
|
+
if len(layers_to_remove) >= len(self.my_layers):
|
|
1202
|
+
layers_to_remove = layers_to_remove[:-1] # Keep last one
|
|
1203
|
+
if not layers_to_remove:
|
|
1204
|
+
logger.warning("[LAYER] Cannot remove all layers")
|
|
1205
|
+
return []
|
|
1206
|
+
|
|
1207
|
+
removed = []
|
|
1208
|
+
for layer_id in layers_to_remove:
|
|
1209
|
+
try:
|
|
1210
|
+
# Delete the layer
|
|
1211
|
+
layer = self.my_layers.pop(layer_id)
|
|
1212
|
+
del layer
|
|
1213
|
+
removed.append(layer_id)
|
|
1214
|
+
except Exception as e:
|
|
1215
|
+
logger.error(f"[LAYER] Failed to remove layer {layer_id}: {e}")
|
|
1216
|
+
|
|
1217
|
+
if removed:
|
|
1218
|
+
self.my_layer_ids = sorted(self.my_layers.keys())
|
|
1219
|
+
|
|
1220
|
+
# UPDATE LAYER POOL: Remove this node from removed layers' assignments
|
|
1221
|
+
if layer_pool:
|
|
1222
|
+
try:
|
|
1223
|
+
with layer_pool.lock:
|
|
1224
|
+
for layer_id in removed:
|
|
1225
|
+
if layer_id in layer_pool.layer_assignments:
|
|
1226
|
+
# Remove our node from this layer's holders
|
|
1227
|
+
layer_pool.layer_assignments[layer_id] = [
|
|
1228
|
+
a for a in layer_pool.layer_assignments[layer_id]
|
|
1229
|
+
if a.node_id != self.node_id
|
|
1230
|
+
]
|
|
1231
|
+
logger.info(f"[LAYER] Updated layer_pool: removed self from layer {layer_id}")
|
|
1232
|
+
except Exception as e:
|
|
1233
|
+
logger.warning(f"[LAYER] Could not update layer_pool: {e}")
|
|
1234
|
+
|
|
1235
|
+
# UPDATE DHT: Change announced layer range
|
|
1236
|
+
if p2p_manager and self.my_layer_ids:
|
|
1237
|
+
try:
|
|
1238
|
+
new_start = min(self.my_layer_ids)
|
|
1239
|
+
new_end = max(self.my_layer_ids)
|
|
1240
|
+
p2p_manager.start_layer = new_start
|
|
1241
|
+
p2p_manager.end_layer = new_end
|
|
1242
|
+
p2p_manager.shard_range = f"{new_start}-{new_end}"
|
|
1243
|
+
logger.info(f"[LAYER] Updated P2P shard_range: {new_start}-{new_end}")
|
|
1244
|
+
|
|
1245
|
+
# Re-announce immediately so network knows
|
|
1246
|
+
if hasattr(p2p_manager, '_announce_once'):
|
|
1247
|
+
p2p_manager._announce_once(verbose=True)
|
|
1248
|
+
except Exception as e:
|
|
1249
|
+
logger.warning(f"[LAYER] Could not update P2P shard_range: {e}")
|
|
1250
|
+
|
|
1251
|
+
# Force garbage collection to free memory
|
|
1252
|
+
import gc
|
|
1253
|
+
gc.collect()
|
|
1254
|
+
if torch.cuda.is_available():
|
|
1255
|
+
torch.cuda.empty_cache()
|
|
1256
|
+
|
|
1257
|
+
logger.info(f"[LAYER] ✅ Removed {len(removed)} layers: {removed}")
|
|
1258
|
+
logger.info(f"[LAYER] Now holding {len(self.my_layer_ids)} layers")
|
|
1259
|
+
|
|
1260
|
+
# NOTIFY: Call callback so node can sync its state
|
|
1261
|
+
if self._on_layers_changed:
|
|
1262
|
+
try:
|
|
1263
|
+
self._on_layers_changed(self.my_layer_ids)
|
|
1264
|
+
except Exception as e:
|
|
1265
|
+
logger.warning(f"[LAYER] Callback failed: {e}")
|
|
1266
|
+
|
|
1267
|
+
return removed
|
|
1268
|
+
|
|
1269
|
+
def expand_vocabulary(self, new_vocab_size: int) -> bool:
|
|
1270
|
+
"""
|
|
1271
|
+
Expand embedding and lm_head to accommodate a larger vocabulary.
|
|
1272
|
+
|
|
1273
|
+
This is called when the tokenizer learns new BPE merges that exceed
|
|
1274
|
+
the current vocabulary capacity. The expansion preserves existing
|
|
1275
|
+
token embeddings while initializing new ones.
|
|
1276
|
+
|
|
1277
|
+
For an ever-growing decentralized LLM, vocabulary expansion is essential
|
|
1278
|
+
as millions of users contribute diverse training data across languages
|
|
1279
|
+
and domains.
|
|
1280
|
+
|
|
1281
|
+
Args:
|
|
1282
|
+
new_vocab_size: The new vocabulary size (must be > current capacity)
|
|
1283
|
+
|
|
1284
|
+
Returns:
|
|
1285
|
+
True if expansion was successful, False otherwise
|
|
1286
|
+
"""
|
|
1287
|
+
if new_vocab_size <= self.vocab_capacity:
|
|
1288
|
+
return True # No expansion needed
|
|
1289
|
+
|
|
1290
|
+
# Check against max (if set)
|
|
1291
|
+
if MAX_VOCAB_SIZE is not None and new_vocab_size > MAX_VOCAB_SIZE:
|
|
1292
|
+
logger.warning(f"Requested vocab {new_vocab_size} exceeds MAX_VOCAB_SIZE {MAX_VOCAB_SIZE}")
|
|
1293
|
+
new_vocab_size = MAX_VOCAB_SIZE
|
|
1294
|
+
|
|
1295
|
+
# Round up to next VOCAB_GROWTH_CHUNK for efficient memory alignment
|
|
1296
|
+
new_capacity = ((new_vocab_size + VOCAB_GROWTH_CHUNK - 1) // VOCAB_GROWTH_CHUNK) * VOCAB_GROWTH_CHUNK
|
|
1297
|
+
if MAX_VOCAB_SIZE is not None:
|
|
1298
|
+
new_capacity = min(new_capacity, MAX_VOCAB_SIZE)
|
|
1299
|
+
|
|
1300
|
+
# MEMORY CHECK: Estimate memory needed for expansion
|
|
1301
|
+
# Embedding + LM head expansion: 2 * (new - old) * hidden_dim * 4 bytes (float32)
|
|
1302
|
+
hidden_dim = self.architecture.hidden_dim
|
|
1303
|
+
expansion_params = 2 * (new_capacity - self.vocab_capacity) * hidden_dim
|
|
1304
|
+
expansion_memory_mb = (expansion_params * 4) / (1024 * 1024) # Just weights, no optimizer yet
|
|
1305
|
+
|
|
1306
|
+
# Check available memory (GPU or CPU)
|
|
1307
|
+
try:
|
|
1308
|
+
if torch.cuda.is_available() and self.device != "cpu":
|
|
1309
|
+
# Check GPU memory
|
|
1310
|
+
free_memory_mb = (torch.cuda.get_device_properties(0).total_memory -
|
|
1311
|
+
torch.cuda.memory_allocated()) / (1024 * 1024)
|
|
1312
|
+
else:
|
|
1313
|
+
# Check system RAM
|
|
1314
|
+
import psutil
|
|
1315
|
+
free_memory_mb = psutil.virtual_memory().available / (1024 * 1024)
|
|
1316
|
+
|
|
1317
|
+
# Need at least 2x expansion memory (weights + temporary copy during expansion)
|
|
1318
|
+
required_mb = expansion_memory_mb * 2
|
|
1319
|
+
|
|
1320
|
+
if free_memory_mb < required_mb:
|
|
1321
|
+
logger.warning(f"[VOCAB] Insufficient memory for expansion: need {required_mb:.0f}MB, "
|
|
1322
|
+
f"have {free_memory_mb:.0f}MB free")
|
|
1323
|
+
|
|
1324
|
+
# TRY: Remove some layers to make room for vocab expansion
|
|
1325
|
+
# Vocab is more important than extra layers (all nodes need same vocab)
|
|
1326
|
+
memory_per_layer = estimate_memory_per_layer(self.architecture)
|
|
1327
|
+
layers_to_free = int((required_mb - free_memory_mb) / memory_per_layer) + 1
|
|
1328
|
+
|
|
1329
|
+
if len(self.my_layers) > layers_to_free + 1: # Keep at least 1 layer
|
|
1330
|
+
# Remove highest-numbered layers (least important for Driver/Validator)
|
|
1331
|
+
layers_to_remove = sorted(self.my_layers.keys(), reverse=True)[:layers_to_free]
|
|
1332
|
+
logger.warning(f"[VOCAB] Attempting to free memory by removing {layers_to_free} layers: {layers_to_remove}")
|
|
1333
|
+
|
|
1334
|
+
# Pass layer_pool so network state is updated
|
|
1335
|
+
# p2p_manager will be notified via layer_pool.on_layers_changed callback
|
|
1336
|
+
removed = self.remove_layers(
|
|
1337
|
+
layers_to_remove,
|
|
1338
|
+
layer_pool=self.layer_pool,
|
|
1339
|
+
p2p_manager=getattr(self, '_p2p_manager', None)
|
|
1340
|
+
)
|
|
1341
|
+
if removed:
|
|
1342
|
+
# Recalculate free memory after layer removal
|
|
1343
|
+
if torch.cuda.is_available() and self.device != "cpu":
|
|
1344
|
+
free_memory_mb = (torch.cuda.get_device_properties(0).total_memory -
|
|
1345
|
+
torch.cuda.memory_allocated()) / (1024 * 1024)
|
|
1346
|
+
else:
|
|
1347
|
+
import psutil
|
|
1348
|
+
free_memory_mb = psutil.virtual_memory().available / (1024 * 1024)
|
|
1349
|
+
|
|
1350
|
+
if free_memory_mb >= required_mb:
|
|
1351
|
+
logger.info(f"[VOCAB] ✅ Freed enough memory by removing {len(removed)} layers")
|
|
1352
|
+
# Continue with expansion below
|
|
1353
|
+
else:
|
|
1354
|
+
logger.warning(f"[VOCAB] Still insufficient memory after layer removal")
|
|
1355
|
+
return False
|
|
1356
|
+
else:
|
|
1357
|
+
logger.warning(f"[VOCAB] Could not remove layers, capping expansion")
|
|
1358
|
+
return False
|
|
1359
|
+
else:
|
|
1360
|
+
logger.warning(f"[VOCAB] Not enough layers to remove, capping expansion")
|
|
1361
|
+
return False
|
|
1362
|
+
|
|
1363
|
+
logger.info(f"[VOCAB] Memory check passed: {expansion_memory_mb:.0f}MB needed, "
|
|
1364
|
+
f"{free_memory_mb:.0f}MB available")
|
|
1365
|
+
except Exception as e:
|
|
1366
|
+
logger.warning(f"[VOCAB] Could not check memory: {e}, proceeding with expansion")
|
|
1367
|
+
|
|
1368
|
+
logger.info(f"[VOCAB] Expanding vocabulary: {self.vocab_capacity} → {new_capacity}")
|
|
1369
|
+
|
|
1370
|
+
try:
|
|
1371
|
+
# Expand embedding if we have it
|
|
1372
|
+
if self.has_embedding and self.embedding is not None:
|
|
1373
|
+
old_embedding = self.embedding
|
|
1374
|
+
old_vocab = old_embedding.weight.shape[0]
|
|
1375
|
+
hidden_dim = old_embedding.weight.shape[1]
|
|
1376
|
+
|
|
1377
|
+
# Create new larger embedding
|
|
1378
|
+
new_embedding = torch.nn.Embedding(new_capacity, hidden_dim)
|
|
1379
|
+
new_embedding.to(self.device)
|
|
1380
|
+
|
|
1381
|
+
# Copy existing embeddings
|
|
1382
|
+
with torch.no_grad():
|
|
1383
|
+
new_embedding.weight[:old_vocab] = old_embedding.weight
|
|
1384
|
+
# Initialize new embeddings with small random values
|
|
1385
|
+
# (similar to how transformers initialize)
|
|
1386
|
+
std = 0.02
|
|
1387
|
+
new_embedding.weight[old_vocab:].normal_(mean=0.0, std=std)
|
|
1388
|
+
|
|
1389
|
+
# Replace old embedding
|
|
1390
|
+
del self.embedding
|
|
1391
|
+
self.embedding = new_embedding
|
|
1392
|
+
|
|
1393
|
+
logger.info(f"[VOCAB] Expanded embedding: {old_vocab} → {new_capacity} tokens")
|
|
1394
|
+
|
|
1395
|
+
# Expand lm_head if we have it
|
|
1396
|
+
if self.has_lm_head and self.lm_head is not None:
|
|
1397
|
+
old_lm_head = self.lm_head
|
|
1398
|
+
old_vocab = old_lm_head.weight.shape[0]
|
|
1399
|
+
hidden_dim = old_lm_head.weight.shape[1]
|
|
1400
|
+
|
|
1401
|
+
# Create new larger lm_head
|
|
1402
|
+
new_lm_head = torch.nn.Linear(hidden_dim, new_capacity, bias=False)
|
|
1403
|
+
new_lm_head.to(self.device)
|
|
1404
|
+
|
|
1405
|
+
# Copy existing weights
|
|
1406
|
+
with torch.no_grad():
|
|
1407
|
+
new_lm_head.weight[:old_vocab] = old_lm_head.weight
|
|
1408
|
+
# Initialize new output weights
|
|
1409
|
+
std = 0.02
|
|
1410
|
+
new_lm_head.weight[old_vocab:].normal_(mean=0.0, std=std)
|
|
1411
|
+
|
|
1412
|
+
# Replace old lm_head
|
|
1413
|
+
del self.lm_head
|
|
1414
|
+
self.lm_head = new_lm_head
|
|
1415
|
+
|
|
1416
|
+
logger.info(f"[VOCAB] Expanded lm_head: {old_vocab} → {new_capacity} output classes")
|
|
1417
|
+
|
|
1418
|
+
# Update capacity
|
|
1419
|
+
old_capacity = self.vocab_capacity
|
|
1420
|
+
self.vocab_capacity = new_capacity
|
|
1421
|
+
|
|
1422
|
+
# Force garbage collection
|
|
1423
|
+
import gc
|
|
1424
|
+
gc.collect()
|
|
1425
|
+
if torch.cuda.is_available():
|
|
1426
|
+
torch.cuda.empty_cache()
|
|
1427
|
+
|
|
1428
|
+
logger.info(f"[VOCAB] ✅ Vocabulary expansion complete: {old_capacity} → {new_capacity}")
|
|
1429
|
+
return True
|
|
1430
|
+
|
|
1431
|
+
except Exception as e:
|
|
1432
|
+
logger.error(f"[VOCAB] Failed to expand vocabulary: {e}")
|
|
1433
|
+
return False
|
|
1434
|
+
|
|
1435
|
+
def check_and_expand_vocab_if_needed(self) -> bool:
|
|
1436
|
+
"""
|
|
1437
|
+
Check if tokenizer vocabulary exceeds current capacity and expand if needed.
|
|
1438
|
+
|
|
1439
|
+
This should be called periodically (e.g., after loading new tokenizer from CDN)
|
|
1440
|
+
to ensure the model can handle all tokens in the current vocabulary.
|
|
1441
|
+
|
|
1442
|
+
Returns:
|
|
1443
|
+
True if no expansion needed or expansion successful, False on failure
|
|
1444
|
+
"""
|
|
1445
|
+
if self.tokenizer is None:
|
|
1446
|
+
return True
|
|
1447
|
+
|
|
1448
|
+
current_vocab = self.tokenizer.current_vocab_size
|
|
1449
|
+
if current_vocab > self.vocab_capacity:
|
|
1450
|
+
logger.info(f"[VOCAB] Tokenizer vocab ({current_vocab}) exceeds capacity ({self.vocab_capacity})")
|
|
1451
|
+
return self.expand_vocabulary(current_vocab)
|
|
1452
|
+
|
|
1453
|
+
return True
|
|
1454
|
+
|
|
1455
|
+
def forward_my_layers(
|
|
1456
|
+
self,
|
|
1457
|
+
hidden_states: torch.Tensor,
|
|
1458
|
+
start_layer: Optional[int] = None,
|
|
1459
|
+
end_layer: Optional[int] = None,
|
|
1460
|
+
) -> torch.Tensor:
|
|
1461
|
+
"""Forward through my assigned layers."""
|
|
1462
|
+
if start_layer is None:
|
|
1463
|
+
start_layer = min(self.my_layer_ids) if self.my_layer_ids else 0
|
|
1464
|
+
if end_layer is None:
|
|
1465
|
+
end_layer = max(self.my_layer_ids) + 1 if self.my_layer_ids else 0
|
|
1466
|
+
|
|
1467
|
+
x = hidden_states
|
|
1468
|
+
|
|
1469
|
+
for layer_id in range(start_layer, end_layer):
|
|
1470
|
+
if layer_id in self.my_layers:
|
|
1471
|
+
x, _ = self.my_layers[layer_id](x)
|
|
1472
|
+
|
|
1473
|
+
return x
|
|
1474
|
+
|
|
1475
|
+
def embed(self, input_ids: torch.Tensor) -> torch.Tensor:
|
|
1476
|
+
"""Embed input tokens (only if I hold embedding)."""
|
|
1477
|
+
if not self.has_embedding:
|
|
1478
|
+
raise RuntimeError("This node does not hold the embedding layer")
|
|
1479
|
+
return self.embedding(input_ids)
|
|
1480
|
+
|
|
1481
|
+
def compute_logits(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
1482
|
+
"""Compute logits (only if I hold LM head)."""
|
|
1483
|
+
if not self.has_lm_head:
|
|
1484
|
+
raise RuntimeError("This node does not hold the LM head")
|
|
1485
|
+
x = self.final_norm(hidden_states)
|
|
1486
|
+
return self.lm_head(x)
|
|
1487
|
+
|
|
1488
|
+
def get_num_params(self) -> int:
|
|
1489
|
+
"""Get number of parameters on this node."""
|
|
1490
|
+
total = 0
|
|
1491
|
+
for layer in self.my_layers.values():
|
|
1492
|
+
total += sum(p.numel() for p in layer.parameters())
|
|
1493
|
+
if self.embedding:
|
|
1494
|
+
total += sum(p.numel() for p in self.embedding.parameters())
|
|
1495
|
+
if self.lm_head:
|
|
1496
|
+
total += sum(p.numel() for p in self.lm_head.parameters())
|
|
1497
|
+
if self.final_norm:
|
|
1498
|
+
total += sum(p.numel() for p in self.final_norm.parameters())
|
|
1499
|
+
return total
|
|
1500
|
+
|
|
1501
|
+
def parameters(self):
|
|
1502
|
+
"""Yield all parameters for this node's model components (for optimizer/gradient clipping)."""
|
|
1503
|
+
for layer in self.my_layers.values():
|
|
1504
|
+
yield from layer.parameters()
|
|
1505
|
+
if self.embedding:
|
|
1506
|
+
yield from self.embedding.parameters()
|
|
1507
|
+
if self.lm_head:
|
|
1508
|
+
yield from self.lm_head.parameters()
|
|
1509
|
+
if self.final_norm:
|
|
1510
|
+
yield from self.final_norm.parameters()
|
|
1511
|
+
|
|
1512
|
+
def named_parameters(self, prefix: str = '', recurse: bool = True):
|
|
1513
|
+
"""
|
|
1514
|
+
Yield (name, param) tuples for all parameters.
|
|
1515
|
+
|
|
1516
|
+
This is the standard PyTorch interface for iterating over named parameters.
|
|
1517
|
+
"""
|
|
1518
|
+
# Layers
|
|
1519
|
+
for layer_id, layer in self.my_layers.items():
|
|
1520
|
+
layer_prefix = f"{prefix}layers.{layer_id}." if prefix else f"layers.{layer_id}."
|
|
1521
|
+
for name, param in layer.named_parameters(prefix='', recurse=recurse):
|
|
1522
|
+
yield layer_prefix + name, param
|
|
1523
|
+
|
|
1524
|
+
# Embedding
|
|
1525
|
+
if self.embedding:
|
|
1526
|
+
emb_prefix = f"{prefix}embedding." if prefix else "embedding."
|
|
1527
|
+
for name, param in self.embedding.named_parameters(prefix='', recurse=recurse):
|
|
1528
|
+
yield emb_prefix + name, param
|
|
1529
|
+
|
|
1530
|
+
# LM Head
|
|
1531
|
+
if self.lm_head:
|
|
1532
|
+
head_prefix = f"{prefix}lm_head." if prefix else "lm_head."
|
|
1533
|
+
for name, param in self.lm_head.named_parameters(prefix='', recurse=recurse):
|
|
1534
|
+
yield head_prefix + name, param
|
|
1535
|
+
|
|
1536
|
+
# Final Norm
|
|
1537
|
+
if self.final_norm:
|
|
1538
|
+
norm_prefix = f"{prefix}final_norm." if prefix else "final_norm."
|
|
1539
|
+
for name, param in self.final_norm.named_parameters(prefix='', recurse=recurse):
|
|
1540
|
+
yield norm_prefix + name, param
|
|
1541
|
+
|
|
1542
|
+
def state_dict(self) -> Dict[str, Any]:
|
|
1543
|
+
"""
|
|
1544
|
+
Return the state dictionary of the model.
|
|
1545
|
+
|
|
1546
|
+
This is the standard PyTorch interface for saving model state.
|
|
1547
|
+
"""
|
|
1548
|
+
state = {}
|
|
1549
|
+
|
|
1550
|
+
# Layers
|
|
1551
|
+
for layer_id, layer in self.my_layers.items():
|
|
1552
|
+
for name, param in layer.state_dict().items():
|
|
1553
|
+
state[f"layers.{layer_id}.{name}"] = param
|
|
1554
|
+
|
|
1555
|
+
# Embedding
|
|
1556
|
+
if self.embedding:
|
|
1557
|
+
for name, param in self.embedding.state_dict().items():
|
|
1558
|
+
state[f"embedding.{name}"] = param
|
|
1559
|
+
|
|
1560
|
+
# LM Head
|
|
1561
|
+
if self.lm_head:
|
|
1562
|
+
for name, param in self.lm_head.state_dict().items():
|
|
1563
|
+
state[f"lm_head.{name}"] = param
|
|
1564
|
+
|
|
1565
|
+
# Final Norm
|
|
1566
|
+
if self.final_norm:
|
|
1567
|
+
for name, param in self.final_norm.state_dict().items():
|
|
1568
|
+
state[f"final_norm.{name}"] = param
|
|
1569
|
+
|
|
1570
|
+
return state
|
|
1571
|
+
|
|
1572
|
+
def load_state_dict(self, state_dict: Dict[str, Any], strict: bool = True):
|
|
1573
|
+
"""
|
|
1574
|
+
Load state dictionary into the model.
|
|
1575
|
+
|
|
1576
|
+
This is the standard PyTorch interface for loading model state.
|
|
1577
|
+
"""
|
|
1578
|
+
# Group state by component
|
|
1579
|
+
layer_states: Dict[int, Dict[str, Any]] = {}
|
|
1580
|
+
embedding_state: Dict[str, Any] = {}
|
|
1581
|
+
lm_head_state: Dict[str, Any] = {}
|
|
1582
|
+
final_norm_state: Dict[str, Any] = {}
|
|
1583
|
+
|
|
1584
|
+
for key, value in state_dict.items():
|
|
1585
|
+
if key.startswith("layers."):
|
|
1586
|
+
parts = key.split(".", 2)
|
|
1587
|
+
layer_id = int(parts[1])
|
|
1588
|
+
param_name = parts[2]
|
|
1589
|
+
if layer_id not in layer_states:
|
|
1590
|
+
layer_states[layer_id] = {}
|
|
1591
|
+
layer_states[layer_id][param_name] = value
|
|
1592
|
+
elif key.startswith("embedding."):
|
|
1593
|
+
param_name = key[len("embedding."):]
|
|
1594
|
+
embedding_state[param_name] = value
|
|
1595
|
+
elif key.startswith("lm_head."):
|
|
1596
|
+
param_name = key[len("lm_head."):]
|
|
1597
|
+
lm_head_state[param_name] = value
|
|
1598
|
+
elif key.startswith("final_norm."):
|
|
1599
|
+
param_name = key[len("final_norm."):]
|
|
1600
|
+
final_norm_state[param_name] = value
|
|
1601
|
+
|
|
1602
|
+
# Load into components
|
|
1603
|
+
for layer_id, layer in self.my_layers.items():
|
|
1604
|
+
if layer_id in layer_states:
|
|
1605
|
+
layer.load_state_dict(layer_states[layer_id], strict=strict)
|
|
1606
|
+
|
|
1607
|
+
if self.embedding and embedding_state:
|
|
1608
|
+
self.embedding.load_state_dict(embedding_state, strict=strict)
|
|
1609
|
+
|
|
1610
|
+
if self.lm_head and lm_head_state:
|
|
1611
|
+
self.lm_head.load_state_dict(lm_head_state, strict=strict)
|
|
1612
|
+
|
|
1613
|
+
if self.final_norm and final_norm_state:
|
|
1614
|
+
self.final_norm.load_state_dict(final_norm_state, strict=strict)
|
|
1615
|
+
|
|
1616
|
+
def zero_grad(self, set_to_none: bool = False):
|
|
1617
|
+
"""
|
|
1618
|
+
Zero all gradients.
|
|
1619
|
+
|
|
1620
|
+
This is the standard PyTorch interface for zeroing gradients.
|
|
1621
|
+
"""
|
|
1622
|
+
for param in self.parameters():
|
|
1623
|
+
if param.grad is not None:
|
|
1624
|
+
if set_to_none:
|
|
1625
|
+
param.grad = None
|
|
1626
|
+
else:
|
|
1627
|
+
param.grad.zero_()
|
|
1628
|
+
|
|
1629
|
+
def train(self, mode: bool = True) -> 'DynamicNeuroLLM':
|
|
1630
|
+
"""
|
|
1631
|
+
Set the model to training mode.
|
|
1632
|
+
|
|
1633
|
+
This is the standard PyTorch interface for setting training mode
|
|
1634
|
+
on all submodules.
|
|
1635
|
+
"""
|
|
1636
|
+
self.training = mode
|
|
1637
|
+
for layer in self.my_layers.values():
|
|
1638
|
+
layer.train(mode)
|
|
1639
|
+
if self.embedding:
|
|
1640
|
+
self.embedding.train(mode)
|
|
1641
|
+
if self.lm_head:
|
|
1642
|
+
self.lm_head.train(mode)
|
|
1643
|
+
if self.final_norm:
|
|
1644
|
+
self.final_norm.train(mode)
|
|
1645
|
+
return self
|
|
1646
|
+
|
|
1647
|
+
def eval(self) -> 'DynamicNeuroLLM':
|
|
1648
|
+
"""Set the model to evaluation mode."""
|
|
1649
|
+
return self.train(False)
|
|
1650
|
+
|
|
1651
|
+
def get_my_contribution(self) -> Dict[str, Any]:
|
|
1652
|
+
"""Get this node's contribution to the network."""
|
|
1653
|
+
capacity = self.layer_pool.get_network_capacity()
|
|
1654
|
+
|
|
1655
|
+
return {
|
|
1656
|
+
"node_id": self.node_id[:16] + "...",
|
|
1657
|
+
"my_layers": self.my_layer_ids,
|
|
1658
|
+
"my_params": self.get_num_params(),
|
|
1659
|
+
"has_embedding": self.has_embedding,
|
|
1660
|
+
"has_lm_head": self.has_lm_head,
|
|
1661
|
+
"network_total_layers": capacity.assigned_layers,
|
|
1662
|
+
"network_total_nodes": capacity.total_nodes,
|
|
1663
|
+
"contribution_ratio": len(self.my_layer_ids) / max(1, capacity.assigned_layers),
|
|
1664
|
+
}
|
|
1665
|
+
|
|
1666
|
+
|
|
1667
|
+
def calculate_reward_multiplier(
|
|
1668
|
+
num_layers_held: int,
|
|
1669
|
+
total_network_layers: int,
|
|
1670
|
+
has_embedding: bool,
|
|
1671
|
+
has_lm_head: bool
|
|
1672
|
+
) -> float:
|
|
1673
|
+
"""
|
|
1674
|
+
Calculate NEURO reward multiplier based on contribution.
|
|
1675
|
+
|
|
1676
|
+
Roles:
|
|
1677
|
+
- Worker: Standard reward based on layers
|
|
1678
|
+
- Driver (Embedding): 1.2x bonus (bandwidth cost)
|
|
1679
|
+
- Validator (Head): 1.2x bonus (compute/consensus cost)
|
|
1680
|
+
"""
|
|
1681
|
+
if total_network_layers == 0:
|
|
1682
|
+
return 1.0
|
|
1683
|
+
|
|
1684
|
+
# Base multiplier from layer contribution
|
|
1685
|
+
layer_ratio = num_layers_held / total_network_layers
|
|
1686
|
+
base_multiplier = 1.0 + layer_ratio # 1.0 to 2.0 based on layers
|
|
1687
|
+
|
|
1688
|
+
# Bonus for critical components (Roles)
|
|
1689
|
+
if has_embedding:
|
|
1690
|
+
base_multiplier *= 1.2 # 20% bonus for Driving (Data bandwidth)
|
|
1691
|
+
if has_lm_head:
|
|
1692
|
+
base_multiplier *= 1.2 # 20% bonus for Validating (Loss calc + Gradient origin)
|
|
1693
|
+
|
|
1694
|
+
return base_multiplier
|
|
1695
|
+
|
|
1696
|
+
|
|
1697
|
+
# ============================================================================
|
|
1698
|
+
# DYNAMIC NEURO NODE - The Main Node Class
|
|
1699
|
+
# ============================================================================
|
|
1700
|
+
|
|
1701
|
+
class DynamicNeuroNode:
|
|
1702
|
+
"""
|
|
1703
|
+
A truly decentralized node that contributes based on available memory.
|
|
1704
|
+
|
|
1705
|
+
NO PHASES. NO CENTRAL COORDINATION.
|
|
1706
|
+
|
|
1707
|
+
How it works:
|
|
1708
|
+
1. Node starts, detects available memory
|
|
1709
|
+
2. Registers with network, gets assigned layers
|
|
1710
|
+
3. Loads only the layers it's responsible for
|
|
1711
|
+
4. Participates in training (computes gradients for its layers)
|
|
1712
|
+
5. Participates in inference (forwards through its layers)
|
|
1713
|
+
6. Earns NEURO proportional to its contribution
|
|
1714
|
+
|
|
1715
|
+
The more memory you have, the more layers you hold, the more you earn.
|
|
1716
|
+
"""
|
|
1717
|
+
|
|
1718
|
+
CHECKPOINT_DIR = None # Set in __init__
|
|
1719
|
+
|
|
1720
|
+
def __init__(
|
|
1721
|
+
self,
|
|
1722
|
+
node_id: str,
|
|
1723
|
+
port: int = 8000,
|
|
1724
|
+
tracker_url: str = "https://neuroshard.com/api/tracker",
|
|
1725
|
+
node_token: Optional[str] = None,
|
|
1726
|
+
device: str = "cpu",
|
|
1727
|
+
available_memory_mb: Optional[float] = None,
|
|
1728
|
+
enable_training: bool = True,
|
|
1729
|
+
max_storage_mb: float = 100.0,
|
|
1730
|
+
max_cpu_threads: Optional[int] = None,
|
|
1731
|
+
):
|
|
1732
|
+
self.node_id = node_id
|
|
1733
|
+
self.port = port
|
|
1734
|
+
self.tracker_url = tracker_url
|
|
1735
|
+
self.node_token = node_token
|
|
1736
|
+
|
|
1737
|
+
# Detect device automatically if "auto" or "cpu" (backward compatibility)
|
|
1738
|
+
if device in ("auto", "cpu"):
|
|
1739
|
+
if torch.cuda.is_available():
|
|
1740
|
+
self.device = "cuda"
|
|
1741
|
+
logger.info(f"[NODE] GPU detected: CUDA available")
|
|
1742
|
+
elif torch.backends.mps.is_available():
|
|
1743
|
+
# MPS (Apple Silicon GPU) - enabled now that we have GIL yields
|
|
1744
|
+
self.device = "mps"
|
|
1745
|
+
logger.info(f"[NODE] GPU detected: Apple Metal (MPS)")
|
|
1746
|
+
else:
|
|
1747
|
+
self.device = "cpu"
|
|
1748
|
+
logger.info(f"[NODE] No GPU detected, using CPU")
|
|
1749
|
+
|
|
1750
|
+
# Help debug why CUDA isn't available
|
|
1751
|
+
import subprocess
|
|
1752
|
+
import sys
|
|
1753
|
+
|
|
1754
|
+
# Check if NVIDIA GPU exists
|
|
1755
|
+
has_nvidia_gpu = False
|
|
1756
|
+
try:
|
|
1757
|
+
if sys.platform == 'win32':
|
|
1758
|
+
result = subprocess.run(['nvidia-smi'], capture_output=True, text=True, timeout=2)
|
|
1759
|
+
has_nvidia_gpu = result.returncode == 0
|
|
1760
|
+
elif sys.platform == 'darwin':
|
|
1761
|
+
# macOS doesn't have NVIDIA support (use MPS instead)
|
|
1762
|
+
pass
|
|
1763
|
+
else: # Linux
|
|
1764
|
+
result = subprocess.run(['nvidia-smi'], capture_output=True, text=True, timeout=2)
|
|
1765
|
+
has_nvidia_gpu = result.returncode == 0
|
|
1766
|
+
except Exception:
|
|
1767
|
+
pass
|
|
1768
|
+
|
|
1769
|
+
# Detailed diagnostic
|
|
1770
|
+
logger.info(f"[NODE] torch.cuda.is_available() = False")
|
|
1771
|
+
|
|
1772
|
+
# Check if PyTorch was built with CUDA
|
|
1773
|
+
try:
|
|
1774
|
+
cuda_available = torch.cuda.is_available()
|
|
1775
|
+
cuda_built = getattr(torch.version, 'cuda', None)
|
|
1776
|
+
torch_version = torch.__version__
|
|
1777
|
+
logger.info(f"[NODE] PyTorch version: {torch_version}")
|
|
1778
|
+
logger.info(f"[NODE] CUDA compiled version: {cuda_built if cuda_built else 'None (CPU-only build)'}")
|
|
1779
|
+
except Exception as e:
|
|
1780
|
+
logger.info(f"[NODE] Could not get CUDA info: {e}")
|
|
1781
|
+
|
|
1782
|
+
# Provide helpful diagnostic
|
|
1783
|
+
if has_nvidia_gpu:
|
|
1784
|
+
logger.warning("⚠️ NVIDIA GPU DETECTED BUT NOT BEING USED!")
|
|
1785
|
+
logger.warning("Your system has an NVIDIA GPU, but this PyTorch installation is CPU-only.")
|
|
1786
|
+
logger.warning("🔧 TO ENABLE GPU (for 5-10x faster training):")
|
|
1787
|
+
logger.warning("If running the .exe (frozen build):")
|
|
1788
|
+
logger.warning(" Unfortunately, the bundled Python environment can't easily be modified.")
|
|
1789
|
+
logger.warning(" We recommend running from source for GPU support.")
|
|
1790
|
+
logger.warning("If running from source:")
|
|
1791
|
+
logger.warning(" pip uninstall torch")
|
|
1792
|
+
logger.warning(" pip install torch --index-url https://download.pytorch.org/whl/cu121")
|
|
1793
|
+
logger.warning("To verify: python -c \"import torch; print(torch.cuda.is_available())\"")
|
|
1794
|
+
else:
|
|
1795
|
+
self.device = device
|
|
1796
|
+
logger.info(f"[NODE] Device manually set to: {self.device}")
|
|
1797
|
+
|
|
1798
|
+
logger.info(f"Using device: {self.device}")
|
|
1799
|
+
|
|
1800
|
+
self.enable_training = enable_training
|
|
1801
|
+
self.max_storage_mb = max_storage_mb
|
|
1802
|
+
self.max_cpu_threads = max_cpu_threads
|
|
1803
|
+
|
|
1804
|
+
# CPU thread limiting is done in runner.py BEFORE any torch operations
|
|
1805
|
+
# (torch.set_num_interop_threads must be called before any parallel work)
|
|
1806
|
+
if max_cpu_threads and self.device == "cpu":
|
|
1807
|
+
torch.set_num_threads(max_cpu_threads) # Intra-op parallelism only
|
|
1808
|
+
logger.info(f"Set PyTorch intra-op threads: {max_cpu_threads}")
|
|
1809
|
+
|
|
1810
|
+
# Detect memory if not provided
|
|
1811
|
+
if available_memory_mb is None:
|
|
1812
|
+
self.available_memory_mb = self._detect_available_memory()
|
|
1813
|
+
else:
|
|
1814
|
+
self.available_memory_mb = available_memory_mb
|
|
1815
|
+
|
|
1816
|
+
# Checkpoint directory
|
|
1817
|
+
from pathlib import Path
|
|
1818
|
+
self.CHECKPOINT_DIR = Path.home() / ".neuroshard" / "checkpoints"
|
|
1819
|
+
self.CHECKPOINT_DIR.mkdir(parents=True, exist_ok=True)
|
|
1820
|
+
|
|
1821
|
+
# Compute wallet_id from token for stable checkpoint naming
|
|
1822
|
+
# (node_id can change if machine_id changes, but wallet_id is stable)
|
|
1823
|
+
if node_token:
|
|
1824
|
+
self.wallet_id = hashlib.sha256(node_token.encode()).hexdigest()[:16]
|
|
1825
|
+
else:
|
|
1826
|
+
self.wallet_id = self.node_id[:16] # Fallback to node_id
|
|
1827
|
+
|
|
1828
|
+
# Layer pool (shared across network via DHT)
|
|
1829
|
+
self.layer_pool: Optional[DynamicLayerPool] = None
|
|
1830
|
+
|
|
1831
|
+
# My model (only my layers)
|
|
1832
|
+
self.model: Optional[DynamicNeuroLLM] = None
|
|
1833
|
+
self.my_layer_ids: List[int] = []
|
|
1834
|
+
|
|
1835
|
+
# Tokenizer
|
|
1836
|
+
self.tokenizer = None
|
|
1837
|
+
|
|
1838
|
+
# Training components (enable_training set in __init__)
|
|
1839
|
+
self.optimizer: Optional[torch.optim.Optimizer] = None
|
|
1840
|
+
self.training_coordinator = None
|
|
1841
|
+
self.data_manager = None
|
|
1842
|
+
self.gradient_gossip = None
|
|
1843
|
+
|
|
1844
|
+
# Training lock to prevent concurrent training operations
|
|
1845
|
+
# (local training vs pipeline training conflict)
|
|
1846
|
+
self._training_lock = threading.Lock()
|
|
1847
|
+
|
|
1848
|
+
# P2P
|
|
1849
|
+
self.p2p_manager = None
|
|
1850
|
+
|
|
1851
|
+
# Stats
|
|
1852
|
+
self.is_running = False
|
|
1853
|
+
self.total_tokens_processed = 0
|
|
1854
|
+
self.total_training_rounds = 0
|
|
1855
|
+
self.current_loss = float('inf')
|
|
1856
|
+
self.inference_count = 0
|
|
1857
|
+
self.training_contribution_count = 0
|
|
1858
|
+
|
|
1859
|
+
# KV cache for inference
|
|
1860
|
+
self.kv_cache: Dict[str, Any] = {}
|
|
1861
|
+
|
|
1862
|
+
# Training context (keeps tensors alive for backward pass)
|
|
1863
|
+
# session_id -> {input, output, prev_peer}
|
|
1864
|
+
self.training_context: Dict[str, Any] = {}
|
|
1865
|
+
|
|
1866
|
+
logger.info(f"DynamicNeuroNode initialized: memory={self.available_memory_mb:.0f}MB")
|
|
1867
|
+
|
|
1868
|
+
def _detect_available_memory(self) -> float:
|
|
1869
|
+
"""Detect available system memory."""
|
|
1870
|
+
try:
|
|
1871
|
+
import psutil
|
|
1872
|
+
mem = psutil.virtual_memory()
|
|
1873
|
+
# Use 70% of available memory for safety
|
|
1874
|
+
return mem.available * 0.7 / (1024 * 1024)
|
|
1875
|
+
except ImportError:
|
|
1876
|
+
# Fallback
|
|
1877
|
+
return 2000 # Assume 2GB
|
|
1878
|
+
|
|
1879
|
+
def start(self):
|
|
1880
|
+
"""Start the node."""
|
|
1881
|
+
logger.info("Starting DynamicNeuroNode...")
|
|
1882
|
+
|
|
1883
|
+
# 1. Initialize layer pool
|
|
1884
|
+
dht = None
|
|
1885
|
+
if self.p2p_manager and hasattr(self.p2p_manager, 'dht'):
|
|
1886
|
+
dht = self.p2p_manager.dht
|
|
1887
|
+
self.layer_pool = DynamicLayerPool(dht_protocol=dht)
|
|
1888
|
+
|
|
1889
|
+
# Pass device hint for memory calculations (CPU needs more conservative limits)
|
|
1890
|
+
self.layer_pool._device_hint = self.device
|
|
1891
|
+
|
|
1892
|
+
# 1b. SMART ARCHITECTURE RECONCILIATION
|
|
1893
|
+
# This handles the case where the network has evolved while we were offline
|
|
1894
|
+
self._reconcile_architecture()
|
|
1895
|
+
|
|
1896
|
+
# 1c. PRE-FETCH TOKENIZER VOCAB SIZE for accurate memory calculation
|
|
1897
|
+
# This is critical for dynamic vocab - we need to know vocab size BEFORE
|
|
1898
|
+
# assigning layers, otherwise we'll assign too many and OOM when vocab expands
|
|
1899
|
+
self._prefetch_vocab_capacity()
|
|
1900
|
+
|
|
1901
|
+
# 2. Get staked amount from ledger (for Validator eligibility)
|
|
1902
|
+
staked_amount = 0.0
|
|
1903
|
+
if self.p2p_manager and self.p2p_manager.ledger:
|
|
1904
|
+
try:
|
|
1905
|
+
account_info = self.p2p_manager.ledger.get_account_info()
|
|
1906
|
+
staked_amount = account_info.get("stake", 0.0)
|
|
1907
|
+
logger.info(f"Current stake: {staked_amount:.2f} NEURO")
|
|
1908
|
+
except Exception as e:
|
|
1909
|
+
logger.debug(f"Could not get stake info: {e}")
|
|
1910
|
+
|
|
1911
|
+
# 3. Register with network and get layer assignments
|
|
1912
|
+
self.my_layer_ids = self.layer_pool.register_node(
|
|
1913
|
+
node_id=self.node_id,
|
|
1914
|
+
node_url=f"http://localhost:{self.port}",
|
|
1915
|
+
grpc_addr=f"localhost:{self.port + 1000}",
|
|
1916
|
+
available_memory_mb=self.available_memory_mb,
|
|
1917
|
+
staked_amount=staked_amount
|
|
1918
|
+
)
|
|
1919
|
+
|
|
1920
|
+
logger.info(f"Assigned {len(self.my_layer_ids)} layers: {self.my_layer_ids}")
|
|
1921
|
+
|
|
1922
|
+
# 3. Initialize model with my layers
|
|
1923
|
+
self.model = DynamicNeuroLLM(
|
|
1924
|
+
node_id=self.node_id,
|
|
1925
|
+
layer_pool=self.layer_pool,
|
|
1926
|
+
device=self.device
|
|
1927
|
+
)
|
|
1928
|
+
self.model.initialize_layers(self.my_layer_ids)
|
|
1929
|
+
|
|
1930
|
+
# 3b. Set up callback for dynamic layer changes
|
|
1931
|
+
# When model removes layers (e.g., for vocab expansion), sync node state
|
|
1932
|
+
def on_layers_changed(new_layer_ids: List[int]):
|
|
1933
|
+
self.my_layer_ids = new_layer_ids
|
|
1934
|
+
# Update P2P shard_range if available
|
|
1935
|
+
if self.p2p_manager and new_layer_ids:
|
|
1936
|
+
new_start = min(new_layer_ids)
|
|
1937
|
+
new_end = max(new_layer_ids)
|
|
1938
|
+
self.p2p_manager.start_layer = new_start
|
|
1939
|
+
self.p2p_manager.end_layer = new_end
|
|
1940
|
+
self.p2p_manager.shard_range = f"{new_start}-{new_end}"
|
|
1941
|
+
logger.info(f"[NODE] Synced layer_ids after change: {new_layer_ids}")
|
|
1942
|
+
|
|
1943
|
+
self.model._on_layers_changed = on_layers_changed
|
|
1944
|
+
|
|
1945
|
+
# 4. Initialize tokenizer with learned BPE merges from CDN
|
|
1946
|
+
from neuroshard.core.model.tokenizer import get_neuro_tokenizer, NeuroTokenizer
|
|
1947
|
+
self.tokenizer = get_neuro_tokenizer()
|
|
1948
|
+
self._load_learned_tokenizer() # Update with BPE merges from CDN
|
|
1949
|
+
|
|
1950
|
+
# 5. Try to load existing checkpoint (resume training)
|
|
1951
|
+
self._load_checkpoint()
|
|
1952
|
+
|
|
1953
|
+
# 6. Setup training
|
|
1954
|
+
if self.enable_training:
|
|
1955
|
+
self._setup_training()
|
|
1956
|
+
|
|
1957
|
+
self.is_running = True
|
|
1958
|
+
|
|
1959
|
+
# Log contribution
|
|
1960
|
+
contribution = self.model.get_my_contribution()
|
|
1961
|
+
logger.info(f"Node started: {contribution['my_params']/1e6:.1f}M params, "
|
|
1962
|
+
f"{len(self.my_layer_ids)} layers, "
|
|
1963
|
+
f"embed={self.model.has_embedding}, head={self.model.has_lm_head}")
|
|
1964
|
+
|
|
1965
|
+
# Verify model is actually on the expected device
|
|
1966
|
+
if self.my_layer_ids and self.model.my_layers:
|
|
1967
|
+
first_layer = self.model.my_layers[self.my_layer_ids[0]]
|
|
1968
|
+
param_device = next(first_layer.parameters()).device
|
|
1969
|
+
if str(param_device) != self.device and not (self.device == "cuda" and "cuda" in str(param_device)):
|
|
1970
|
+
logger.error(f"[DEVICE] Model device mismatch! Expected {self.device}, got {param_device}")
|
|
1971
|
+
else:
|
|
1972
|
+
logger.info(f"[DEVICE] Model verified on: {param_device}")
|
|
1973
|
+
|
|
1974
|
+
def _load_learned_tokenizer(self):
|
|
1975
|
+
"""
|
|
1976
|
+
Load learned BPE tokenizer from Genesis CDN.
|
|
1977
|
+
|
|
1978
|
+
This ensures the tokenizer used for inference matches the one used
|
|
1979
|
+
for training data tokenization, providing consistency across the network.
|
|
1980
|
+
"""
|
|
1981
|
+
import requests
|
|
1982
|
+
import os
|
|
1983
|
+
|
|
1984
|
+
GENESIS_CDN_URL = "https://dwquwt9gkkeil.cloudfront.net"
|
|
1985
|
+
cache_dir = os.path.join(os.path.expanduser("~"), ".neuroshard", "data_cache")
|
|
1986
|
+
|
|
1987
|
+
try:
|
|
1988
|
+
tokenizer_url = f"{GENESIS_CDN_URL}/tokenizer.json"
|
|
1989
|
+
tokenizer_cache_path = os.path.join(cache_dir, "tokenizer.json")
|
|
1990
|
+
|
|
1991
|
+
# Try to fetch from CDN
|
|
1992
|
+
try:
|
|
1993
|
+
logger.debug(f"[TOKENIZER] Checking for learned tokenizer from {tokenizer_url}...")
|
|
1994
|
+
resp = requests.get(tokenizer_url, timeout=10)
|
|
1995
|
+
|
|
1996
|
+
if resp.status_code == 200:
|
|
1997
|
+
remote_tokenizer_data = resp.json()
|
|
1998
|
+
remote_vocab_size = remote_tokenizer_data.get("next_merge_id", 0)
|
|
1999
|
+
|
|
2000
|
+
# Cache locally
|
|
2001
|
+
os.makedirs(cache_dir, exist_ok=True)
|
|
2002
|
+
with open(tokenizer_cache_path, 'w') as f:
|
|
2003
|
+
f.write(resp.text)
|
|
2004
|
+
|
|
2005
|
+
# Update tokenizer if remote has more merges
|
|
2006
|
+
if remote_vocab_size > self.tokenizer.next_merge_id:
|
|
2007
|
+
from neuroshard.core.model.tokenizer import NeuroTokenizer
|
|
2008
|
+
learned_tokenizer = NeuroTokenizer.load(tokenizer_cache_path)
|
|
2009
|
+
|
|
2010
|
+
self.tokenizer.merges = learned_tokenizer.merges
|
|
2011
|
+
self.tokenizer.merge_to_tokens = learned_tokenizer.merge_to_tokens
|
|
2012
|
+
self.tokenizer.next_merge_id = learned_tokenizer.next_merge_id
|
|
2013
|
+
|
|
2014
|
+
logger.info(f"[TOKENIZER] Loaded BPE tokenizer: {self.tokenizer.current_vocab_size} tokens, {len(self.tokenizer.merges)} merges")
|
|
2015
|
+
|
|
2016
|
+
# CRITICAL: Check if model needs vocabulary expansion after loading new tokenizer
|
|
2017
|
+
if self.model is not None:
|
|
2018
|
+
self.model.tokenizer = self.tokenizer
|
|
2019
|
+
self.model.check_and_expand_vocab_if_needed()
|
|
2020
|
+
# Update layer pool's vocab_capacity for future layer calculations
|
|
2021
|
+
if hasattr(self, 'layer_pool') and self.layer_pool:
|
|
2022
|
+
self.layer_pool.vocab_capacity = self.model.vocab_capacity
|
|
2023
|
+
else:
|
|
2024
|
+
logger.debug(f"[TOKENIZER] Already up to date: {self.tokenizer.current_vocab_size} tokens")
|
|
2025
|
+
return
|
|
2026
|
+
except requests.RequestException as e:
|
|
2027
|
+
logger.debug(f"[TOKENIZER] CDN fetch failed: {e}")
|
|
2028
|
+
|
|
2029
|
+
# Fallback to cached version
|
|
2030
|
+
if os.path.exists(tokenizer_cache_path) and self.tokenizer.next_merge_id <= 266:
|
|
2031
|
+
try:
|
|
2032
|
+
from neuroshard.core.model.tokenizer import NeuroTokenizer
|
|
2033
|
+
learned_tokenizer = NeuroTokenizer.load(tokenizer_cache_path)
|
|
2034
|
+
|
|
2035
|
+
if learned_tokenizer.next_merge_id > self.tokenizer.next_merge_id:
|
|
2036
|
+
self.tokenizer.merges = learned_tokenizer.merges
|
|
2037
|
+
self.tokenizer.merge_to_tokens = learned_tokenizer.merge_to_tokens
|
|
2038
|
+
self.tokenizer.next_merge_id = learned_tokenizer.next_merge_id
|
|
2039
|
+
logger.info(f"[TOKENIZER] Loaded cached BPE tokenizer: {self.tokenizer.current_vocab_size} tokens")
|
|
2040
|
+
|
|
2041
|
+
# CRITICAL: Check if model needs vocabulary expansion
|
|
2042
|
+
if self.model is not None:
|
|
2043
|
+
self.model.tokenizer = self.tokenizer
|
|
2044
|
+
self.model.check_and_expand_vocab_if_needed()
|
|
2045
|
+
# Update layer pool's vocab_capacity for future layer calculations
|
|
2046
|
+
if hasattr(self, 'layer_pool') and self.layer_pool:
|
|
2047
|
+
self.layer_pool.vocab_capacity = self.model.vocab_capacity
|
|
2048
|
+
except Exception as e:
|
|
2049
|
+
logger.warning(f"[TOKENIZER] Failed to load cached tokenizer: {e}")
|
|
2050
|
+
|
|
2051
|
+
except Exception as e:
|
|
2052
|
+
logger.warning(f"[TOKENIZER] Error loading learned tokenizer: {e}")
|
|
2053
|
+
|
|
2054
|
+
def _setup_training(self):
|
|
2055
|
+
"""Setup training components."""
|
|
2056
|
+
from neuroshard.core.training.distributed import FederatedDataManager
|
|
2057
|
+
|
|
2058
|
+
# Collect all parameters from my layers
|
|
2059
|
+
all_params = []
|
|
2060
|
+
for layer in self.model.my_layers.values():
|
|
2061
|
+
all_params.extend(layer.parameters())
|
|
2062
|
+
if self.model.embedding:
|
|
2063
|
+
all_params.extend(self.model.embedding.parameters())
|
|
2064
|
+
if self.model.lm_head:
|
|
2065
|
+
all_params.extend(self.model.lm_head.parameters())
|
|
2066
|
+
if self.model.final_norm:
|
|
2067
|
+
all_params.extend(self.model.final_norm.parameters())
|
|
2068
|
+
|
|
2069
|
+
self.optimizer = torch.optim.AdamW(all_params, lr=1e-4, weight_decay=0.01)
|
|
2070
|
+
|
|
2071
|
+
self.data_manager = FederatedDataManager(
|
|
2072
|
+
tokenizer=self.tokenizer,
|
|
2073
|
+
max_seq_len=2048
|
|
2074
|
+
)
|
|
2075
|
+
|
|
2076
|
+
# DYNAMIC TRAINING CONFIG: Calculate based on current model size and device
|
|
2077
|
+
# This will be recalculated when model grows via recalculate_training_config()
|
|
2078
|
+
num_layers = len(self.my_layer_ids)
|
|
2079
|
+
|
|
2080
|
+
# Smart gradient checkpointing decision
|
|
2081
|
+
# CRITICAL: Always enable checkpointing for models with many layers!
|
|
2082
|
+
# Without checkpointing, attention scores alone need: batch × heads × seq² × layers × 4 bytes
|
|
2083
|
+
# For 46 layers: 8 × 16 × 2048² × 46 × 4 = ~92GB (way more than any GPU!)
|
|
2084
|
+
|
|
2085
|
+
# SIMPLE RULE: Enable checkpointing if layers > 16 (always safe)
|
|
2086
|
+
# This avoids complex calculations that can have bugs with timing of vocab expansion
|
|
2087
|
+
if num_layers > 16:
|
|
2088
|
+
self._use_gradient_checkpointing = True
|
|
2089
|
+
logger.info(f"[NODE] Gradient checkpointing: ENABLED (layers={num_layers} > 16)")
|
|
2090
|
+
elif self.device != "cuda":
|
|
2091
|
+
# CPU/MPS always use checkpointing for memory efficiency
|
|
2092
|
+
self._use_gradient_checkpointing = True
|
|
2093
|
+
logger.info(f"[NODE] Gradient checkpointing: ENABLED (device={self.device})")
|
|
2094
|
+
else:
|
|
2095
|
+
# Small CUDA models can skip checkpointing for speed
|
|
2096
|
+
self._use_gradient_checkpointing = False
|
|
2097
|
+
logger.info(f"[NODE] Gradient checkpointing: DISABLED (layers={num_layers} ≤ 16, CUDA)")
|
|
2098
|
+
|
|
2099
|
+
# Calculate memory-aware training batch size
|
|
2100
|
+
self._training_batch_size = self._calculate_training_batch_size()
|
|
2101
|
+
|
|
2102
|
+
logger.info(f"Training initialized: batch_size={self._training_batch_size}, "
|
|
2103
|
+
f"checkpointing={self._use_gradient_checkpointing}, "
|
|
2104
|
+
f"layers={num_layers}, device={self.device}")
|
|
2105
|
+
|
|
2106
|
+
# CUDA sanity check: verify GPU is actually usable
|
|
2107
|
+
if self.device == "cuda":
|
|
2108
|
+
try:
|
|
2109
|
+
import time as _time
|
|
2110
|
+
test_tensor = torch.randn(1000, 1000, device="cuda")
|
|
2111
|
+
start = _time.time()
|
|
2112
|
+
_ = torch.matmul(test_tensor, test_tensor)
|
|
2113
|
+
torch.cuda.synchronize()
|
|
2114
|
+
elapsed = _time.time() - start
|
|
2115
|
+
del test_tensor
|
|
2116
|
+
torch.cuda.empty_cache()
|
|
2117
|
+
logger.info(f"[CUDA] GPU sanity check passed: 1000x1000 matmul in {elapsed*1000:.1f}ms")
|
|
2118
|
+
except Exception as e:
|
|
2119
|
+
logger.error(f"[CUDA] GPU sanity check FAILED: {e}")
|
|
2120
|
+
logger.error("[CUDA] Training will likely run on CPU despite device=cuda!")
|
|
2121
|
+
|
|
2122
|
+
def _calculate_training_batch_size(self) -> int:
|
|
2123
|
+
"""
|
|
2124
|
+
Calculate optimal batch size based on available memory, device, and model size.
|
|
2125
|
+
|
|
2126
|
+
DYNAMIC: This is called initially and can be recalculated when model grows.
|
|
2127
|
+
SMART: Considers GPU memory, gradient checkpointing, and actual model size.
|
|
2128
|
+
"""
|
|
2129
|
+
seq_len = 512 # Typical sequence length
|
|
2130
|
+
hidden_dim = self.layer_pool.current_architecture.hidden_dim
|
|
2131
|
+
num_layers = len(self.my_layer_ids)
|
|
2132
|
+
|
|
2133
|
+
# Calculate model memory footprint (params + gradients + optimizer states)
|
|
2134
|
+
model_params = sum(p.numel() for p in self.model.parameters())
|
|
2135
|
+
# Model memory: weights (fp32=4 bytes) × 4 (weights + grads + adam_m + adam_v)
|
|
2136
|
+
model_memory_mb = (model_params * 4 * 4) / (1024 * 1024)
|
|
2137
|
+
|
|
2138
|
+
# For CUDA, check actual GPU memory available
|
|
2139
|
+
if self.device == "cuda":
|
|
2140
|
+
try:
|
|
2141
|
+
gpu_total = torch.cuda.get_device_properties(0).total_memory / (1024 * 1024)
|
|
2142
|
+
gpu_allocated = torch.cuda.memory_allocated(0) / (1024 * 1024)
|
|
2143
|
+
logger.info(f"[NODE] CUDA memory: {gpu_allocated:.0f}MB used / {gpu_total:.0f}MB total")
|
|
2144
|
+
effective_memory_mb = self.available_memory_mb
|
|
2145
|
+
except Exception:
|
|
2146
|
+
effective_memory_mb = self.available_memory_mb
|
|
2147
|
+
else:
|
|
2148
|
+
effective_memory_mb = self.available_memory_mb
|
|
2149
|
+
|
|
2150
|
+
# CORRECT FORMULA: Available for activations = Total - Model memory
|
|
2151
|
+
# Leave 10% buffer for system overhead
|
|
2152
|
+
available_for_activations = max(100, (effective_memory_mb * 0.9) - model_memory_mb)
|
|
2153
|
+
|
|
2154
|
+
# With gradient checkpointing, activation memory is MUCH lower
|
|
2155
|
+
use_checkpointing = getattr(self, '_use_gradient_checkpointing', False)
|
|
2156
|
+
if use_checkpointing:
|
|
2157
|
+
# Checkpointing: Only need to store ~sqrt(num_layers) worth of activations
|
|
2158
|
+
# Plus inputs/outputs at checkpoint boundaries
|
|
2159
|
+
checkpoint_segments = max(1, int(num_layers ** 0.5))
|
|
2160
|
+
# Memory per sample: seq_len × hidden_dim × checkpoint_segments × 4 bytes × 2 (fwd+bwd)
|
|
2161
|
+
mem_per_sample_mb = (seq_len * hidden_dim * checkpoint_segments * 4 * 2) / (1024 * 1024)
|
|
2162
|
+
logger.info(f"[NODE] Gradient checkpointing: {checkpoint_segments} segments "
|
|
2163
|
+
f"(~{mem_per_sample_mb:.1f}MB/sample)")
|
|
2164
|
+
else:
|
|
2165
|
+
# No checkpointing: full activation memory for all layers
|
|
2166
|
+
mem_per_sample_mb = (seq_len * hidden_dim * num_layers * 4 * 2) / (1024 * 1024)
|
|
2167
|
+
|
|
2168
|
+
logger.info(f"[NODE] Memory budget: total={effective_memory_mb:.0f}MB, "
|
|
2169
|
+
f"model={model_memory_mb:.0f}MB, "
|
|
2170
|
+
f"available_for_activations={available_for_activations:.0f}MB")
|
|
2171
|
+
|
|
2172
|
+
# Calculate max batch size from available memory
|
|
2173
|
+
max_batch = max(1, int(available_for_activations / max(1, mem_per_sample_mb)))
|
|
2174
|
+
|
|
2175
|
+
# SMART CLAMPING based on device capability
|
|
2176
|
+
if self.device == "cuda" and effective_memory_mb > 16000:
|
|
2177
|
+
# High-memory CUDA (Jetson Orin 32GB, RTX 3090 24GB): up to 8
|
|
2178
|
+
max_batch = min(max_batch, 8)
|
|
2179
|
+
elif self.device == "cuda" and effective_memory_mb > 8000:
|
|
2180
|
+
# Medium CUDA: up to 4
|
|
2181
|
+
max_batch = min(max_batch, 4)
|
|
2182
|
+
elif self.device == "cuda":
|
|
2183
|
+
# Small CUDA: up to 2
|
|
2184
|
+
max_batch = min(max_batch, 2)
|
|
2185
|
+
elif num_layers > 100:
|
|
2186
|
+
# Large model on CPU/MPS: conservative
|
|
2187
|
+
max_batch = min(max_batch, 2)
|
|
2188
|
+
else:
|
|
2189
|
+
max_batch = min(max_batch, 4)
|
|
2190
|
+
|
|
2191
|
+
batch_size = max(1, max_batch)
|
|
2192
|
+
|
|
2193
|
+
logger.info(f"[NODE] Training config: batch_size={batch_size}, "
|
|
2194
|
+
f"model={model_params/1e6:.1f}M params ({num_layers} layers × {hidden_dim} dim), "
|
|
2195
|
+
f"checkpointing={use_checkpointing}, device={self.device}")
|
|
2196
|
+
|
|
2197
|
+
return batch_size
|
|
2198
|
+
|
|
2199
|
+
def recalculate_training_config(self):
|
|
2200
|
+
"""
|
|
2201
|
+
Recalculate training configuration after model architecture changes.
|
|
2202
|
+
|
|
2203
|
+
Called when:
|
|
2204
|
+
- Model grows (new layers added)
|
|
2205
|
+
- Memory allocation changes
|
|
2206
|
+
- Device changes
|
|
2207
|
+
"""
|
|
2208
|
+
old_batch = getattr(self, '_training_batch_size', None)
|
|
2209
|
+
self._training_batch_size = self._calculate_training_batch_size()
|
|
2210
|
+
|
|
2211
|
+
# Update gradient checkpointing based on new model size
|
|
2212
|
+
num_layers = len(self.my_layer_ids)
|
|
2213
|
+
old_checkpointing = getattr(self, '_use_gradient_checkpointing', False)
|
|
2214
|
+
|
|
2215
|
+
# Simple checkpointing rule: enable if layers > 16
|
|
2216
|
+
if num_layers > 16:
|
|
2217
|
+
self._use_gradient_checkpointing = True
|
|
2218
|
+
elif self.device != "cuda":
|
|
2219
|
+
self._use_gradient_checkpointing = True
|
|
2220
|
+
else:
|
|
2221
|
+
self._use_gradient_checkpointing = False
|
|
2222
|
+
|
|
2223
|
+
if old_batch != self._training_batch_size or old_checkpointing != self._use_gradient_checkpointing:
|
|
2224
|
+
logger.info(f"[NODE] Training config updated: batch_size={old_batch}→{self._training_batch_size}, "
|
|
2225
|
+
f"checkpointing={old_checkpointing}→{self._use_gradient_checkpointing}")
|
|
2226
|
+
|
|
2227
|
+
def stop(self):
|
|
2228
|
+
"""Stop the node."""
|
|
2229
|
+
logger.info("Stopping DynamicNeuroNode...")
|
|
2230
|
+
|
|
2231
|
+
self.is_running = False
|
|
2232
|
+
|
|
2233
|
+
# Unregister from network
|
|
2234
|
+
if self.layer_pool:
|
|
2235
|
+
self.layer_pool.unregister_node(self.node_id)
|
|
2236
|
+
|
|
2237
|
+
# Save checkpoint
|
|
2238
|
+
self._save_checkpoint()
|
|
2239
|
+
|
|
2240
|
+
logger.info("DynamicNeuroNode stopped")
|
|
2241
|
+
|
|
2242
|
+
def connect_p2p(self, p2p_manager):
|
|
2243
|
+
"""Connect to P2P network."""
|
|
2244
|
+
self.p2p_manager = p2p_manager
|
|
2245
|
+
|
|
2246
|
+
# Initialize Data Swarm
|
|
2247
|
+
from neuroshard.core.network.p2p_data import DataSwarm
|
|
2248
|
+
|
|
2249
|
+
# Ensure cache dir exists in a writable location
|
|
2250
|
+
data_cache_dir = self.CHECKPOINT_DIR / "data_cache"
|
|
2251
|
+
data_cache_dir.mkdir(parents=True, exist_ok=True)
|
|
2252
|
+
|
|
2253
|
+
self.swarm = DataSwarm(p2p_manager, cache_dir=str(data_cache_dir))
|
|
2254
|
+
|
|
2255
|
+
# Update layer pool with DHT
|
|
2256
|
+
if self.layer_pool and hasattr(p2p_manager, 'dht'):
|
|
2257
|
+
self.layer_pool.dht = p2p_manager.dht
|
|
2258
|
+
|
|
2259
|
+
# IMPORTANT: Give model access to p2p_manager for dynamic layer updates
|
|
2260
|
+
# When vocab expansion removes layers, the model needs to update DHT
|
|
2261
|
+
if self.model:
|
|
2262
|
+
self.model._p2p_manager = p2p_manager
|
|
2263
|
+
|
|
2264
|
+
logger.info("Connected to P2P network and Data Swarm")
|
|
2265
|
+
|
|
2266
|
+
# ==================== INFERENCE ====================
|
|
2267
|
+
|
|
2268
|
+
def forward(self, input_ids: torch.Tensor, session_id: Optional[str] = None) -> torch.Tensor:
|
|
2269
|
+
"""
|
|
2270
|
+
Forward pass - routes through network if needed.
|
|
2271
|
+
|
|
2272
|
+
If this node has all layers: process locally
|
|
2273
|
+
If not: forward to nodes with other layers
|
|
2274
|
+
"""
|
|
2275
|
+
# Check if we can do full inference locally
|
|
2276
|
+
capacity = self.layer_pool.get_network_capacity()
|
|
2277
|
+
|
|
2278
|
+
if len(self.my_layer_ids) == capacity.assigned_layers and self.model.has_embedding and self.model.has_lm_head:
|
|
2279
|
+
# We have everything - do local inference
|
|
2280
|
+
return self._forward_local(input_ids)
|
|
2281
|
+
else:
|
|
2282
|
+
# Need to route through network
|
|
2283
|
+
return self._forward_distributed(input_ids, session_id)
|
|
2284
|
+
|
|
2285
|
+
def _forward_local(self, input_ids: torch.Tensor) -> torch.Tensor:
|
|
2286
|
+
"""Full local inference (when we have all layers)."""
|
|
2287
|
+
with torch.no_grad():
|
|
2288
|
+
# Embed
|
|
2289
|
+
hidden = self.model.embed(input_ids.to(self.device))
|
|
2290
|
+
|
|
2291
|
+
# Forward through all layers
|
|
2292
|
+
hidden = self.model.forward_my_layers(hidden)
|
|
2293
|
+
|
|
2294
|
+
# Compute logits
|
|
2295
|
+
logits = self.model.compute_logits(hidden)
|
|
2296
|
+
|
|
2297
|
+
self.inference_count += 1
|
|
2298
|
+
self.total_tokens_processed += input_ids.numel()
|
|
2299
|
+
|
|
2300
|
+
return logits
|
|
2301
|
+
|
|
2302
|
+
def _forward_distributed(self, input_ids: torch.Tensor, session_id: Optional[str] = None) -> torch.Tensor:
|
|
2303
|
+
"""Distributed inference through network pipeline."""
|
|
2304
|
+
# Get pipeline route
|
|
2305
|
+
route = self.layer_pool.get_pipeline_route()
|
|
2306
|
+
|
|
2307
|
+
if not route:
|
|
2308
|
+
raise RuntimeError("No pipeline route available")
|
|
2309
|
+
|
|
2310
|
+
# Start with embedding
|
|
2311
|
+
if self.model.has_embedding:
|
|
2312
|
+
hidden = self.model.embed(input_ids.to(self.device))
|
|
2313
|
+
else:
|
|
2314
|
+
# Request embedding from holder
|
|
2315
|
+
hidden = self._request_embedding(input_ids)
|
|
2316
|
+
|
|
2317
|
+
# Forward through layers (local or remote)
|
|
2318
|
+
current_layer = 0
|
|
2319
|
+
for layer_id, grpc_addr in route:
|
|
2320
|
+
if layer_id in self.model.my_layers:
|
|
2321
|
+
# Local layer
|
|
2322
|
+
hidden, _ = self.model.my_layers[layer_id](hidden)
|
|
2323
|
+
else:
|
|
2324
|
+
# Remote layer - forward to peer
|
|
2325
|
+
hidden = self._forward_to_peer(grpc_addr, hidden, layer_id)
|
|
2326
|
+
current_layer = layer_id
|
|
2327
|
+
|
|
2328
|
+
# Compute logits
|
|
2329
|
+
if self.model.has_lm_head:
|
|
2330
|
+
logits = self.model.compute_logits(hidden)
|
|
2331
|
+
else:
|
|
2332
|
+
# Request from holder
|
|
2333
|
+
logits = self._request_logits(hidden)
|
|
2334
|
+
|
|
2335
|
+
self.inference_count += 1
|
|
2336
|
+
self.total_tokens_processed += input_ids.numel()
|
|
2337
|
+
|
|
2338
|
+
return logits
|
|
2339
|
+
|
|
2340
|
+
def forward_pipeline(
|
|
2341
|
+
self,
|
|
2342
|
+
hidden_states: torch.Tensor,
|
|
2343
|
+
attention_mask: Optional[torch.Tensor] = None,
|
|
2344
|
+
position_ids: Optional[torch.Tensor] = None,
|
|
2345
|
+
training_labels: Optional[torch.Tensor] = None,
|
|
2346
|
+
session_id: Optional[str] = None,
|
|
2347
|
+
sender_url: Optional[str] = None,
|
|
2348
|
+
use_cache: bool = False
|
|
2349
|
+
) -> Tuple[torch.Tensor, Optional[List[torch.Tensor]]]:
|
|
2350
|
+
"""
|
|
2351
|
+
Forward pass for pipeline parallelism (received from peer).
|
|
2352
|
+
"""
|
|
2353
|
+
# Enable gradient tracking if training
|
|
2354
|
+
is_training = training_labels is not None
|
|
2355
|
+
|
|
2356
|
+
if is_training:
|
|
2357
|
+
hidden_states.requires_grad_(True)
|
|
2358
|
+
hidden_states.retain_grad()
|
|
2359
|
+
|
|
2360
|
+
# Check if input is token IDs (embedding request)
|
|
2361
|
+
# Integer dtype or 2D shape [batch, seq] implies input_ids
|
|
2362
|
+
# This happens when a client sends input_ids to the Driver (Layer 0)
|
|
2363
|
+
if (hidden_states.dtype in [torch.long, torch.int64, torch.int32] or
|
|
2364
|
+
len(hidden_states.shape) == 2) and self.model.has_embedding:
|
|
2365
|
+
|
|
2366
|
+
# Ensure correct dtype
|
|
2367
|
+
if hidden_states.dtype != torch.long:
|
|
2368
|
+
hidden_states = hidden_states.to(torch.long)
|
|
2369
|
+
|
|
2370
|
+
# Embed tokens
|
|
2371
|
+
hidden_states = self.model.embed(hidden_states)
|
|
2372
|
+
|
|
2373
|
+
if is_training:
|
|
2374
|
+
hidden_states.requires_grad_(True)
|
|
2375
|
+
hidden_states.retain_grad()
|
|
2376
|
+
|
|
2377
|
+
# Forward through local layers
|
|
2378
|
+
output = self.model.forward_my_layers(hidden_states)
|
|
2379
|
+
|
|
2380
|
+
if is_training and session_id:
|
|
2381
|
+
# Save context for backward pass
|
|
2382
|
+
self.training_context[session_id] = {
|
|
2383
|
+
"input": hidden_states,
|
|
2384
|
+
"output": output,
|
|
2385
|
+
"sender_url": sender_url,
|
|
2386
|
+
"timestamp": time.time()
|
|
2387
|
+
}
|
|
2388
|
+
# Cleanup old sessions
|
|
2389
|
+
now = time.time()
|
|
2390
|
+
to_remove = [s for s, ctx in self.training_context.items() if now - ctx["timestamp"] > 600]
|
|
2391
|
+
for s in to_remove:
|
|
2392
|
+
del self.training_context[s]
|
|
2393
|
+
|
|
2394
|
+
# If we are the Validator (Last Layer holder)
|
|
2395
|
+
# DYNAMIC CHECK: Query layer_pool for current lm_head_holder
|
|
2396
|
+
# This handles the case where a new Validator joined and took over
|
|
2397
|
+
is_current_validator = self.model.has_lm_head
|
|
2398
|
+
if hasattr(self, 'layer_pool') and self.layer_pool:
|
|
2399
|
+
is_current_validator = (self.layer_pool.lm_head_holder == self.node_id)
|
|
2400
|
+
|
|
2401
|
+
if is_current_validator:
|
|
2402
|
+
logits = self.model.compute_logits(output)
|
|
2403
|
+
|
|
2404
|
+
# Calculate Loss if labels present
|
|
2405
|
+
if training_labels is not None:
|
|
2406
|
+
loss = torch.nn.functional.cross_entropy(
|
|
2407
|
+
logits.view(-1, logits.size(-1)),
|
|
2408
|
+
training_labels.view(-1),
|
|
2409
|
+
ignore_index=-100
|
|
2410
|
+
)
|
|
2411
|
+
|
|
2412
|
+
# Use training lock to prevent conflict with local training
|
|
2413
|
+
with self._training_lock:
|
|
2414
|
+
# Trigger Backward Pass
|
|
2415
|
+
self.optimizer.zero_grad()
|
|
2416
|
+
loss.backward()
|
|
2417
|
+
|
|
2418
|
+
# Propagate gradient back to previous node
|
|
2419
|
+
if sender_url and session_id:
|
|
2420
|
+
# The gradient we send back is dL/d(input_hidden_states)
|
|
2421
|
+
# hidden_states.grad is populated by backward()
|
|
2422
|
+
if hidden_states.grad is not None:
|
|
2423
|
+
self._backward_to_peer(
|
|
2424
|
+
sender_url,
|
|
2425
|
+
hidden_states.grad,
|
|
2426
|
+
# Target shard is whatever layer sent this to us.
|
|
2427
|
+
# Assuming sender holds previous layers.
|
|
2428
|
+
# We send to the sender's LAST layer.
|
|
2429
|
+
# Simplified: just send to the node, it routes.
|
|
2430
|
+
0,
|
|
2431
|
+
session_id
|
|
2432
|
+
)
|
|
2433
|
+
|
|
2434
|
+
# Step Optimizer
|
|
2435
|
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
|
|
2436
|
+
self.optimizer.step()
|
|
2437
|
+
|
|
2438
|
+
self.total_training_rounds += 1
|
|
2439
|
+
self.current_loss = loss.item()
|
|
2440
|
+
|
|
2441
|
+
return logits, None
|
|
2442
|
+
|
|
2443
|
+
return logits, None
|
|
2444
|
+
|
|
2445
|
+
# If we are a Worker (Middle Layer), we need to forward to next peer
|
|
2446
|
+
my_last_layer = max(self.my_layer_ids) if self.my_layer_ids else 0
|
|
2447
|
+
next_layer = my_last_layer + 1
|
|
2448
|
+
|
|
2449
|
+
if self.p2p_manager:
|
|
2450
|
+
next_hop = self.p2p_manager.get_next_hop(next_layer)
|
|
2451
|
+
if next_hop:
|
|
2452
|
+
return self._forward_to_peer(
|
|
2453
|
+
next_hop,
|
|
2454
|
+
output,
|
|
2455
|
+
next_layer,
|
|
2456
|
+
labels=training_labels,
|
|
2457
|
+
session_id=session_id
|
|
2458
|
+
)
|
|
2459
|
+
|
|
2460
|
+
logger.warning(f"Pipeline broken at layer {next_layer}: no peer found")
|
|
2461
|
+
return output, None
|
|
2462
|
+
|
|
2463
|
+
def backward_pipeline(self, grad_output: torch.Tensor, session_id: str):
|
|
2464
|
+
"""
|
|
2465
|
+
Backward pass received from next peer.
|
|
2466
|
+
"""
|
|
2467
|
+
if session_id not in self.training_context:
|
|
2468
|
+
logger.warning(f"Received backward for unknown session {session_id}")
|
|
2469
|
+
return
|
|
2470
|
+
|
|
2471
|
+
ctx = self.training_context[session_id]
|
|
2472
|
+
output = ctx["output"]
|
|
2473
|
+
input_tensor = ctx["input"]
|
|
2474
|
+
sender_url = ctx["sender_url"]
|
|
2475
|
+
|
|
2476
|
+
# Use training lock to prevent conflict with local training
|
|
2477
|
+
with self._training_lock:
|
|
2478
|
+
# Run local backward
|
|
2479
|
+
# output is the tensor we produced in forward_pipeline
|
|
2480
|
+
# grad_output is dL/d(output) received from next peer
|
|
2481
|
+
self.optimizer.zero_grad()
|
|
2482
|
+
output.backward(grad_output)
|
|
2483
|
+
|
|
2484
|
+
# Propagate back
|
|
2485
|
+
if sender_url and input_tensor.grad is not None:
|
|
2486
|
+
# Find previous layer ID? Not strictly needed for routing if we have direct sender URL
|
|
2487
|
+
# But _backward_to_peer takes layer_id
|
|
2488
|
+
self._backward_to_peer(sender_url, input_tensor.grad, 0, session_id)
|
|
2489
|
+
|
|
2490
|
+
# Step Optimizer
|
|
2491
|
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
|
|
2492
|
+
self.optimizer.step()
|
|
2493
|
+
|
|
2494
|
+
# Cleanup
|
|
2495
|
+
del self.training_context[session_id]
|
|
2496
|
+
|
|
2497
|
+
def _request_embedding(self, input_ids: torch.Tensor) -> torch.Tensor:
|
|
2498
|
+
"""Request embedding from the node that holds it."""
|
|
2499
|
+
# Find a node that holds Layer 0 (Driver)
|
|
2500
|
+
peer_url = None
|
|
2501
|
+
|
|
2502
|
+
# 1. Check layer pool assignments
|
|
2503
|
+
if self.layer_pool:
|
|
2504
|
+
assignments = self.layer_pool.get_layer_holders(0)
|
|
2505
|
+
if assignments:
|
|
2506
|
+
# Pick one (e.g., random for load balancing)
|
|
2507
|
+
import random
|
|
2508
|
+
peer_url = random.choice(assignments).grpc_addr
|
|
2509
|
+
|
|
2510
|
+
# 2. Fallback to P2P manager routing
|
|
2511
|
+
if not peer_url and self.p2p_manager:
|
|
2512
|
+
peer_url = self.p2p_manager.get_next_hop(0)
|
|
2513
|
+
|
|
2514
|
+
if not peer_url:
|
|
2515
|
+
raise RuntimeError("No embedding holder (Driver/Layer 0) found in network")
|
|
2516
|
+
|
|
2517
|
+
# Call peer - Send input_ids to Layer 0 holder
|
|
2518
|
+
# The receiver's forward_pipeline will detect it's input_ids and run embed()
|
|
2519
|
+
result, _ = self._forward_to_peer(peer_url, input_ids, 0)
|
|
2520
|
+
return result
|
|
2521
|
+
|
|
2522
|
+
def _forward_to_peer(self, peer_url: str, hidden: torch.Tensor, layer_id: int, labels: Optional[torch.Tensor] = None, session_id: str = None) -> torch.Tensor:
|
|
2523
|
+
"""
|
|
2524
|
+
Forward hidden states to a peer for processing.
|
|
2525
|
+
|
|
2526
|
+
SECURITY: Calculates and validates SHA256 checksums to detect tampering.
|
|
2527
|
+
"""
|
|
2528
|
+
from protos import neuroshard_pb2
|
|
2529
|
+
from protos import neuroshard_pb2_grpc
|
|
2530
|
+
from neuroshard.core.network.connection_pool import get_channel
|
|
2531
|
+
import numpy as np
|
|
2532
|
+
import hashlib
|
|
2533
|
+
|
|
2534
|
+
try:
|
|
2535
|
+
parsed = urlparse(peer_url)
|
|
2536
|
+
ip = parsed.hostname
|
|
2537
|
+
# gRPC port convention
|
|
2538
|
+
port = (parsed.port or 80) + 1000
|
|
2539
|
+
|
|
2540
|
+
channel = get_channel(f"{ip}:{port}")
|
|
2541
|
+
stub = neuroshard_pb2_grpc.NeuroShardServiceStub(channel)
|
|
2542
|
+
|
|
2543
|
+
# Serialize hidden states
|
|
2544
|
+
hidden_bytes = hidden.detach().cpu().numpy().tobytes()
|
|
2545
|
+
hidden_shape = list(hidden.shape)
|
|
2546
|
+
|
|
2547
|
+
# CHECKSUM: Calculate SHA256 hash for integrity verification
|
|
2548
|
+
checksum = hashlib.sha256(hidden_bytes).hexdigest()
|
|
2549
|
+
logger.debug(f"[SECURITY] Sending layer {layer_id} with checksum: {checksum[:16]}...")
|
|
2550
|
+
|
|
2551
|
+
# Serialize labels if present
|
|
2552
|
+
labels_bytes = b""
|
|
2553
|
+
if labels is not None:
|
|
2554
|
+
labels_bytes = labels.cpu().numpy().tobytes()
|
|
2555
|
+
|
|
2556
|
+
req_session_id = session_id or f"train_{time.time()}"
|
|
2557
|
+
|
|
2558
|
+
# Get my URL for backward routing
|
|
2559
|
+
my_url = ""
|
|
2560
|
+
if self.p2p_manager:
|
|
2561
|
+
my_url = self.p2p_manager.my_url
|
|
2562
|
+
|
|
2563
|
+
req = neuroshard_pb2.PipelineForwardRequest(
|
|
2564
|
+
session_id=req_session_id,
|
|
2565
|
+
request_id=f"req_{time.time()}",
|
|
2566
|
+
hidden_states=hidden_bytes,
|
|
2567
|
+
hidden_shape=hidden_shape,
|
|
2568
|
+
target_shard=layer_id,
|
|
2569
|
+
use_cache=False,
|
|
2570
|
+
training_labels=labels_bytes,
|
|
2571
|
+
sender_url=my_url
|
|
2572
|
+
)
|
|
2573
|
+
|
|
2574
|
+
# Store context for backward pass
|
|
2575
|
+
# We need to know WHO sent us this so we can send gradients back?
|
|
2576
|
+
# No, this function is called by US sending to THEM.
|
|
2577
|
+
# We need to know who THEY are so when they send us gradients back, we verify?
|
|
2578
|
+
# Actually, we don't need to do anything here for backward.
|
|
2579
|
+
# They will call PipelineBackward on US.
|
|
2580
|
+
|
|
2581
|
+
resp = stub.PipelineForward(req, timeout=30.0)
|
|
2582
|
+
|
|
2583
|
+
if not resp.success:
|
|
2584
|
+
raise RuntimeError(f"Peer error: {resp.error_message}")
|
|
2585
|
+
|
|
2586
|
+
# Deserialize result
|
|
2587
|
+
if resp.is_final:
|
|
2588
|
+
# It's logits
|
|
2589
|
+
result_bytes = resp.logits
|
|
2590
|
+
result = torch.from_numpy(
|
|
2591
|
+
np.frombuffer(result_bytes, dtype=np.float32)
|
|
2592
|
+
).reshape(list(resp.logits_shape))
|
|
2593
|
+
else:
|
|
2594
|
+
# It's hidden states (recursive/chained)
|
|
2595
|
+
result_bytes = resp.hidden_states
|
|
2596
|
+
result = torch.from_numpy(
|
|
2597
|
+
np.frombuffer(result_bytes, dtype=np.float32)
|
|
2598
|
+
).reshape(list(resp.hidden_shape))
|
|
2599
|
+
|
|
2600
|
+
# CHECKSUM VALIDATION: Verify integrity of received data
|
|
2601
|
+
received_checksum = hashlib.sha256(result_bytes).hexdigest()
|
|
2602
|
+
logger.debug(f"[SECURITY] Received layer {layer_id} result with checksum: {received_checksum[:16]}...")
|
|
2603
|
+
|
|
2604
|
+
# AUDIT TRAIL: Store checksum in PipelineSession for tamper detection
|
|
2605
|
+
if session_id and self.ledger and hasattr(self.ledger, 'inference_market'):
|
|
2606
|
+
market = self.ledger.inference_market
|
|
2607
|
+
if market and hasattr(market, 'active_sessions'):
|
|
2608
|
+
for sess_id, session in market.active_sessions.items():
|
|
2609
|
+
if sess_id == session_id or session.request_id in session_id:
|
|
2610
|
+
session.activations_hashes.append(received_checksum)
|
|
2611
|
+
logger.debug(f"[AUDIT] Stored checksum for layer {layer_id} in session")
|
|
2612
|
+
break
|
|
2613
|
+
|
|
2614
|
+
return result.to(self.device), None
|
|
2615
|
+
|
|
2616
|
+
except Exception as e:
|
|
2617
|
+
logger.error(f"Failed to forward to peer {peer_url}: {e}")
|
|
2618
|
+
return hidden, None
|
|
2619
|
+
|
|
2620
|
+
def _backward_to_peer(self, peer_url: str, grad_output: torch.Tensor, layer_id: int, session_id: str):
|
|
2621
|
+
"""Send gradients back to the previous peer."""
|
|
2622
|
+
from protos import neuroshard_pb2
|
|
2623
|
+
from protos import neuroshard_pb2_grpc
|
|
2624
|
+
from neuroshard.core.network.connection_pool import get_channel
|
|
2625
|
+
|
|
2626
|
+
try:
|
|
2627
|
+
parsed = urlparse(peer_url)
|
|
2628
|
+
ip = parsed.hostname
|
|
2629
|
+
port = (parsed.port or 80) + 1000
|
|
2630
|
+
|
|
2631
|
+
channel = get_channel(f"{ip}:{port}")
|
|
2632
|
+
stub = neuroshard_pb2_grpc.NeuroShardServiceStub(channel)
|
|
2633
|
+
|
|
2634
|
+
grad_bytes = grad_output.detach().cpu().numpy().tobytes()
|
|
2635
|
+
grad_shape = list(grad_output.shape)
|
|
2636
|
+
|
|
2637
|
+
req = neuroshard_pb2.PipelineBackwardRequest(
|
|
2638
|
+
session_id=session_id,
|
|
2639
|
+
request_id=f"bw_{time.time()}",
|
|
2640
|
+
grad_output=grad_bytes,
|
|
2641
|
+
grad_shape=grad_shape,
|
|
2642
|
+
target_shard=layer_id
|
|
2643
|
+
)
|
|
2644
|
+
|
|
2645
|
+
stub.PipelineBackward(req, timeout=10.0)
|
|
2646
|
+
|
|
2647
|
+
except Exception as e:
|
|
2648
|
+
logger.error(f"Failed to backward to peer {peer_url}: {e}")
|
|
2649
|
+
|
|
2650
|
+
def _request_logits(self, hidden: torch.Tensor) -> torch.Tensor:
|
|
2651
|
+
"""Request logits from the node that holds LM head."""
|
|
2652
|
+
# Find Last Layer holder (Validator)
|
|
2653
|
+
if not self.layer_pool:
|
|
2654
|
+
return hidden
|
|
2655
|
+
|
|
2656
|
+
capacity = self.layer_pool.get_network_capacity()
|
|
2657
|
+
last_layer = max(0, capacity.assigned_layers - 1)
|
|
2658
|
+
|
|
2659
|
+
peer_url = None
|
|
2660
|
+
|
|
2661
|
+
# 1. Check layer pool assignments
|
|
2662
|
+
assignments = self.layer_pool.get_layer_holders(last_layer)
|
|
2663
|
+
if assignments:
|
|
2664
|
+
import random
|
|
2665
|
+
peer_url = random.choice(assignments).grpc_addr
|
|
2666
|
+
|
|
2667
|
+
# 2. Fallback to P2P manager
|
|
2668
|
+
if not peer_url and self.p2p_manager:
|
|
2669
|
+
peer_url = self.p2p_manager.get_next_hop(last_layer)
|
|
2670
|
+
|
|
2671
|
+
if not peer_url:
|
|
2672
|
+
raise RuntimeError(f"No Validator (Layer {last_layer}) found in network")
|
|
2673
|
+
|
|
2674
|
+
# Forward hidden states to peer targeting Last Layer
|
|
2675
|
+
# The receiver will compute logits and return is_final=True
|
|
2676
|
+
return self._forward_to_peer(peer_url, hidden, last_layer)
|
|
2677
|
+
|
|
2678
|
+
def generate(
|
|
2679
|
+
self,
|
|
2680
|
+
prompt: str,
|
|
2681
|
+
max_new_tokens: int = 50,
|
|
2682
|
+
temperature: float = 1.0,
|
|
2683
|
+
) -> str:
|
|
2684
|
+
"""Generate text from prompt."""
|
|
2685
|
+
try:
|
|
2686
|
+
if not self.tokenizer:
|
|
2687
|
+
raise RuntimeError("Tokenizer not initialized")
|
|
2688
|
+
|
|
2689
|
+
input_ids = torch.tensor([self.tokenizer.encode(prompt)], dtype=torch.long)
|
|
2690
|
+
logger.debug(f"[GENERATE] Encoded prompt: {input_ids.shape} tokens")
|
|
2691
|
+
|
|
2692
|
+
# Move to model's device (handles CPU, CUDA, MPS)
|
|
2693
|
+
generated = input_ids.clone().to(self.device)
|
|
2694
|
+
|
|
2695
|
+
# Get current vocabulary size from tokenizer
|
|
2696
|
+
# Only tokens 0 to current_vocab_size-1 are valid (have learned representations)
|
|
2697
|
+
# This is NOT a workaround - it's how BPE tokenizers work (vocab grows over time)
|
|
2698
|
+
valid_vocab_size = self.tokenizer.current_vocab_size
|
|
2699
|
+
|
|
2700
|
+
for step in range(max_new_tokens):
|
|
2701
|
+
logits = self.forward(generated)
|
|
2702
|
+
next_logits = logits[:, -1, :] / temperature
|
|
2703
|
+
|
|
2704
|
+
# Constrain to valid vocabulary (standard BPE tokenizer behavior)
|
|
2705
|
+
# Tokens beyond current_vocab_size don't exist in the tokenizer yet
|
|
2706
|
+
if valid_vocab_size < next_logits.size(-1):
|
|
2707
|
+
next_logits[:, valid_vocab_size:] = float('-inf')
|
|
2708
|
+
|
|
2709
|
+
probs = torch.softmax(next_logits, dim=-1)
|
|
2710
|
+
next_token = torch.multinomial(probs, num_samples=1)
|
|
2711
|
+
generated = torch.cat([generated, next_token], dim=-1)
|
|
2712
|
+
|
|
2713
|
+
if next_token.item() == 2: # EOS
|
|
2714
|
+
logger.debug(f"[GENERATE] EOS at step {step+1}")
|
|
2715
|
+
break
|
|
2716
|
+
|
|
2717
|
+
prompt_tokens = input_ids.size(1)
|
|
2718
|
+
new_tokens = generated[0, prompt_tokens:].tolist()
|
|
2719
|
+
result = self.tokenizer.decode(new_tokens)
|
|
2720
|
+
logger.debug(f"[GENERATE] Generated {len(new_tokens)} tokens: '{result[:100]}...'")
|
|
2721
|
+
|
|
2722
|
+
return result
|
|
2723
|
+
|
|
2724
|
+
except Exception as e:
|
|
2725
|
+
logger.error(f"[GENERATE] Error: {e}")
|
|
2726
|
+
import traceback
|
|
2727
|
+
logger.error(traceback.format_exc())
|
|
2728
|
+
raise
|
|
2729
|
+
|
|
2730
|
+
# ==================== TRAINING ====================
|
|
2731
|
+
|
|
2732
|
+
def contribute_training_data(self, text: str, apply_dp: bool = True) -> int:
|
|
2733
|
+
"""
|
|
2734
|
+
Contribute training data.
|
|
2735
|
+
|
|
2736
|
+
Returns the number of tokens added.
|
|
2737
|
+
"""
|
|
2738
|
+
if not self.data_manager:
|
|
2739
|
+
return 0
|
|
2740
|
+
|
|
2741
|
+
# Get token count before
|
|
2742
|
+
stats_before = self.data_manager.get_stats()
|
|
2743
|
+
tokens_before = stats_before.get("total_tokens", 0)
|
|
2744
|
+
|
|
2745
|
+
self.data_manager.add_text(text, apply_dp=apply_dp)
|
|
2746
|
+
|
|
2747
|
+
# Get token count after
|
|
2748
|
+
stats_after = self.data_manager.get_stats()
|
|
2749
|
+
tokens_after = stats_after.get("total_tokens", 0)
|
|
2750
|
+
|
|
2751
|
+
tokens_added = tokens_after - tokens_before
|
|
2752
|
+
logger.info(f"Added {tokens_added} tokens to training buffer")
|
|
2753
|
+
|
|
2754
|
+
return tokens_added
|
|
2755
|
+
|
|
2756
|
+
def _get_training_batch(self) -> Optional[Tuple[torch.Tensor, torch.Tensor]]:
|
|
2757
|
+
"""
|
|
2758
|
+
Get a training batch from the Genesis data loader.
|
|
2759
|
+
|
|
2760
|
+
Returns:
|
|
2761
|
+
Tuple of (input_ids, labels) or None if data not available.
|
|
2762
|
+
|
|
2763
|
+
Note: Only Drivers (with embedding) load training data.
|
|
2764
|
+
Workers wait for activations via pipeline, they don't need data directly.
|
|
2765
|
+
"""
|
|
2766
|
+
if not self.enable_training:
|
|
2767
|
+
return None
|
|
2768
|
+
|
|
2769
|
+
# WORKERS DON'T LOAD DATA - they receive activations via pipeline
|
|
2770
|
+
# Only Drivers (with embedding) need to load training data
|
|
2771
|
+
if not self.model.has_embedding:
|
|
2772
|
+
return None # Worker - skip training data loading
|
|
2773
|
+
|
|
2774
|
+
# Initialize genesis loader if needed
|
|
2775
|
+
if not hasattr(self, 'genesis_loader') or self.genesis_loader is None:
|
|
2776
|
+
try:
|
|
2777
|
+
from neuroshard.core.training.distributed import GenesisDataLoader
|
|
2778
|
+
from neuroshard.core.model.tokenizer import get_neuro_tokenizer
|
|
2779
|
+
logger.info("[GENESIS] Initializing data loader...")
|
|
2780
|
+
self.genesis_loader = GenesisDataLoader(
|
|
2781
|
+
self.node_id,
|
|
2782
|
+
get_neuro_tokenizer(),
|
|
2783
|
+
max_storage_mb=self.max_storage_mb
|
|
2784
|
+
)
|
|
2785
|
+
logger.info(f"[GENESIS] Data loader ready: {self.genesis_loader.total_shards} shards available")
|
|
2786
|
+
|
|
2787
|
+
# Connect Swarm to Loader
|
|
2788
|
+
if hasattr(self, 'swarm') and self.swarm:
|
|
2789
|
+
self.genesis_loader.set_swarm(self.swarm)
|
|
2790
|
+
except Exception as e:
|
|
2791
|
+
logger.warning(f"[GENESIS] Failed to initialize loader: {e}")
|
|
2792
|
+
return None
|
|
2793
|
+
|
|
2794
|
+
# Check if data is ready
|
|
2795
|
+
if not self.genesis_loader.is_data_ready():
|
|
2796
|
+
return None
|
|
2797
|
+
|
|
2798
|
+
# Get batch
|
|
2799
|
+
batch_size = getattr(self, '_training_batch_size', 2)
|
|
2800
|
+
try:
|
|
2801
|
+
input_ids, labels = self.genesis_loader.get_batch(batch_size=batch_size)
|
|
2802
|
+
return input_ids, labels
|
|
2803
|
+
except Exception as e:
|
|
2804
|
+
logger.warning(f"[GENESIS] Failed to get batch: {e}")
|
|
2805
|
+
return None
|
|
2806
|
+
|
|
2807
|
+
def train_step(self) -> Optional[float]:
|
|
2808
|
+
"""
|
|
2809
|
+
Perform a training step on my layers.
|
|
2810
|
+
|
|
2811
|
+
OPTIMIZED FOR SINGLE-NODE: When we have embedding + all layers + LM head,
|
|
2812
|
+
we skip distributed overhead and train locally.
|
|
2813
|
+
|
|
2814
|
+
NON-BLOCKING DATA: Uses prefetched data when available, raises RuntimeError
|
|
2815
|
+
if data not ready (caller should retry later).
|
|
2816
|
+
"""
|
|
2817
|
+
if not self.enable_training:
|
|
2818
|
+
return None
|
|
2819
|
+
|
|
2820
|
+
# RUNTIME MEMORY CHECK: Skip training if system memory is critically high
|
|
2821
|
+
# This prevents OOM crashes and keeps the system responsive
|
|
2822
|
+
try:
|
|
2823
|
+
import psutil
|
|
2824
|
+
mem = psutil.virtual_memory()
|
|
2825
|
+
# Skip if less than 15% of system RAM is free (critical threshold)
|
|
2826
|
+
if mem.percent > 85:
|
|
2827
|
+
logger.warning(f"[NODE] System memory at {mem.percent:.0f}%, skipping training step")
|
|
2828
|
+
# Also try to free some memory
|
|
2829
|
+
import gc
|
|
2830
|
+
gc.collect()
|
|
2831
|
+
if self.device == "cuda":
|
|
2832
|
+
torch.cuda.empty_cache()
|
|
2833
|
+
elif self.device == "mps":
|
|
2834
|
+
torch.mps.empty_cache()
|
|
2835
|
+
return None
|
|
2836
|
+
except Exception:
|
|
2837
|
+
pass # If psutil fails, continue anyway
|
|
2838
|
+
|
|
2839
|
+
try:
|
|
2840
|
+
# SINGLE-NODE OPTIMIZATION: Check if we're a full node (Driver + Worker + Validator)
|
|
2841
|
+
# DYNAMIC CHECK: Use layer_pool to get current lm_head_holder
|
|
2842
|
+
# This handles the case where a new Validator joined and took over the LM head
|
|
2843
|
+
am_current_validator = self.model.has_lm_head
|
|
2844
|
+
if hasattr(self, 'layer_pool') and self.layer_pool:
|
|
2845
|
+
am_current_validator = (self.layer_pool.lm_head_holder == self.node_id)
|
|
2846
|
+
|
|
2847
|
+
is_full_node = self.model.has_embedding and am_current_validator
|
|
2848
|
+
|
|
2849
|
+
if self.model.has_embedding:
|
|
2850
|
+
# I am a Driver (Layer 0)
|
|
2851
|
+
# Use Genesis Data Loader
|
|
2852
|
+
if not hasattr(self, 'genesis_loader') or self.genesis_loader is None:
|
|
2853
|
+
try:
|
|
2854
|
+
from neuroshard.core.training.distributed import GenesisDataLoader
|
|
2855
|
+
from neuroshard.core.model.tokenizer import get_neuro_tokenizer
|
|
2856
|
+
logger.info("[GENESIS] Initializing data loader...")
|
|
2857
|
+
self.genesis_loader = GenesisDataLoader(
|
|
2858
|
+
self.node_id,
|
|
2859
|
+
get_neuro_tokenizer(),
|
|
2860
|
+
max_storage_mb=self.max_storage_mb
|
|
2861
|
+
)
|
|
2862
|
+
logger.info(f"[GENESIS] Data loader ready: {self.genesis_loader.total_shards} shards available")
|
|
2863
|
+
|
|
2864
|
+
# Connect Swarm to Loader
|
|
2865
|
+
if hasattr(self, 'swarm') and self.swarm:
|
|
2866
|
+
self.genesis_loader.set_swarm(self.swarm)
|
|
2867
|
+
except Exception as e:
|
|
2868
|
+
import traceback
|
|
2869
|
+
logger.error(f"[GENESIS] ERROR: {type(e).__name__}: {e}")
|
|
2870
|
+
logger.error(f"[GENESIS] {traceback.format_exc()}")
|
|
2871
|
+
# Mark as failed so we don't keep retrying immediately
|
|
2872
|
+
self.genesis_loader = None
|
|
2873
|
+
raise RuntimeError(f"Genesis loader init failed: {e}")
|
|
2874
|
+
|
|
2875
|
+
# Check if data is ready (non-blocking)
|
|
2876
|
+
if not self.genesis_loader.is_data_ready():
|
|
2877
|
+
# Data not ready - don't block, let caller retry
|
|
2878
|
+
raise RuntimeError("Data not ready - shard still loading")
|
|
2879
|
+
|
|
2880
|
+
# Get batch from Genesis Shard using memory-aware batch size
|
|
2881
|
+
batch_size = getattr(self, '_training_batch_size', 2)
|
|
2882
|
+
try:
|
|
2883
|
+
input_ids, labels = self.genesis_loader.get_batch(batch_size=batch_size)
|
|
2884
|
+
input_ids = input_ids.to(self.device)
|
|
2885
|
+
labels = labels.to(self.device)
|
|
2886
|
+
except RuntimeError as e:
|
|
2887
|
+
# Data not ready - propagate up
|
|
2888
|
+
raise
|
|
2889
|
+
except Exception as e:
|
|
2890
|
+
logger.warning(f"[GENESIS] Failed to get batch: {type(e).__name__}: {e}")
|
|
2891
|
+
import traceback
|
|
2892
|
+
logger.warning(traceback.format_exc())
|
|
2893
|
+
return None
|
|
2894
|
+
|
|
2895
|
+
# SINGLE-NODE OPTIMIZED PATH: Skip distributed overhead
|
|
2896
|
+
if is_full_node:
|
|
2897
|
+
return self._train_step_local(input_ids, labels)
|
|
2898
|
+
|
|
2899
|
+
# DISTRIBUTED PATH: Forward to next peer
|
|
2900
|
+
# Forward pass with optional gradient checkpointing
|
|
2901
|
+
# Note: time.sleep(0) yields GIL to keep HTTP server responsive
|
|
2902
|
+
embeddings = self.model.embed(input_ids)
|
|
2903
|
+
embeddings.requires_grad_(True)
|
|
2904
|
+
embeddings.retain_grad()
|
|
2905
|
+
time.sleep(0) # Yield GIL
|
|
2906
|
+
|
|
2907
|
+
# Use gradient checkpointing if enabled (trades CPU for memory)
|
|
2908
|
+
if getattr(self, '_use_gradient_checkpointing', False):
|
|
2909
|
+
output = torch.utils.checkpoint.checkpoint(
|
|
2910
|
+
self.model.forward_my_layers,
|
|
2911
|
+
embeddings,
|
|
2912
|
+
use_reentrant=False
|
|
2913
|
+
)
|
|
2914
|
+
else:
|
|
2915
|
+
output = self.model.forward_my_layers(embeddings)
|
|
2916
|
+
time.sleep(0) # Yield GIL after forward pass
|
|
2917
|
+
|
|
2918
|
+
# Distributed: Send to next peer
|
|
2919
|
+
my_last_layer = max(self.my_layer_ids) if self.my_layer_ids else 0
|
|
2920
|
+
next_layer = my_last_layer + 1
|
|
2921
|
+
|
|
2922
|
+
if self.p2p_manager:
|
|
2923
|
+
next_hop = self.p2p_manager.get_next_hop(next_layer)
|
|
2924
|
+
if next_hop:
|
|
2925
|
+
session_id = f"train_{self.node_id}_{time.time()}"
|
|
2926
|
+
|
|
2927
|
+
# Save context for backward
|
|
2928
|
+
self.training_context[session_id] = {
|
|
2929
|
+
"input": embeddings,
|
|
2930
|
+
"output": output,
|
|
2931
|
+
"sender_url": None, # We are the start
|
|
2932
|
+
"timestamp": time.time()
|
|
2933
|
+
}
|
|
2934
|
+
|
|
2935
|
+
result, _ = self._forward_to_peer(
|
|
2936
|
+
next_hop,
|
|
2937
|
+
output,
|
|
2938
|
+
next_layer,
|
|
2939
|
+
labels=labels,
|
|
2940
|
+
session_id=session_id
|
|
2941
|
+
)
|
|
2942
|
+
|
|
2943
|
+
# Check if forward succeeded (result should be different from output if it was processed)
|
|
2944
|
+
# If the peer rejected or failed, result will be the original output (unchanged)
|
|
2945
|
+
forward_succeeded = result is not output
|
|
2946
|
+
|
|
2947
|
+
if not forward_succeeded:
|
|
2948
|
+
# Pipeline forward failed - peer rejected or error
|
|
2949
|
+
# Clean up the training context
|
|
2950
|
+
if session_id in self.training_context:
|
|
2951
|
+
del self.training_context[session_id]
|
|
2952
|
+
logger.warning(f"[DISTRIBUTED] Pipeline forward failed - skipping training step")
|
|
2953
|
+
return None
|
|
2954
|
+
|
|
2955
|
+
# We don't get loss immediately in distributed pipeline
|
|
2956
|
+
# It comes back later via backward pass or status update
|
|
2957
|
+
# For now, return None (not inf!)
|
|
2958
|
+
return None
|
|
2959
|
+
|
|
2960
|
+
return None
|
|
2961
|
+
|
|
2962
|
+
else:
|
|
2963
|
+
# I am a Worker/Validator
|
|
2964
|
+
# I wait for activations from peers via gRPC (forward_pipeline)
|
|
2965
|
+
# So this method does nothing actively
|
|
2966
|
+
return None
|
|
2967
|
+
|
|
2968
|
+
except RuntimeError as e:
|
|
2969
|
+
error_msg = str(e)
|
|
2970
|
+
if "not ready" in error_msg.lower():
|
|
2971
|
+
# Data not ready - propagate to caller
|
|
2972
|
+
raise
|
|
2973
|
+
elif "out of memory" in error_msg.lower() or "MPS" in error_msg:
|
|
2974
|
+
logger.warning(f"Training step OOM - reducing batch size and clearing cache")
|
|
2975
|
+
# Clear GPU cache
|
|
2976
|
+
import gc
|
|
2977
|
+
gc.collect()
|
|
2978
|
+
if self.device == "mps":
|
|
2979
|
+
torch.mps.empty_cache()
|
|
2980
|
+
elif self.device == "cuda":
|
|
2981
|
+
torch.cuda.empty_cache()
|
|
2982
|
+
|
|
2983
|
+
# Reduce batch size for next attempt
|
|
2984
|
+
current_batch = getattr(self, '_training_batch_size', 8)
|
|
2985
|
+
if current_batch > 1:
|
|
2986
|
+
self._training_batch_size = max(1, current_batch // 2)
|
|
2987
|
+
logger.info(f"Reduced batch size to {self._training_batch_size}")
|
|
2988
|
+
else:
|
|
2989
|
+
# Already at minimum batch size, fall back to CPU for training
|
|
2990
|
+
if self.device != "cpu":
|
|
2991
|
+
logger.warning(f"Batch size already at minimum. Consider using --memory flag to limit layers.")
|
|
2992
|
+
else:
|
|
2993
|
+
logger.error(f"Training step failed: {e}")
|
|
2994
|
+
return None
|
|
2995
|
+
except Exception as e:
|
|
2996
|
+
logger.error(f"Training step failed: {e}")
|
|
2997
|
+
return None
|
|
2998
|
+
|
|
2999
|
+
def _train_step_local(self, input_ids: torch.Tensor, labels: torch.Tensor) -> float:
|
|
3000
|
+
"""
|
|
3001
|
+
OPTIMIZED single-node training step.
|
|
3002
|
+
|
|
3003
|
+
When we have ALL components (embedding + layers + LM head), we can
|
|
3004
|
+
train entirely locally without any network overhead.
|
|
3005
|
+
"""
|
|
3006
|
+
# DIAGNOSTIC: Verify device placement periodically
|
|
3007
|
+
if self.total_training_rounds % 100 == 0:
|
|
3008
|
+
try:
|
|
3009
|
+
emb_device = next(self.model.embedding.parameters()).device if self.model.embedding else 'N/A'
|
|
3010
|
+
layer_device = next(iter(self.model.my_layers.values())).parameters().__next__().device if self.model.my_layers else 'N/A'
|
|
3011
|
+
logger.info(f"[TRAIN] Device check: input={input_ids.device}, embedding={emb_device}, layer0={layer_device}")
|
|
3012
|
+
except Exception as e:
|
|
3013
|
+
logger.warning(f"[TRAIN] Device check failed: {e}")
|
|
3014
|
+
|
|
3015
|
+
# Forward pass with optional gradient checkpointing
|
|
3016
|
+
embeddings = self.model.embed(input_ids)
|
|
3017
|
+
|
|
3018
|
+
# Use gradient checkpointing if enabled (trades CPU for memory)
|
|
3019
|
+
if getattr(self, '_use_gradient_checkpointing', False):
|
|
3020
|
+
output = torch.utils.checkpoint.checkpoint(
|
|
3021
|
+
self.model.forward_my_layers,
|
|
3022
|
+
embeddings,
|
|
3023
|
+
use_reentrant=False
|
|
3024
|
+
)
|
|
3025
|
+
else:
|
|
3026
|
+
output = self.model.forward_my_layers(embeddings)
|
|
3027
|
+
|
|
3028
|
+
# Compute logits and loss
|
|
3029
|
+
logits = self.model.compute_logits(output)
|
|
3030
|
+
|
|
3031
|
+
loss = torch.nn.functional.cross_entropy(
|
|
3032
|
+
logits.view(-1, logits.size(-1)),
|
|
3033
|
+
labels.view(-1),
|
|
3034
|
+
ignore_index=-100
|
|
3035
|
+
)
|
|
3036
|
+
|
|
3037
|
+
# Use training lock to prevent conflict with pipeline training
|
|
3038
|
+
with self._training_lock:
|
|
3039
|
+
# Backward pass
|
|
3040
|
+
self.optimizer.zero_grad()
|
|
3041
|
+
loss.backward()
|
|
3042
|
+
|
|
3043
|
+
# Gradient clipping and optimizer step
|
|
3044
|
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), 1.0)
|
|
3045
|
+
self.optimizer.step()
|
|
3046
|
+
|
|
3047
|
+
# Update stats
|
|
3048
|
+
self.total_training_rounds += 1
|
|
3049
|
+
loss_val = loss.item()
|
|
3050
|
+
self.current_loss = loss_val
|
|
3051
|
+
|
|
3052
|
+
# PERIODIC CHECKPOINT: Save every 100 steps
|
|
3053
|
+
# Synchronous save blocks training briefly (~30-60s) but avoids memory pressure
|
|
3054
|
+
if self.total_training_rounds % 100 == 0:
|
|
3055
|
+
self._save_checkpoint()
|
|
3056
|
+
|
|
3057
|
+
return loss_val
|
|
3058
|
+
|
|
3059
|
+
# ==================== STATS & PONW ====================
|
|
3060
|
+
|
|
3061
|
+
def get_stats(self) -> Dict[str, Any]:
|
|
3062
|
+
"""Get node statistics."""
|
|
3063
|
+
# Safety check for shutdown race condition
|
|
3064
|
+
model = getattr(self, 'model', None)
|
|
3065
|
+
layer_pool = getattr(self, 'layer_pool', None)
|
|
3066
|
+
|
|
3067
|
+
contribution = model.get_my_contribution() if model else {}
|
|
3068
|
+
capacity = layer_pool.get_network_capacity() if layer_pool else None
|
|
3069
|
+
|
|
3070
|
+
# Calculate reward multiplier
|
|
3071
|
+
my_layer_ids = getattr(self, 'my_layer_ids', [])
|
|
3072
|
+
network_layers = capacity.assigned_layers if capacity else len(my_layer_ids)
|
|
3073
|
+
reward_multiplier = calculate_reward_multiplier(
|
|
3074
|
+
num_layers_held=len(my_layer_ids),
|
|
3075
|
+
total_network_layers=network_layers or 1,
|
|
3076
|
+
has_embedding=model.has_embedding if model else False,
|
|
3077
|
+
has_lm_head=model.has_lm_head if model else False,
|
|
3078
|
+
)
|
|
3079
|
+
|
|
3080
|
+
# Estimate network params (rough: ~10M params per layer)
|
|
3081
|
+
network_params = network_layers * 10_000_000 if network_layers else 0
|
|
3082
|
+
|
|
3083
|
+
# Get data buffer size
|
|
3084
|
+
data_buffer_size = 0
|
|
3085
|
+
if self.data_manager:
|
|
3086
|
+
data_stats = self.data_manager.get_stats()
|
|
3087
|
+
data_buffer_size = data_stats.get("buffer_size", 0)
|
|
3088
|
+
|
|
3089
|
+
# Get shard stats (if we have a genesis loader)
|
|
3090
|
+
shard_stats = {}
|
|
3091
|
+
if hasattr(self, 'genesis_loader') and self.genesis_loader:
|
|
3092
|
+
shard_stats = self.genesis_loader.get_stats()
|
|
3093
|
+
|
|
3094
|
+
# Multi-node identity info
|
|
3095
|
+
instance_id = getattr(self, 'instance_id', None)
|
|
3096
|
+
wallet_id = getattr(self, 'wallet_id', None)
|
|
3097
|
+
|
|
3098
|
+
return {
|
|
3099
|
+
"node_id": self.node_id[:16] + "...",
|
|
3100
|
+
"instance_id": instance_id, # Unique per machine+port
|
|
3101
|
+
"wallet_id": wallet_id, # Same across instances with same token
|
|
3102
|
+
"available_memory_mb": self.available_memory_mb,
|
|
3103
|
+
"my_layers": self.my_layer_ids,
|
|
3104
|
+
"my_params": contribution.get("my_params", 0),
|
|
3105
|
+
"has_embedding": contribution.get("has_embedding", False),
|
|
3106
|
+
"has_lm_head": contribution.get("has_lm_head", False),
|
|
3107
|
+
"contribution_ratio": contribution.get("contribution_ratio", 0),
|
|
3108
|
+
"reward_multiplier": reward_multiplier,
|
|
3109
|
+
"network_layers": network_layers,
|
|
3110
|
+
"network_params": network_params,
|
|
3111
|
+
"network_nodes": capacity.total_nodes if capacity else 1,
|
|
3112
|
+
"total_tokens_processed": self.total_tokens_processed,
|
|
3113
|
+
"total_training_rounds": self.total_training_rounds,
|
|
3114
|
+
"current_loss": self.current_loss,
|
|
3115
|
+
"inference_count": self.inference_count,
|
|
3116
|
+
"data_buffer_size": data_buffer_size,
|
|
3117
|
+
"shard_stats": shard_stats,
|
|
3118
|
+
}
|
|
3119
|
+
|
|
3120
|
+
def get_ponw_proof(self) -> Dict[str, Any]:
|
|
3121
|
+
"""
|
|
3122
|
+
Generate Proof of Neural Work.
|
|
3123
|
+
|
|
3124
|
+
This proof demonstrates verifiable neural network computation
|
|
3125
|
+
and is used for NEURO token rewards.
|
|
3126
|
+
"""
|
|
3127
|
+
contribution = self.model.get_my_contribution() if self.model else {}
|
|
3128
|
+
capacity = self.layer_pool.get_network_capacity() if self.layer_pool else None
|
|
3129
|
+
|
|
3130
|
+
# Calculate reward multiplier
|
|
3131
|
+
multiplier = calculate_reward_multiplier(
|
|
3132
|
+
num_layers_held=len(self.my_layer_ids),
|
|
3133
|
+
total_network_layers=capacity.assigned_layers if capacity else 1,
|
|
3134
|
+
has_embedding=self.model.has_embedding if self.model else False,
|
|
3135
|
+
has_lm_head=self.model.has_lm_head if self.model else False,
|
|
3136
|
+
)
|
|
3137
|
+
|
|
3138
|
+
timestamp = time.time()
|
|
3139
|
+
|
|
3140
|
+
# Determine role
|
|
3141
|
+
role = "Worker"
|
|
3142
|
+
if self.model and self.model.has_embedding:
|
|
3143
|
+
role = "Driver"
|
|
3144
|
+
elif self.model and self.model.has_lm_head:
|
|
3145
|
+
role = "Validator"
|
|
3146
|
+
|
|
3147
|
+
proof_data = {
|
|
3148
|
+
"node_id": self.node_id,
|
|
3149
|
+
"timestamp": timestamp,
|
|
3150
|
+
"tokens_processed": self.total_tokens_processed,
|
|
3151
|
+
"training_rounds": self.total_training_rounds,
|
|
3152
|
+
"training_contributions": self.training_contribution_count,
|
|
3153
|
+
"inference_count": self.inference_count,
|
|
3154
|
+
"layers_held": len(self.my_layer_ids),
|
|
3155
|
+
"layer_ids": self.my_layer_ids,
|
|
3156
|
+
"has_embedding": self.model.has_embedding if self.model else False,
|
|
3157
|
+
"has_lm_head": self.model.has_lm_head if self.model else False,
|
|
3158
|
+
"role": role,
|
|
3159
|
+
"reward_multiplier": multiplier,
|
|
3160
|
+
"available_memory_mb": self.available_memory_mb,
|
|
3161
|
+
}
|
|
3162
|
+
|
|
3163
|
+
# Add model hash for verification
|
|
3164
|
+
# Use architecture-based hash (consistent with SwarmEnabledDynamicNode._get_model_hash)
|
|
3165
|
+
if self.model:
|
|
3166
|
+
hasher = hashlib.sha256()
|
|
3167
|
+
arch_str = f"{self.model.hidden_dim}:{len(self.my_layer_ids)}:{getattr(self.model, 'num_heads', 0)}"
|
|
3168
|
+
hasher.update(arch_str.encode())
|
|
3169
|
+
for name, param in sorted(self.model.named_parameters()):
|
|
3170
|
+
hasher.update(f"{name}:{list(param.shape)}".encode())
|
|
3171
|
+
proof_data["model_hash"] = hasher.hexdigest()[:16]
|
|
3172
|
+
|
|
3173
|
+
# Sign the proof
|
|
3174
|
+
proof_string = f"{self.node_id}:{timestamp}:{self.total_tokens_processed}:{len(self.my_layer_ids)}:{self.total_training_rounds}"
|
|
3175
|
+
if self.node_token:
|
|
3176
|
+
# Use HMAC for proper signing
|
|
3177
|
+
import hmac
|
|
3178
|
+
signature = hmac.new(
|
|
3179
|
+
self.node_token.encode(),
|
|
3180
|
+
proof_string.encode(),
|
|
3181
|
+
hashlib.sha256
|
|
3182
|
+
).hexdigest()
|
|
3183
|
+
else:
|
|
3184
|
+
signature = hashlib.sha256(proof_string.encode()).hexdigest()
|
|
3185
|
+
|
|
3186
|
+
proof_data["signature"] = signature
|
|
3187
|
+
|
|
3188
|
+
return proof_data
|
|
3189
|
+
|
|
3190
|
+
def _prefetch_vocab_capacity(self):
|
|
3191
|
+
"""
|
|
3192
|
+
Pre-fetch tokenizer vocab size to know how much memory embeddings will need.
|
|
3193
|
+
|
|
3194
|
+
This MUST be called before register_node() to ensure accurate layer assignment.
|
|
3195
|
+
Without this, we'd assign layers assuming 32K vocab, then OOM when vocab expands to 288K+.
|
|
3196
|
+
"""
|
|
3197
|
+
import requests
|
|
3198
|
+
import os
|
|
3199
|
+
|
|
3200
|
+
GENESIS_CDN_URL = "https://dwquwt9gkkeil.cloudfront.net"
|
|
3201
|
+
cache_dir = os.path.join(os.path.expanduser("~"), ".neuroshard", "data_cache")
|
|
3202
|
+
tokenizer_cache_path = os.path.join(cache_dir, "tokenizer.json")
|
|
3203
|
+
|
|
3204
|
+
vocab_size = INITIAL_VOCAB_SIZE # Default fallback
|
|
3205
|
+
|
|
3206
|
+
try:
|
|
3207
|
+
# Try to fetch vocab size from CDN
|
|
3208
|
+
tokenizer_url = f"{GENESIS_CDN_URL}/tokenizer.json"
|
|
3209
|
+
resp = requests.get(tokenizer_url, timeout=10)
|
|
3210
|
+
|
|
3211
|
+
if resp.status_code == 200:
|
|
3212
|
+
data = resp.json()
|
|
3213
|
+
vocab_size = data.get("next_merge_id", INITIAL_VOCAB_SIZE)
|
|
3214
|
+
|
|
3215
|
+
# Cache locally for faster startup next time
|
|
3216
|
+
os.makedirs(cache_dir, exist_ok=True)
|
|
3217
|
+
with open(tokenizer_cache_path, 'w') as f:
|
|
3218
|
+
f.write(resp.text)
|
|
3219
|
+
|
|
3220
|
+
logger.info(f"[VOCAB] Pre-fetched tokenizer: {vocab_size:,} tokens (for memory calculation)")
|
|
3221
|
+
else:
|
|
3222
|
+
# Try cached version
|
|
3223
|
+
if os.path.exists(tokenizer_cache_path):
|
|
3224
|
+
import json
|
|
3225
|
+
with open(tokenizer_cache_path, 'r') as f:
|
|
3226
|
+
data = json.load(f)
|
|
3227
|
+
vocab_size = data.get("next_merge_id", INITIAL_VOCAB_SIZE)
|
|
3228
|
+
logger.info(f"[VOCAB] Using cached tokenizer: {vocab_size:,} tokens")
|
|
3229
|
+
except Exception as e:
|
|
3230
|
+
# Try cached version as fallback
|
|
3231
|
+
try:
|
|
3232
|
+
if os.path.exists(tokenizer_cache_path):
|
|
3233
|
+
import json
|
|
3234
|
+
with open(tokenizer_cache_path, 'r') as f:
|
|
3235
|
+
data = json.load(f)
|
|
3236
|
+
vocab_size = data.get("next_merge_id", INITIAL_VOCAB_SIZE)
|
|
3237
|
+
logger.info(f"[VOCAB] Using cached tokenizer: {vocab_size:,} tokens (CDN unavailable)")
|
|
3238
|
+
except Exception:
|
|
3239
|
+
pass
|
|
3240
|
+
logger.debug(f"[VOCAB] Could not prefetch vocab size: {e}, using default {INITIAL_VOCAB_SIZE}")
|
|
3241
|
+
|
|
3242
|
+
# Round up to next chunk boundary (no headroom - recalculate if vocab grows)
|
|
3243
|
+
# Previously used 64K headroom but this wastes ~1GB memory on limited devices
|
|
3244
|
+
vocab_capacity = ((vocab_size + VOCAB_GROWTH_CHUNK - 1) // VOCAB_GROWTH_CHUNK) * VOCAB_GROWTH_CHUNK
|
|
3245
|
+
|
|
3246
|
+
# Update layer pool's vocab_capacity for accurate layer assignment
|
|
3247
|
+
self.layer_pool.vocab_capacity = vocab_capacity
|
|
3248
|
+
logger.info(f"[VOCAB] Layer pool vocab_capacity set to {vocab_capacity:,} (current vocab: {vocab_size:,})")
|
|
3249
|
+
|
|
3250
|
+
def _reconcile_architecture(self):
|
|
3251
|
+
"""
|
|
3252
|
+
Smart architecture reconciliation for rejoining the network.
|
|
3253
|
+
|
|
3254
|
+
Handles all scenarios:
|
|
3255
|
+
1. Quick restart (same architecture) → Use checkpoint
|
|
3256
|
+
2. Network upgraded (larger arch) → Start fresh with network arch
|
|
3257
|
+
3. Network downgraded (smaller arch) → Start fresh with network arch
|
|
3258
|
+
4. Solo bootstrap (no peers) → Use checkpoint or calculate
|
|
3259
|
+
5. First time (no checkpoint) → Query network or calculate
|
|
3260
|
+
|
|
3261
|
+
Priority:
|
|
3262
|
+
1. Network consensus (if peers available)
|
|
3263
|
+
2. Saved checkpoint (if compatible)
|
|
3264
|
+
3. Fresh calculation (fallback)
|
|
3265
|
+
"""
|
|
3266
|
+
saved_arch = self._peek_checkpoint_architecture()
|
|
3267
|
+
network_arch = self._query_network_architecture()
|
|
3268
|
+
|
|
3269
|
+
# Log what we found
|
|
3270
|
+
if saved_arch:
|
|
3271
|
+
logger.info(f"Saved checkpoint: {saved_arch.num_layers}L × {saved_arch.hidden_dim}H")
|
|
3272
|
+
else:
|
|
3273
|
+
logger.info(f"No saved checkpoint found")
|
|
3274
|
+
|
|
3275
|
+
if network_arch:
|
|
3276
|
+
logger.info(f"Network architecture: {network_arch.num_layers}L × {network_arch.hidden_dim}H")
|
|
3277
|
+
else:
|
|
3278
|
+
logger.info(f"No peers found (solo mode or bootstrap)")
|
|
3279
|
+
|
|
3280
|
+
# Decision matrix
|
|
3281
|
+
if network_arch and saved_arch:
|
|
3282
|
+
# Both exist - compare them
|
|
3283
|
+
if self._architectures_compatible(saved_arch, network_arch):
|
|
3284
|
+
# Perfect - checkpoint matches network
|
|
3285
|
+
logger.info(f"✅ Checkpoint compatible with network - will load checkpoint")
|
|
3286
|
+
self.layer_pool.current_architecture = network_arch
|
|
3287
|
+
self.layer_pool.current_num_layers = network_arch.num_layers
|
|
3288
|
+
else:
|
|
3289
|
+
# Mismatch - network takes priority
|
|
3290
|
+
logger.warning(f"⚠️ Architecture mismatch!")
|
|
3291
|
+
logger.warning(f" Checkpoint: {saved_arch.num_layers}L × {saved_arch.hidden_dim}H")
|
|
3292
|
+
logger.warning(f" Network: {network_arch.num_layers}L × {network_arch.hidden_dim}H")
|
|
3293
|
+
|
|
3294
|
+
# Check if network arch fits in our memory
|
|
3295
|
+
network_memory = network_arch.estimate_memory_mb()
|
|
3296
|
+
if network_memory <= self.available_memory_mb:
|
|
3297
|
+
logger.warning(f" → Using NETWORK architecture (checkpoint will be incompatible)")
|
|
3298
|
+
logger.warning(f" → Your training progress will be preserved but weights reset")
|
|
3299
|
+
self.layer_pool.current_architecture = network_arch
|
|
3300
|
+
self.layer_pool.current_num_layers = network_arch.num_layers
|
|
3301
|
+
# Rename old checkpoint instead of deleting
|
|
3302
|
+
self._archive_incompatible_checkpoint()
|
|
3303
|
+
else:
|
|
3304
|
+
logger.error(f" → Network arch needs {network_memory}MB but you only have {self.available_memory_mb}MB!")
|
|
3305
|
+
logger.error(f" → This node cannot participate in current network")
|
|
3306
|
+
logger.error(f" → Falling back to solo mode with checkpoint architecture")
|
|
3307
|
+
self.layer_pool.current_architecture = saved_arch
|
|
3308
|
+
self.layer_pool.current_num_layers = saved_arch.num_layers
|
|
3309
|
+
|
|
3310
|
+
elif network_arch:
|
|
3311
|
+
# Network exists but no checkpoint - join the network
|
|
3312
|
+
network_memory = network_arch.estimate_memory_mb()
|
|
3313
|
+
if network_memory <= self.available_memory_mb:
|
|
3314
|
+
logger.info(f"✅ Joining network with architecture: {network_arch.num_layers}L × {network_arch.hidden_dim}H")
|
|
3315
|
+
self.layer_pool.current_architecture = network_arch
|
|
3316
|
+
self.layer_pool.current_num_layers = network_arch.num_layers
|
|
3317
|
+
else:
|
|
3318
|
+
logger.warning(f"⚠️ Network arch needs {network_memory}MB but you only have {self.available_memory_mb}MB")
|
|
3319
|
+
logger.warning(f" → Will calculate a smaller architecture (may train in isolation)")
|
|
3320
|
+
# Let register_node calculate appropriate architecture
|
|
3321
|
+
|
|
3322
|
+
elif saved_arch:
|
|
3323
|
+
# Checkpoint exists but no network peers (solo mode)
|
|
3324
|
+
# IMPORTANT: Check ACTUAL layers in checkpoint, not just architecture's num_layers
|
|
3325
|
+
actual_saved_layers = self._get_checkpoint_layer_count()
|
|
3326
|
+
if actual_saved_layers and actual_saved_layers > saved_arch.num_layers:
|
|
3327
|
+
# Model grew beyond base architecture - calculate memory for actual layers
|
|
3328
|
+
memory_per_layer = estimate_memory_per_layer(saved_arch)
|
|
3329
|
+
saved_memory = memory_per_layer * actual_saved_layers * 1.1 # 10% overhead
|
|
3330
|
+
logger.info(f"Checkpoint has {actual_saved_layers} layers (grew from base {saved_arch.num_layers})")
|
|
3331
|
+
else:
|
|
3332
|
+
saved_memory = saved_arch.estimate_memory_mb()
|
|
3333
|
+
actual_saved_layers = saved_arch.num_layers
|
|
3334
|
+
|
|
3335
|
+
if saved_memory <= self.available_memory_mb:
|
|
3336
|
+
logger.info(f"✅ Solo mode - using saved checkpoint: {actual_saved_layers}L × {saved_arch.hidden_dim}H")
|
|
3337
|
+
self.layer_pool.current_architecture = saved_arch
|
|
3338
|
+
self.layer_pool.current_num_layers = actual_saved_layers
|
|
3339
|
+
else:
|
|
3340
|
+
# Memory too small for saved checkpoint - calculate what we CAN fit
|
|
3341
|
+
# Use current vocab_capacity for accurate memory estimation
|
|
3342
|
+
vocab_cap = getattr(self.layer_pool, 'vocab_capacity', INITIAL_VOCAB_SIZE)
|
|
3343
|
+
max_layers = calculate_layer_assignment(
|
|
3344
|
+
self.available_memory_mb, saved_arch,
|
|
3345
|
+
safety_factor=0.6, vocab_capacity=vocab_cap,
|
|
3346
|
+
training_mode=True # Conservative for training
|
|
3347
|
+
)
|
|
3348
|
+
logger.warning(f"⚠️ Saved checkpoint has {actual_saved_layers} layers ({saved_memory:.0f}MB) "
|
|
3349
|
+
f"but you only have {self.available_memory_mb:.0f}MB")
|
|
3350
|
+
logger.warning(f" → Will use {max_layers} layers (reduced from checkpoint)")
|
|
3351
|
+
self.layer_pool.current_architecture = saved_arch
|
|
3352
|
+
self.layer_pool.current_num_layers = max_layers
|
|
3353
|
+
|
|
3354
|
+
else:
|
|
3355
|
+
# No checkpoint, no network - fresh start
|
|
3356
|
+
logger.info(f"Fresh start - architecture will be calculated from available memory")
|
|
3357
|
+
|
|
3358
|
+
def _query_network_architecture(self) -> Optional[ModelArchitecture]:
|
|
3359
|
+
"""
|
|
3360
|
+
Query the network for the current architecture.
|
|
3361
|
+
|
|
3362
|
+
Tries multiple sources:
|
|
3363
|
+
1. DHT lookup for architecture announcements
|
|
3364
|
+
2. Tracker API for network stats
|
|
3365
|
+
3. Direct peer query
|
|
3366
|
+
|
|
3367
|
+
Returns None if no peers available (solo mode).
|
|
3368
|
+
"""
|
|
3369
|
+
import requests
|
|
3370
|
+
|
|
3371
|
+
# Method 1: Try tracker API first (fastest, most reliable)
|
|
3372
|
+
if self.tracker_url:
|
|
3373
|
+
try:
|
|
3374
|
+
# Query tracker for network architecture
|
|
3375
|
+
response = requests.get(
|
|
3376
|
+
f"{self.tracker_url}/network_architecture",
|
|
3377
|
+
timeout=5
|
|
3378
|
+
)
|
|
3379
|
+
if response.ok:
|
|
3380
|
+
data = response.json()
|
|
3381
|
+
if data.get("hidden_dim"):
|
|
3382
|
+
arch = ModelArchitecture(
|
|
3383
|
+
hidden_dim=data["hidden_dim"],
|
|
3384
|
+
intermediate_dim=data.get("intermediate_dim", int(data["hidden_dim"] * 8 / 3)),
|
|
3385
|
+
num_layers=data.get("num_layers", 12),
|
|
3386
|
+
num_heads=data.get("num_heads", 12),
|
|
3387
|
+
num_kv_heads=data.get("num_kv_heads", 4),
|
|
3388
|
+
)
|
|
3389
|
+
logger.debug(f"Got network architecture from tracker: {arch.num_layers}L × {arch.hidden_dim}H")
|
|
3390
|
+
return arch
|
|
3391
|
+
except Exception as e:
|
|
3392
|
+
logger.debug(f"Tracker architecture query failed: {e}")
|
|
3393
|
+
|
|
3394
|
+
# Method 2: Query known peers directly
|
|
3395
|
+
if self.p2p_manager and self.p2p_manager.known_peers:
|
|
3396
|
+
for peer_url in list(self.p2p_manager.known_peers.keys())[:3]:
|
|
3397
|
+
try:
|
|
3398
|
+
response = requests.get(
|
|
3399
|
+
f"{peer_url}/api/node/architecture",
|
|
3400
|
+
timeout=3
|
|
3401
|
+
)
|
|
3402
|
+
if response.ok:
|
|
3403
|
+
data = response.json()
|
|
3404
|
+
if data.get("hidden_dim"):
|
|
3405
|
+
arch = ModelArchitecture(
|
|
3406
|
+
hidden_dim=data["hidden_dim"],
|
|
3407
|
+
intermediate_dim=data.get("intermediate_dim", int(data["hidden_dim"] * 8 / 3)),
|
|
3408
|
+
num_layers=data.get("num_layers", 12),
|
|
3409
|
+
num_heads=data.get("num_heads", 12),
|
|
3410
|
+
num_kv_heads=data.get("num_kv_heads", 4),
|
|
3411
|
+
)
|
|
3412
|
+
logger.debug(f"Got network architecture from peer {peer_url}: {arch.num_layers}L × {arch.hidden_dim}H")
|
|
3413
|
+
return arch
|
|
3414
|
+
except Exception:
|
|
3415
|
+
continue
|
|
3416
|
+
|
|
3417
|
+
# Method 3: DHT lookup (if available)
|
|
3418
|
+
if self.p2p_manager and hasattr(self.p2p_manager, 'dht') and self.p2p_manager.dht:
|
|
3419
|
+
try:
|
|
3420
|
+
import hashlib
|
|
3421
|
+
key = int(hashlib.sha1("network_architecture".encode()).hexdigest(), 16)
|
|
3422
|
+
value = self.p2p_manager.dht.lookup_value(key)
|
|
3423
|
+
if value:
|
|
3424
|
+
import json
|
|
3425
|
+
data = json.loads(value)
|
|
3426
|
+
if isinstance(data, dict) and data.get("hidden_dim"):
|
|
3427
|
+
arch = ModelArchitecture(
|
|
3428
|
+
hidden_dim=data["hidden_dim"],
|
|
3429
|
+
intermediate_dim=data.get("intermediate_dim", int(data["hidden_dim"] * 8 / 3)),
|
|
3430
|
+
num_layers=data.get("num_layers", 12),
|
|
3431
|
+
num_heads=data.get("num_heads", 12),
|
|
3432
|
+
num_kv_heads=data.get("num_kv_heads", 4),
|
|
3433
|
+
)
|
|
3434
|
+
logger.debug(f"Got network architecture from DHT: {arch.num_layers}L × {arch.hidden_dim}H")
|
|
3435
|
+
return arch
|
|
3436
|
+
except Exception as e:
|
|
3437
|
+
logger.debug(f"DHT architecture lookup failed: {e}")
|
|
3438
|
+
|
|
3439
|
+
return None
|
|
3440
|
+
|
|
3441
|
+
def _architectures_compatible(self, arch1: ModelArchitecture, arch2: ModelArchitecture) -> bool:
|
|
3442
|
+
"""
|
|
3443
|
+
Check if two architectures are compatible for gradient exchange.
|
|
3444
|
+
|
|
3445
|
+
Compatible means: same hidden_dim, num_heads, num_kv_heads
|
|
3446
|
+
(num_layers can differ - nodes just hold different subsets)
|
|
3447
|
+
"""
|
|
3448
|
+
return (
|
|
3449
|
+
arch1.hidden_dim == arch2.hidden_dim and
|
|
3450
|
+
arch1.num_heads == arch2.num_heads and
|
|
3451
|
+
arch1.num_kv_heads == arch2.num_kv_heads
|
|
3452
|
+
)
|
|
3453
|
+
|
|
3454
|
+
def _archive_incompatible_checkpoint(self):
|
|
3455
|
+
"""
|
|
3456
|
+
Archive an incompatible checkpoint instead of deleting it.
|
|
3457
|
+
|
|
3458
|
+
Storage-aware: Keeps only MAX_ARCHIVED_CHECKPOINTS and respects
|
|
3459
|
+
the user's storage budget.
|
|
3460
|
+
"""
|
|
3461
|
+
MAX_ARCHIVED_CHECKPOINTS = 2 # Keep at most 2 old checkpoints
|
|
3462
|
+
|
|
3463
|
+
path = self.CHECKPOINT_DIR / f"dynamic_node_{self.wallet_id}.pt"
|
|
3464
|
+
|
|
3465
|
+
if not path.exists():
|
|
3466
|
+
return
|
|
3467
|
+
|
|
3468
|
+
# First, clean up old archives to stay within limits
|
|
3469
|
+
self._cleanup_old_archives(MAX_ARCHIVED_CHECKPOINTS - 1) # Make room for new one
|
|
3470
|
+
|
|
3471
|
+
# Now archive the current checkpoint
|
|
3472
|
+
import time
|
|
3473
|
+
timestamp = int(time.time())
|
|
3474
|
+
archive_path = self.CHECKPOINT_DIR / f"archived_{self.wallet_id}_{timestamp}.pt"
|
|
3475
|
+
|
|
3476
|
+
try:
|
|
3477
|
+
path.rename(archive_path)
|
|
3478
|
+
logger.info(f"Archived incompatible checkpoint to: {archive_path.name}")
|
|
3479
|
+
except Exception as e:
|
|
3480
|
+
logger.warning(f"Could not archive checkpoint: {e}")
|
|
3481
|
+
# If archive fails, just delete it
|
|
3482
|
+
try:
|
|
3483
|
+
path.unlink()
|
|
3484
|
+
logger.info(f"Deleted incompatible checkpoint (archive failed)")
|
|
3485
|
+
except Exception:
|
|
3486
|
+
pass
|
|
3487
|
+
|
|
3488
|
+
def _cleanup_old_archives(self, max_keep: int = 2):
|
|
3489
|
+
"""
|
|
3490
|
+
Clean up old archived checkpoints, keeping only the most recent ones.
|
|
3491
|
+
|
|
3492
|
+
Also enforces storage budget if archives are taking too much space.
|
|
3493
|
+
"""
|
|
3494
|
+
# Find all archives for this wallet
|
|
3495
|
+
pattern = f"archived_{self.wallet_id}_*.pt"
|
|
3496
|
+
archives = sorted(
|
|
3497
|
+
self.CHECKPOINT_DIR.glob(pattern),
|
|
3498
|
+
key=lambda p: p.stat().st_mtime,
|
|
3499
|
+
reverse=True # Newest first
|
|
3500
|
+
)
|
|
3501
|
+
|
|
3502
|
+
# Calculate total archive size
|
|
3503
|
+
total_archive_mb = sum(p.stat().st_size / (1024 * 1024) for p in archives)
|
|
3504
|
+
|
|
3505
|
+
# Storage budget: archives should use at most 20% of max_storage
|
|
3506
|
+
archive_budget_mb = self.max_storage_mb * 0.2
|
|
3507
|
+
|
|
3508
|
+
# Delete archives that exceed count OR storage limits
|
|
3509
|
+
deleted_count = 0
|
|
3510
|
+
for i, archive in enumerate(archives):
|
|
3511
|
+
should_delete = False
|
|
3512
|
+
|
|
3513
|
+
# Too many archives
|
|
3514
|
+
if i >= max_keep:
|
|
3515
|
+
should_delete = True
|
|
3516
|
+
|
|
3517
|
+
# Over storage budget
|
|
3518
|
+
if total_archive_mb > archive_budget_mb:
|
|
3519
|
+
should_delete = True
|
|
3520
|
+
|
|
3521
|
+
if should_delete:
|
|
3522
|
+
try:
|
|
3523
|
+
archive_size_mb = archive.stat().st_size / (1024 * 1024)
|
|
3524
|
+
archive.unlink()
|
|
3525
|
+
total_archive_mb -= archive_size_mb
|
|
3526
|
+
deleted_count += 1
|
|
3527
|
+
logger.debug(f"Cleaned up old archive: {archive.name}")
|
|
3528
|
+
except Exception:
|
|
3529
|
+
pass
|
|
3530
|
+
|
|
3531
|
+
if deleted_count > 0:
|
|
3532
|
+
logger.info(f"Cleaned up {deleted_count} old archived checkpoint(s)")
|
|
3533
|
+
|
|
3534
|
+
def _get_checkpoint_layer_count(self) -> Optional[int]:
|
|
3535
|
+
"""
|
|
3536
|
+
Get the actual number of layers saved in checkpoint.
|
|
3537
|
+
|
|
3538
|
+
This is important because the model may have GROWN beyond the base architecture.
|
|
3539
|
+
The architecture might say 11 layers, but 110 layers could be saved!
|
|
3540
|
+
"""
|
|
3541
|
+
path = self.CHECKPOINT_DIR / f"dynamic_node_{self.wallet_id}.pt"
|
|
3542
|
+
legacy_path = self.CHECKPOINT_DIR / f"dynamic_node_{self.node_id[:16]}.pt"
|
|
3543
|
+
|
|
3544
|
+
checkpoint_path = path if path.exists() else (legacy_path if legacy_path.exists() else None)
|
|
3545
|
+
if not checkpoint_path:
|
|
3546
|
+
return None
|
|
3547
|
+
|
|
3548
|
+
try:
|
|
3549
|
+
checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
|
|
3550
|
+
layers = checkpoint.get("layers", {})
|
|
3551
|
+
if layers:
|
|
3552
|
+
return len(layers)
|
|
3553
|
+
# Fall back to layer_ids if present
|
|
3554
|
+
layer_ids = checkpoint.get("layer_ids", [])
|
|
3555
|
+
if layer_ids:
|
|
3556
|
+
return len(layer_ids)
|
|
3557
|
+
except Exception as e:
|
|
3558
|
+
logger.debug(f"Could not get checkpoint layer count: {e}")
|
|
3559
|
+
|
|
3560
|
+
return None
|
|
3561
|
+
|
|
3562
|
+
def _peek_checkpoint_architecture(self) -> Optional[ModelArchitecture]:
|
|
3563
|
+
"""
|
|
3564
|
+
Peek at checkpoint to get saved architecture WITHOUT loading weights.
|
|
3565
|
+
|
|
3566
|
+
This allows us to use the same architecture as the checkpoint,
|
|
3567
|
+
preventing architecture drift between restarts on the same machine.
|
|
3568
|
+
"""
|
|
3569
|
+
# Use wallet_id for stable checkpoint path
|
|
3570
|
+
path = self.CHECKPOINT_DIR / f"dynamic_node_{self.wallet_id}.pt"
|
|
3571
|
+
|
|
3572
|
+
# Also check legacy path
|
|
3573
|
+
legacy_path = self.CHECKPOINT_DIR / f"dynamic_node_{self.node_id[:16]}.pt"
|
|
3574
|
+
|
|
3575
|
+
checkpoint_path = None
|
|
3576
|
+
if path.exists():
|
|
3577
|
+
checkpoint_path = path
|
|
3578
|
+
elif legacy_path.exists():
|
|
3579
|
+
checkpoint_path = legacy_path
|
|
3580
|
+
|
|
3581
|
+
if not checkpoint_path:
|
|
3582
|
+
return None
|
|
3583
|
+
|
|
3584
|
+
try:
|
|
3585
|
+
# Load just the metadata (weights_only would fail, but we catch it)
|
|
3586
|
+
checkpoint = torch.load(checkpoint_path, map_location="cpu", weights_only=False)
|
|
3587
|
+
|
|
3588
|
+
arch_dict = checkpoint.get("architecture")
|
|
3589
|
+
if arch_dict:
|
|
3590
|
+
return ModelArchitecture.from_dict(arch_dict)
|
|
3591
|
+
except Exception as e:
|
|
3592
|
+
logger.debug(f"Could not peek checkpoint architecture: {e}")
|
|
3593
|
+
|
|
3594
|
+
return None
|
|
3595
|
+
|
|
3596
|
+
def _load_embedding_with_vocab_expansion(self, embedding: nn.Embedding, state_dict: dict, name: str):
|
|
3597
|
+
"""
|
|
3598
|
+
Load embedding weights, handling vocab size expansion gracefully.
|
|
3599
|
+
|
|
3600
|
+
When vocabulary grows (tokenizer learns new merges), the checkpoint has fewer
|
|
3601
|
+
tokens than the current model. This method:
|
|
3602
|
+
1. Loads weights for existing tokens (preserves all training)
|
|
3603
|
+
2. Keeps randomly initialized weights for new tokens
|
|
3604
|
+
"""
|
|
3605
|
+
checkpoint_weight = state_dict.get("weight")
|
|
3606
|
+
if checkpoint_weight is None:
|
|
3607
|
+
logger.warning(f"[CHECKPOINT] No weight found in {name} state_dict")
|
|
3608
|
+
return
|
|
3609
|
+
|
|
3610
|
+
checkpoint_vocab_size = checkpoint_weight.shape[0]
|
|
3611
|
+
current_vocab_size = embedding.weight.shape[0]
|
|
3612
|
+
|
|
3613
|
+
if checkpoint_vocab_size == current_vocab_size:
|
|
3614
|
+
# Same size - normal load
|
|
3615
|
+
embedding.load_state_dict(state_dict)
|
|
3616
|
+
logger.info(f"[CHECKPOINT] Loaded {name}: {current_vocab_size} tokens")
|
|
3617
|
+
elif checkpoint_vocab_size < current_vocab_size:
|
|
3618
|
+
# Vocab expanded - partial load (PRESERVE TRAINING!)
|
|
3619
|
+
with torch.no_grad():
|
|
3620
|
+
embedding.weight[:checkpoint_vocab_size] = checkpoint_weight
|
|
3621
|
+
logger.info(f"[CHECKPOINT] Loaded {name} with vocab expansion: "
|
|
3622
|
+
f"{checkpoint_vocab_size} → {current_vocab_size} tokens "
|
|
3623
|
+
f"(preserved {checkpoint_vocab_size} trained embeddings)")
|
|
3624
|
+
else:
|
|
3625
|
+
# Vocab shrunk (unusual) - load what fits
|
|
3626
|
+
with torch.no_grad():
|
|
3627
|
+
embedding.weight[:] = checkpoint_weight[:current_vocab_size]
|
|
3628
|
+
logger.warning(f"[CHECKPOINT] Loaded {name} with vocab truncation: "
|
|
3629
|
+
f"{checkpoint_vocab_size} → {current_vocab_size} tokens")
|
|
3630
|
+
|
|
3631
|
+
def _load_lm_head_with_vocab_expansion(self, lm_head: nn.Linear, state_dict: dict, name: str):
|
|
3632
|
+
"""
|
|
3633
|
+
Load LM head weights, handling vocab size expansion gracefully.
|
|
3634
|
+
|
|
3635
|
+
Similar to embedding expansion - preserves trained weights for existing tokens.
|
|
3636
|
+
"""
|
|
3637
|
+
checkpoint_weight = state_dict.get("weight")
|
|
3638
|
+
checkpoint_bias = state_dict.get("bias")
|
|
3639
|
+
|
|
3640
|
+
if checkpoint_weight is None:
|
|
3641
|
+
logger.warning(f"[CHECKPOINT] No weight found in {name} state_dict")
|
|
3642
|
+
return
|
|
3643
|
+
|
|
3644
|
+
checkpoint_vocab_size = checkpoint_weight.shape[0]
|
|
3645
|
+
current_vocab_size = lm_head.weight.shape[0]
|
|
3646
|
+
|
|
3647
|
+
if checkpoint_vocab_size == current_vocab_size:
|
|
3648
|
+
# Same size - normal load
|
|
3649
|
+
lm_head.load_state_dict(state_dict)
|
|
3650
|
+
logger.info(f"[CHECKPOINT] Loaded {name}: {current_vocab_size} outputs")
|
|
3651
|
+
elif checkpoint_vocab_size < current_vocab_size:
|
|
3652
|
+
# Vocab expanded - partial load (PRESERVE TRAINING!)
|
|
3653
|
+
with torch.no_grad():
|
|
3654
|
+
lm_head.weight[:checkpoint_vocab_size] = checkpoint_weight
|
|
3655
|
+
if checkpoint_bias is not None and lm_head.bias is not None:
|
|
3656
|
+
lm_head.bias[:checkpoint_vocab_size] = checkpoint_bias
|
|
3657
|
+
logger.info(f"[CHECKPOINT] Loaded {name} with vocab expansion: "
|
|
3658
|
+
f"{checkpoint_vocab_size} → {current_vocab_size} outputs "
|
|
3659
|
+
f"(preserved {checkpoint_vocab_size} trained weights)")
|
|
3660
|
+
else:
|
|
3661
|
+
# Vocab shrunk (unusual) - load what fits
|
|
3662
|
+
with torch.no_grad():
|
|
3663
|
+
lm_head.weight[:] = checkpoint_weight[:current_vocab_size]
|
|
3664
|
+
if checkpoint_bias is not None and lm_head.bias is not None:
|
|
3665
|
+
lm_head.bias[:] = checkpoint_bias[:current_vocab_size]
|
|
3666
|
+
logger.warning(f"[CHECKPOINT] Loaded {name} with vocab truncation: "
|
|
3667
|
+
f"{checkpoint_vocab_size} → {current_vocab_size} outputs")
|
|
3668
|
+
|
|
3669
|
+
def _load_checkpoint(self):
|
|
3670
|
+
"""Load checkpoint from disk if it exists (resume training)."""
|
|
3671
|
+
# Use wallet_id for stable checkpoint path (survives node_id changes)
|
|
3672
|
+
path = self.CHECKPOINT_DIR / f"dynamic_node_{self.wallet_id}.pt"
|
|
3673
|
+
|
|
3674
|
+
# Also check legacy path (node_id-based) for migration
|
|
3675
|
+
legacy_path = self.CHECKPOINT_DIR / f"dynamic_node_{self.node_id[:16]}.pt"
|
|
3676
|
+
if not path.exists() and legacy_path.exists():
|
|
3677
|
+
logger.info(f"Migrating checkpoint from legacy path: {legacy_path.name} -> {path.name}")
|
|
3678
|
+
legacy_path.rename(path)
|
|
3679
|
+
|
|
3680
|
+
if not path.exists():
|
|
3681
|
+
logger.info(f"No checkpoint found at {path.name}, starting fresh")
|
|
3682
|
+
return False
|
|
3683
|
+
|
|
3684
|
+
logger.info(f"Loading checkpoint from: {path.name}")
|
|
3685
|
+
try:
|
|
3686
|
+
checkpoint = torch.load(path, map_location=self.device, weights_only=False)
|
|
3687
|
+
|
|
3688
|
+
# ARCHITECTURE COMPATIBILITY CHECK
|
|
3689
|
+
saved_arch_dict = checkpoint.get("architecture")
|
|
3690
|
+
if saved_arch_dict:
|
|
3691
|
+
saved_arch = ModelArchitecture.from_dict(saved_arch_dict)
|
|
3692
|
+
current_arch = self.model.architecture
|
|
3693
|
+
|
|
3694
|
+
# Check if architecture changed (includes num_heads for head_dim compatibility)
|
|
3695
|
+
if (saved_arch.hidden_dim != current_arch.hidden_dim or
|
|
3696
|
+
saved_arch.intermediate_dim != current_arch.intermediate_dim or
|
|
3697
|
+
saved_arch.num_heads != current_arch.num_heads or
|
|
3698
|
+
saved_arch.num_kv_heads != current_arch.num_kv_heads):
|
|
3699
|
+
logger.warning(f"Architecture mismatch! Checkpoint is incompatible.")
|
|
3700
|
+
logger.warning(f" Saved: {saved_arch.num_layers}L × {saved_arch.hidden_dim}H, "
|
|
3701
|
+
f"heads={saved_arch.num_heads}/{saved_arch.num_kv_heads}")
|
|
3702
|
+
logger.warning(f" Current: {current_arch.num_layers}L × {current_arch.hidden_dim}H, "
|
|
3703
|
+
f"heads={current_arch.num_heads}/{current_arch.num_kv_heads}")
|
|
3704
|
+
logger.warning(f" Starting fresh (architecture was upgraded)")
|
|
3705
|
+
# Delete incompatible checkpoint
|
|
3706
|
+
try:
|
|
3707
|
+
path.unlink()
|
|
3708
|
+
logger.info(f"Deleted incompatible checkpoint: {path}")
|
|
3709
|
+
except Exception:
|
|
3710
|
+
pass
|
|
3711
|
+
return False
|
|
3712
|
+
else:
|
|
3713
|
+
logger.warning("Legacy checkpoint without architecture info - starting fresh")
|
|
3714
|
+
# Delete legacy checkpoint
|
|
3715
|
+
try:
|
|
3716
|
+
path.unlink()
|
|
3717
|
+
logger.info(f"Deleted legacy checkpoint: {path}")
|
|
3718
|
+
except Exception:
|
|
3719
|
+
pass
|
|
3720
|
+
return False
|
|
3721
|
+
|
|
3722
|
+
# Check layer assignment compatibility
|
|
3723
|
+
saved_layers = set(checkpoint.get("layer_ids", []))
|
|
3724
|
+
current_layers = set(self.my_layer_ids)
|
|
3725
|
+
|
|
3726
|
+
if saved_layers != current_layers:
|
|
3727
|
+
# Layers changed - try to load what we can
|
|
3728
|
+
common_layers = saved_layers.intersection(current_layers)
|
|
3729
|
+
if common_layers:
|
|
3730
|
+
logger.warning(f"Layer assignment changed: saved={len(saved_layers)}, current={len(current_layers)}, common={len(common_layers)}")
|
|
3731
|
+
logger.info(f"Will load {len(common_layers)} common layers from checkpoint")
|
|
3732
|
+
else:
|
|
3733
|
+
logger.warning(f"No common layers between checkpoint and current assignment, starting fresh")
|
|
3734
|
+
return False
|
|
3735
|
+
|
|
3736
|
+
# Load layer weights
|
|
3737
|
+
for layer_id, state_dict in checkpoint.get("layers", {}).items():
|
|
3738
|
+
layer_id = int(layer_id)
|
|
3739
|
+
if layer_id in self.model.my_layers:
|
|
3740
|
+
self.model.my_layers[layer_id].load_state_dict(state_dict)
|
|
3741
|
+
|
|
3742
|
+
# Load embedding if present (handle vocab size changes gracefully)
|
|
3743
|
+
if self.model.embedding and "embedding" in checkpoint:
|
|
3744
|
+
self._load_embedding_with_vocab_expansion(
|
|
3745
|
+
self.model.embedding,
|
|
3746
|
+
checkpoint["embedding"],
|
|
3747
|
+
"embedding"
|
|
3748
|
+
)
|
|
3749
|
+
|
|
3750
|
+
# Load LM head if present (handle vocab size changes gracefully)
|
|
3751
|
+
if self.model.lm_head and "lm_head" in checkpoint:
|
|
3752
|
+
self._load_lm_head_with_vocab_expansion(
|
|
3753
|
+
self.model.lm_head,
|
|
3754
|
+
checkpoint["lm_head"],
|
|
3755
|
+
"lm_head"
|
|
3756
|
+
)
|
|
3757
|
+
|
|
3758
|
+
# Load final norm if present
|
|
3759
|
+
if self.model.final_norm and "final_norm" in checkpoint:
|
|
3760
|
+
self.model.final_norm.load_state_dict(checkpoint["final_norm"])
|
|
3761
|
+
|
|
3762
|
+
# Restore training state
|
|
3763
|
+
self.total_training_rounds = checkpoint.get("total_training_rounds", 0)
|
|
3764
|
+
|
|
3765
|
+
# Store optimizer state for later loading (after optimizer is created)
|
|
3766
|
+
if "optimizer" in checkpoint:
|
|
3767
|
+
self._pending_optimizer_state = checkpoint["optimizer"]
|
|
3768
|
+
|
|
3769
|
+
# Store DiLoCo state for later loading (after swarm is created)
|
|
3770
|
+
if "diloco" in checkpoint:
|
|
3771
|
+
self._pending_diloco_state = checkpoint["diloco"]
|
|
3772
|
+
logger.info("[NODE] DiLoCo state found in checkpoint, will restore after swarm init")
|
|
3773
|
+
|
|
3774
|
+
# Count how many layers were actually loaded
|
|
3775
|
+
loaded_layer_count = sum(1 for lid in checkpoint.get("layers", {}).keys() if int(lid) in self.model.my_layers)
|
|
3776
|
+
logger.info(f"Checkpoint loaded: {self.total_training_rounds} training rounds, "
|
|
3777
|
+
f"{loaded_layer_count}/{len(current_layers)} layers from {path}")
|
|
3778
|
+
return True
|
|
3779
|
+
|
|
3780
|
+
except Exception as e:
|
|
3781
|
+
logger.warning(f"Failed to load checkpoint: {e}, starting fresh")
|
|
3782
|
+
return False
|
|
3783
|
+
|
|
3784
|
+
def _restore_pending_state(self):
|
|
3785
|
+
"""
|
|
3786
|
+
Restore optimizer and DiLoCo state after they are initialized.
|
|
3787
|
+
|
|
3788
|
+
Called after swarm/optimizer are set up to restore checkpoint state.
|
|
3789
|
+
"""
|
|
3790
|
+
# Restore optimizer state
|
|
3791
|
+
if hasattr(self, '_pending_optimizer_state') and self._pending_optimizer_state:
|
|
3792
|
+
if hasattr(self, 'optimizer') and self.optimizer:
|
|
3793
|
+
try:
|
|
3794
|
+
self.optimizer.load_state_dict(self._pending_optimizer_state)
|
|
3795
|
+
logger.info("[NODE] Restored optimizer state from checkpoint")
|
|
3796
|
+
except Exception as e:
|
|
3797
|
+
logger.warning(f"[NODE] Could not restore optimizer state: {e}")
|
|
3798
|
+
del self._pending_optimizer_state
|
|
3799
|
+
|
|
3800
|
+
# Restore DiLoCo state
|
|
3801
|
+
if hasattr(self, '_pending_diloco_state') and self._pending_diloco_state:
|
|
3802
|
+
if hasattr(self, 'swarm') and self.swarm:
|
|
3803
|
+
diloco = getattr(self.swarm, 'diloco_trainer', None)
|
|
3804
|
+
if diloco and hasattr(diloco, 'load_state_dict'):
|
|
3805
|
+
try:
|
|
3806
|
+
diloco.load_state_dict(self._pending_diloco_state)
|
|
3807
|
+
logger.info(f"[NODE] Restored DiLoCo state (inner_step={diloco.stats.inner_step_count})")
|
|
3808
|
+
except Exception as e:
|
|
3809
|
+
logger.warning(f"[NODE] Could not restore DiLoCo state: {e}")
|
|
3810
|
+
del self._pending_diloco_state
|
|
3811
|
+
|
|
3812
|
+
# Class-level save lock to prevent concurrent checkpoint saves
|
|
3813
|
+
_checkpoint_save_lock = threading.Lock()
|
|
3814
|
+
_checkpoint_save_in_progress = False
|
|
3815
|
+
|
|
3816
|
+
def _save_checkpoint(self, async_save: bool = True):
|
|
3817
|
+
"""
|
|
3818
|
+
Smart checkpoint saving with STREAMING ASYNC for memory-constrained systems.
|
|
3819
|
+
|
|
3820
|
+
THREE MODES:
|
|
3821
|
+
1. BULK ASYNC (>32GB free OR >2.5x checkpoint): Clone all, save in thread
|
|
3822
|
+
2. STREAMING ASYNC (>500MB free): Clone one layer at a time, save incrementally
|
|
3823
|
+
3. SYNC (<500MB free): Blocking save (last resort)
|
|
3824
|
+
|
|
3825
|
+
Streaming async blocks training only during the brief snapshot (~1-2s for 110 layers),
|
|
3826
|
+
then saves to disk in background (~10-60s) while training continues.
|
|
3827
|
+
|
|
3828
|
+
Thread-safe: concurrent saves are serialized via lock.
|
|
3829
|
+
"""
|
|
3830
|
+
if not self.model:
|
|
3831
|
+
return
|
|
3832
|
+
|
|
3833
|
+
# Prevent concurrent saves (async save might still be in progress)
|
|
3834
|
+
if not DynamicNeuroNode._checkpoint_save_lock.acquire(blocking=False):
|
|
3835
|
+
logger.debug("[NODE] Checkpoint save skipped - another save in progress")
|
|
3836
|
+
return
|
|
3837
|
+
|
|
3838
|
+
# Use wallet_id for stable checkpoint path
|
|
3839
|
+
path = self.CHECKPOINT_DIR / f"dynamic_node_{self.wallet_id}.pt"
|
|
3840
|
+
temp_path = self.CHECKPOINT_DIR / f"dynamic_node_{self.wallet_id}.pt.tmp"
|
|
3841
|
+
|
|
3842
|
+
# 1. Assess memory situation (respect configured --memory limit!)
|
|
3843
|
+
try:
|
|
3844
|
+
total_params = sum(p.numel() for p in self.model.parameters())
|
|
3845
|
+
checkpoint_size_mb = (total_params * 4) / (1024 * 1024)
|
|
3846
|
+
|
|
3847
|
+
# Get ACTUAL available memory (respecting configured limit)
|
|
3848
|
+
vm = psutil.virtual_memory()
|
|
3849
|
+
system_available_mb = vm.available / (1024 * 1024)
|
|
3850
|
+
|
|
3851
|
+
# Check current process memory usage
|
|
3852
|
+
process = psutil.Process()
|
|
3853
|
+
process_used_mb = process.memory_info().rss / (1024 * 1024)
|
|
3854
|
+
|
|
3855
|
+
# If user set --memory limit, respect it
|
|
3856
|
+
# Available = min(system_available, configured_limit - current_usage)
|
|
3857
|
+
configured_limit = getattr(self, 'available_memory_mb', None)
|
|
3858
|
+
if configured_limit:
|
|
3859
|
+
# How much headroom do we have within our configured limit?
|
|
3860
|
+
headroom_mb = max(0, configured_limit - process_used_mb)
|
|
3861
|
+
# Use the more conservative of system available or our headroom
|
|
3862
|
+
available_mb = min(system_available_mb, headroom_mb)
|
|
3863
|
+
logger.debug(f"[NODE] Memory check: process={process_used_mb:.0f}MB, "
|
|
3864
|
+
f"limit={configured_limit:.0f}MB, headroom={headroom_mb:.0f}MB, "
|
|
3865
|
+
f"system_free={system_available_mb:.0f}MB, using={available_mb:.0f}MB")
|
|
3866
|
+
else:
|
|
3867
|
+
available_mb = system_available_mb
|
|
3868
|
+
|
|
3869
|
+
# Determine save mode based on available headroom
|
|
3870
|
+
# Bulk async needs 2.5x checkpoint size to clone everything
|
|
3871
|
+
can_bulk_async = (available_mb > (checkpoint_size_mb * 2.5)) or (available_mb > 32000)
|
|
3872
|
+
# Streaming async just needs enough for 1 layer (~50-100MB typically)
|
|
3873
|
+
can_stream_async = available_mb > 500
|
|
3874
|
+
|
|
3875
|
+
except Exception as e:
|
|
3876
|
+
logger.warning(f"[NODE] Could not assess memory: {e}. Using streaming async.")
|
|
3877
|
+
can_bulk_async = False
|
|
3878
|
+
can_stream_async = True
|
|
3879
|
+
checkpoint_size_mb = 0
|
|
3880
|
+
available_mb = 0
|
|
3881
|
+
|
|
3882
|
+
try:
|
|
3883
|
+
# ============ ATOMIC SNAPSHOT PHASE ============
|
|
3884
|
+
# ALL state must be captured together to ensure consistency.
|
|
3885
|
+
# DiLoCo state + model weights must be from the same "moment in time".
|
|
3886
|
+
|
|
3887
|
+
# Helper to deep-clone DiLoCo state (ALL tensors to CPU for async safety)
|
|
3888
|
+
def _clone_diloco_state():
|
|
3889
|
+
if not hasattr(self, 'swarm') or not self.swarm:
|
|
3890
|
+
return None
|
|
3891
|
+
diloco = getattr(self.swarm, 'diloco_trainer', None)
|
|
3892
|
+
if not diloco or not hasattr(diloco, 'state_dict'):
|
|
3893
|
+
return None
|
|
3894
|
+
try:
|
|
3895
|
+
state = diloco.state_dict()
|
|
3896
|
+
|
|
3897
|
+
# Deep clone optimizer state (handles both PyTorch and custom formats)
|
|
3898
|
+
def _clone_optimizer_state(opt_state):
|
|
3899
|
+
if opt_state is None:
|
|
3900
|
+
return None
|
|
3901
|
+
cloned = {}
|
|
3902
|
+
for key, value in opt_state.items():
|
|
3903
|
+
if isinstance(value, torch.Tensor):
|
|
3904
|
+
# Direct tensor (e.g., in custom optimizers)
|
|
3905
|
+
cloned[key] = value.detach().clone().cpu()
|
|
3906
|
+
elif isinstance(value, dict):
|
|
3907
|
+
# Nested dict (e.g., 'state' or 'velocity' dicts)
|
|
3908
|
+
cloned[key] = {}
|
|
3909
|
+
for k, v in value.items():
|
|
3910
|
+
if isinstance(v, torch.Tensor):
|
|
3911
|
+
cloned[key][k] = v.detach().clone().cpu()
|
|
3912
|
+
elif isinstance(v, dict):
|
|
3913
|
+
# PyTorch optimizer 'state' has param_idx -> {key: tensor}
|
|
3914
|
+
cloned[key][k] = {}
|
|
3915
|
+
for kk, vv in v.items():
|
|
3916
|
+
if isinstance(vv, torch.Tensor):
|
|
3917
|
+
cloned[key][k][kk] = vv.detach().clone().cpu()
|
|
3918
|
+
else:
|
|
3919
|
+
cloned[key][k][kk] = vv
|
|
3920
|
+
else:
|
|
3921
|
+
cloned[key][k] = v
|
|
3922
|
+
elif isinstance(value, list):
|
|
3923
|
+
# List (e.g., param_groups) - shallow copy is fine
|
|
3924
|
+
cloned[key] = list(value)
|
|
3925
|
+
else:
|
|
3926
|
+
# Scalar values (lr, momentum, etc.)
|
|
3927
|
+
cloned[key] = value
|
|
3928
|
+
return cloned
|
|
3929
|
+
|
|
3930
|
+
# Deep clone all tensors to CPU for async safety
|
|
3931
|
+
cloned = {
|
|
3932
|
+
'config': dict(state.get('config', {})),
|
|
3933
|
+
'inner_optimizer': _clone_optimizer_state(state.get('inner_optimizer')),
|
|
3934
|
+
'outer_optimizer': _clone_optimizer_state(state.get('outer_optimizer')),
|
|
3935
|
+
'initial_weights': {
|
|
3936
|
+
k: v.detach().clone().cpu()
|
|
3937
|
+
for k, v in state.get('initial_weights', {}).items()
|
|
3938
|
+
},
|
|
3939
|
+
'stats': dict(state.get('stats', {})),
|
|
3940
|
+
'phase': state.get('phase', 'idle'),
|
|
3941
|
+
}
|
|
3942
|
+
return cloned
|
|
3943
|
+
except Exception as e:
|
|
3944
|
+
logger.warning(f"[NODE] Could not snapshot DiLoCo state: {e}")
|
|
3945
|
+
return None
|
|
3946
|
+
|
|
3947
|
+
# ============ MODE 1: BULK ASYNC (plenty of memory) ============
|
|
3948
|
+
if async_save and can_bulk_async:
|
|
3949
|
+
logger.debug(f"[NODE] Checkpoint: BULK ASYNC (Free: {available_mb:.0f}MB)")
|
|
3950
|
+
|
|
3951
|
+
# Capture everything atomically
|
|
3952
|
+
checkpoint = {
|
|
3953
|
+
"node_id": self.node_id,
|
|
3954
|
+
"layer_ids": list(self.my_layer_ids),
|
|
3955
|
+
"architecture": self.model.architecture.to_dict(),
|
|
3956
|
+
"architecture_version": self.layer_pool.architecture_version if self.layer_pool else 1,
|
|
3957
|
+
"has_embedding": self.model.has_embedding,
|
|
3958
|
+
"has_lm_head": self.model.has_lm_head,
|
|
3959
|
+
"total_training_rounds": self.total_training_rounds,
|
|
3960
|
+
"current_loss": self.current_loss,
|
|
3961
|
+
"timestamp": time.time(),
|
|
3962
|
+
"layers": {
|
|
3963
|
+
layer_id: {k: v.clone().cpu() for k, v in layer.state_dict().items()}
|
|
3964
|
+
for layer_id, layer in self.model.my_layers.items()
|
|
3965
|
+
},
|
|
3966
|
+
}
|
|
3967
|
+
if self.model.embedding:
|
|
3968
|
+
checkpoint["embedding"] = {k: v.clone().cpu() for k, v in self.model.embedding.state_dict().items()}
|
|
3969
|
+
if self.model.lm_head:
|
|
3970
|
+
checkpoint["lm_head"] = {k: v.clone().cpu() for k, v in self.model.lm_head.state_dict().items()}
|
|
3971
|
+
if self.model.final_norm:
|
|
3972
|
+
checkpoint["final_norm"] = {k: v.clone().cpu() for k, v in self.model.final_norm.state_dict().items()}
|
|
3973
|
+
|
|
3974
|
+
# DiLoCo state captured AFTER model weights (both in same atomic snapshot)
|
|
3975
|
+
diloco_state = _clone_diloco_state()
|
|
3976
|
+
if diloco_state:
|
|
3977
|
+
checkpoint["diloco"] = diloco_state
|
|
3978
|
+
|
|
3979
|
+
def _do_bulk_save():
|
|
3980
|
+
try:
|
|
3981
|
+
torch.save(checkpoint, temp_path, _use_new_zipfile_serialization=False)
|
|
3982
|
+
import shutil
|
|
3983
|
+
shutil.move(str(temp_path), str(path))
|
|
3984
|
+
logger.info(f"[NODE] Checkpoint saved ({len(self.my_layer_ids)} layers)")
|
|
3985
|
+
except Exception as e:
|
|
3986
|
+
logger.error(f"[NODE] Checkpoint save failed: {e}")
|
|
3987
|
+
if temp_path.exists(): temp_path.unlink()
|
|
3988
|
+
finally:
|
|
3989
|
+
DynamicNeuroNode._checkpoint_save_lock.release()
|
|
3990
|
+
|
|
3991
|
+
# Use daemon=False so checkpoint completes even during shutdown
|
|
3992
|
+
threading.Thread(target=_do_bulk_save, daemon=False).start()
|
|
3993
|
+
return # Lock will be released by background thread
|
|
3994
|
+
|
|
3995
|
+
# ============ MODE 2: STREAMING ASYNC (memory-efficient) ============
|
|
3996
|
+
if async_save and can_stream_async:
|
|
3997
|
+
logger.info(f"[NODE] Checkpoint: STREAMING ASYNC (Free: {available_mb:.0f}MB, cloning {len(self.model.my_layers)} layers)")
|
|
3998
|
+
|
|
3999
|
+
# SNAPSHOT PHASE: Clone one layer at a time into a list
|
|
4000
|
+
# This brief pause (~1-2s) ensures consistency without needing full clone memory
|
|
4001
|
+
snapshot_start = time.time()
|
|
4002
|
+
|
|
4003
|
+
# Capture metadata first (lightweight)
|
|
4004
|
+
checkpoint_meta = {
|
|
4005
|
+
"node_id": self.node_id,
|
|
4006
|
+
"layer_ids": list(self.my_layer_ids),
|
|
4007
|
+
"architecture": self.model.architecture.to_dict(),
|
|
4008
|
+
"architecture_version": self.layer_pool.architecture_version if self.layer_pool else 1,
|
|
4009
|
+
"has_embedding": self.model.has_embedding,
|
|
4010
|
+
"has_lm_head": self.model.has_lm_head,
|
|
4011
|
+
"total_training_rounds": self.total_training_rounds,
|
|
4012
|
+
"current_loss": self.current_loss,
|
|
4013
|
+
"timestamp": time.time(),
|
|
4014
|
+
}
|
|
4015
|
+
|
|
4016
|
+
# Clone layers one at a time (memory efficient)
|
|
4017
|
+
layer_snapshots = []
|
|
4018
|
+
for layer_id, layer in self.model.my_layers.items():
|
|
4019
|
+
layer_state = {k: v.detach().clone().cpu() for k, v in layer.state_dict().items()}
|
|
4020
|
+
layer_snapshots.append((layer_id, layer_state))
|
|
4021
|
+
|
|
4022
|
+
# Clone special modules
|
|
4023
|
+
embedding_snapshot = None
|
|
4024
|
+
lm_head_snapshot = None
|
|
4025
|
+
final_norm_snapshot = None
|
|
4026
|
+
|
|
4027
|
+
if self.model.embedding:
|
|
4028
|
+
embedding_snapshot = {k: v.detach().clone().cpu() for k, v in self.model.embedding.state_dict().items()}
|
|
4029
|
+
if self.model.lm_head:
|
|
4030
|
+
lm_head_snapshot = {k: v.detach().clone().cpu() for k, v in self.model.lm_head.state_dict().items()}
|
|
4031
|
+
if self.model.final_norm:
|
|
4032
|
+
final_norm_snapshot = {k: v.detach().clone().cpu() for k, v in self.model.final_norm.state_dict().items()}
|
|
4033
|
+
|
|
4034
|
+
# DiLoCo state - captured in SAME snapshot window as model weights
|
|
4035
|
+
diloco_snapshot = _clone_diloco_state()
|
|
4036
|
+
|
|
4037
|
+
snapshot_time = time.time() - snapshot_start
|
|
4038
|
+
logger.debug(f"[NODE] Snapshot complete in {snapshot_time:.1f}s, starting async save")
|
|
4039
|
+
|
|
4040
|
+
# ASYNC SAVE PHASE: Write to disk in background thread
|
|
4041
|
+
# All data is now cloned and owned by this closure - safe for async
|
|
4042
|
+
def _do_streaming_save():
|
|
4043
|
+
try:
|
|
4044
|
+
checkpoint = dict(checkpoint_meta)
|
|
4045
|
+
checkpoint["layers"] = {lid: lstate for lid, lstate in layer_snapshots}
|
|
4046
|
+
if embedding_snapshot:
|
|
4047
|
+
checkpoint["embedding"] = embedding_snapshot
|
|
4048
|
+
if lm_head_snapshot:
|
|
4049
|
+
checkpoint["lm_head"] = lm_head_snapshot
|
|
4050
|
+
if final_norm_snapshot:
|
|
4051
|
+
checkpoint["final_norm"] = final_norm_snapshot
|
|
4052
|
+
if diloco_snapshot:
|
|
4053
|
+
checkpoint["diloco"] = diloco_snapshot
|
|
4054
|
+
|
|
4055
|
+
torch.save(checkpoint, temp_path, _use_new_zipfile_serialization=False)
|
|
4056
|
+
import shutil
|
|
4057
|
+
shutil.move(str(temp_path), str(path))
|
|
4058
|
+
logger.info(f"[NODE] Checkpoint saved ({len(layer_snapshots)} layers)")
|
|
4059
|
+
except Exception as e:
|
|
4060
|
+
logger.error(f"[NODE] Checkpoint save failed: {e}")
|
|
4061
|
+
if temp_path.exists(): temp_path.unlink()
|
|
4062
|
+
finally:
|
|
4063
|
+
DynamicNeuroNode._checkpoint_save_lock.release()
|
|
4064
|
+
|
|
4065
|
+
# Use daemon=False so checkpoint completes even during shutdown
|
|
4066
|
+
threading.Thread(target=_do_streaming_save, daemon=False).start()
|
|
4067
|
+
return # Lock will be released by background thread
|
|
4068
|
+
|
|
4069
|
+
# ============ MODE 3: SYNC (last resort, very low memory) ============
|
|
4070
|
+
logger.warning(f"[NODE] Checkpoint: SYNC mode (Free: {available_mb:.0f}MB < 500MB minimum)")
|
|
4071
|
+
|
|
4072
|
+
checkpoint = {
|
|
4073
|
+
"node_id": self.node_id,
|
|
4074
|
+
"layer_ids": list(self.my_layer_ids),
|
|
4075
|
+
"architecture": self.model.architecture.to_dict(),
|
|
4076
|
+
"architecture_version": self.layer_pool.architecture_version if self.layer_pool else 1,
|
|
4077
|
+
"has_embedding": self.model.has_embedding,
|
|
4078
|
+
"has_lm_head": self.model.has_lm_head,
|
|
4079
|
+
"total_training_rounds": self.total_training_rounds,
|
|
4080
|
+
"current_loss": self.current_loss,
|
|
4081
|
+
"timestamp": time.time(),
|
|
4082
|
+
"layers": {
|
|
4083
|
+
layer_id: layer.state_dict()
|
|
4084
|
+
for layer_id, layer in self.model.my_layers.items()
|
|
4085
|
+
},
|
|
4086
|
+
}
|
|
4087
|
+
if self.model.embedding:
|
|
4088
|
+
checkpoint["embedding"] = self.model.embedding.state_dict()
|
|
4089
|
+
if self.model.lm_head:
|
|
4090
|
+
checkpoint["lm_head"] = self.model.lm_head.state_dict()
|
|
4091
|
+
if self.model.final_norm:
|
|
4092
|
+
checkpoint["final_norm"] = self.model.final_norm.state_dict()
|
|
4093
|
+
|
|
4094
|
+
# DiLoCo state (no need to clone for sync - we block anyway)
|
|
4095
|
+
diloco_state = _clone_diloco_state()
|
|
4096
|
+
if diloco_state:
|
|
4097
|
+
checkpoint["diloco"] = diloco_state
|
|
4098
|
+
|
|
4099
|
+
torch.save(checkpoint, temp_path, _use_new_zipfile_serialization=False)
|
|
4100
|
+
import shutil
|
|
4101
|
+
shutil.move(str(temp_path), str(path))
|
|
4102
|
+
logger.info(f"[NODE] Checkpoint saved ({len(self.my_layer_ids)} layers)")
|
|
4103
|
+
|
|
4104
|
+
# Sync mode completed successfully, release lock
|
|
4105
|
+
DynamicNeuroNode._checkpoint_save_lock.release()
|
|
4106
|
+
|
|
4107
|
+
except Exception as e:
|
|
4108
|
+
logger.error(f"[NODE] Checkpoint preparation failed: {type(e).__name__}: {e}")
|
|
4109
|
+
try:
|
|
4110
|
+
if temp_path.exists(): temp_path.unlink()
|
|
4111
|
+
except:
|
|
4112
|
+
pass
|
|
4113
|
+
# Release lock on exception
|
|
4114
|
+
DynamicNeuroNode._checkpoint_save_lock.release()
|
|
4115
|
+
|
|
4116
|
+
|
|
4117
|
+
def create_dynamic_node(
|
|
4118
|
+
node_token: str,
|
|
4119
|
+
port: int = 8000,
|
|
4120
|
+
tracker_url: str = "https://neuroshard.com/api/tracker",
|
|
4121
|
+
available_memory_mb: Optional[float] = None,
|
|
4122
|
+
enable_training: bool = True,
|
|
4123
|
+
max_storage_mb: float = 100.0,
|
|
4124
|
+
max_cpu_threads: Optional[int] = None,
|
|
4125
|
+
device: str = "auto",
|
|
4126
|
+
p2p_manager: Optional[Any] = None, # NEW: Pass P2P for DHT discovery during layer assignment
|
|
4127
|
+
) -> DynamicNeuroNode:
|
|
4128
|
+
"""
|
|
4129
|
+
Create and start a dynamic node.
|
|
4130
|
+
|
|
4131
|
+
MULTI-NODE SUPPORT:
|
|
4132
|
+
If the same token is used on multiple machines or ports, each gets a unique
|
|
4133
|
+
node_id (based on machine + port) while sharing the same wallet_id (based on token).
|
|
4134
|
+
|
|
4135
|
+
This means:
|
|
4136
|
+
- Each physical node has a unique network identity
|
|
4137
|
+
- Earnings accumulate to the same NEURO wallet
|
|
4138
|
+
- No conflicts in DHT/layer assignments
|
|
4139
|
+
|
|
4140
|
+
FULLY DECENTRALIZED:
|
|
4141
|
+
If p2p_manager is provided, DHT is used for network discovery during layer
|
|
4142
|
+
assignment. No tracker fallbacks - pure P2P!
|
|
4143
|
+
"""
|
|
4144
|
+
from neuroshard.utils.hardware import get_instance_id
|
|
4145
|
+
|
|
4146
|
+
# Generate instance-specific node_id
|
|
4147
|
+
instance_id = get_instance_id(port)
|
|
4148
|
+
|
|
4149
|
+
# Combine token with instance for unique network identity
|
|
4150
|
+
# wallet_id (from token alone) is used for NEURO earnings
|
|
4151
|
+
# node_id (from token + instance) is used for network identity
|
|
4152
|
+
combined = f"{node_token}:{instance_id}"
|
|
4153
|
+
node_id = str(int(hashlib.sha256(combined.encode()).hexdigest(), 16))
|
|
4154
|
+
|
|
4155
|
+
# Log multi-node info
|
|
4156
|
+
wallet_id = hashlib.sha256(node_token.encode()).hexdigest()[:16]
|
|
4157
|
+
logger.info(f"Instance ID: {instance_id} (machine+port)")
|
|
4158
|
+
logger.info(f"Wallet ID: {wallet_id}... (for NEURO earnings)")
|
|
4159
|
+
logger.info(f"Node ID: {node_id[:16]}... (unique network identity)")
|
|
4160
|
+
|
|
4161
|
+
node = DynamicNeuroNode(
|
|
4162
|
+
node_id=node_id,
|
|
4163
|
+
port=port,
|
|
4164
|
+
tracker_url=tracker_url,
|
|
4165
|
+
node_token=node_token,
|
|
4166
|
+
available_memory_mb=available_memory_mb,
|
|
4167
|
+
enable_training=enable_training,
|
|
4168
|
+
max_storage_mb=max_storage_mb,
|
|
4169
|
+
max_cpu_threads=max_cpu_threads,
|
|
4170
|
+
device=device,
|
|
4171
|
+
)
|
|
4172
|
+
|
|
4173
|
+
# Store instance info for debugging
|
|
4174
|
+
node.instance_id = instance_id
|
|
4175
|
+
node.wallet_id = wallet_id
|
|
4176
|
+
|
|
4177
|
+
# CRITICAL: Connect P2P BEFORE start() so DHT is available for layer discovery!
|
|
4178
|
+
# This enables fully decentralized network discovery without tracker fallbacks.
|
|
4179
|
+
if p2p_manager:
|
|
4180
|
+
node.p2p_manager = p2p_manager
|
|
4181
|
+
logger.info("P2P connected BEFORE start - DHT available for network discovery")
|
|
4182
|
+
|
|
4183
|
+
node.start()
|
|
4184
|
+
|
|
4185
|
+
return node
|
|
4186
|
+
|