nexaroa 0.0.111__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- neuroshard/__init__.py +93 -0
- neuroshard/__main__.py +4 -0
- neuroshard/cli.py +466 -0
- neuroshard/core/__init__.py +92 -0
- neuroshard/core/consensus/verifier.py +252 -0
- neuroshard/core/crypto/__init__.py +20 -0
- neuroshard/core/crypto/ecdsa.py +392 -0
- neuroshard/core/economics/__init__.py +52 -0
- neuroshard/core/economics/constants.py +387 -0
- neuroshard/core/economics/ledger.py +2111 -0
- neuroshard/core/economics/market.py +975 -0
- neuroshard/core/economics/wallet.py +168 -0
- neuroshard/core/governance/__init__.py +74 -0
- neuroshard/core/governance/proposal.py +561 -0
- neuroshard/core/governance/registry.py +545 -0
- neuroshard/core/governance/versioning.py +332 -0
- neuroshard/core/governance/voting.py +453 -0
- neuroshard/core/model/__init__.py +30 -0
- neuroshard/core/model/dynamic.py +4186 -0
- neuroshard/core/model/llm.py +905 -0
- neuroshard/core/model/registry.py +164 -0
- neuroshard/core/model/scaler.py +387 -0
- neuroshard/core/model/tokenizer.py +568 -0
- neuroshard/core/network/__init__.py +56 -0
- neuroshard/core/network/connection_pool.py +72 -0
- neuroshard/core/network/dht.py +130 -0
- neuroshard/core/network/dht_plan.py +55 -0
- neuroshard/core/network/dht_proof_store.py +516 -0
- neuroshard/core/network/dht_protocol.py +261 -0
- neuroshard/core/network/dht_service.py +506 -0
- neuroshard/core/network/encrypted_channel.py +141 -0
- neuroshard/core/network/nat.py +201 -0
- neuroshard/core/network/nat_traversal.py +695 -0
- neuroshard/core/network/p2p.py +929 -0
- neuroshard/core/network/p2p_data.py +150 -0
- neuroshard/core/swarm/__init__.py +106 -0
- neuroshard/core/swarm/aggregation.py +729 -0
- neuroshard/core/swarm/buffers.py +643 -0
- neuroshard/core/swarm/checkpoint.py +709 -0
- neuroshard/core/swarm/compute.py +624 -0
- neuroshard/core/swarm/diloco.py +844 -0
- neuroshard/core/swarm/factory.py +1288 -0
- neuroshard/core/swarm/heartbeat.py +669 -0
- neuroshard/core/swarm/logger.py +487 -0
- neuroshard/core/swarm/router.py +658 -0
- neuroshard/core/swarm/service.py +640 -0
- neuroshard/core/training/__init__.py +29 -0
- neuroshard/core/training/checkpoint.py +600 -0
- neuroshard/core/training/distributed.py +1602 -0
- neuroshard/core/training/global_tracker.py +617 -0
- neuroshard/core/training/production.py +276 -0
- neuroshard/governance_cli.py +729 -0
- neuroshard/grpc_server.py +895 -0
- neuroshard/runner.py +3223 -0
- neuroshard/sdk/__init__.py +92 -0
- neuroshard/sdk/client.py +990 -0
- neuroshard/sdk/errors.py +101 -0
- neuroshard/sdk/types.py +282 -0
- neuroshard/tracker/__init__.py +0 -0
- neuroshard/tracker/server.py +864 -0
- neuroshard/ui/__init__.py +0 -0
- neuroshard/ui/app.py +102 -0
- neuroshard/ui/templates/index.html +1052 -0
- neuroshard/utils/__init__.py +0 -0
- neuroshard/utils/autostart.py +81 -0
- neuroshard/utils/hardware.py +121 -0
- neuroshard/utils/serialization.py +90 -0
- neuroshard/version.py +1 -0
- nexaroa-0.0.111.dist-info/METADATA +283 -0
- nexaroa-0.0.111.dist-info/RECORD +78 -0
- nexaroa-0.0.111.dist-info/WHEEL +5 -0
- nexaroa-0.0.111.dist-info/entry_points.txt +4 -0
- nexaroa-0.0.111.dist-info/licenses/LICENSE +190 -0
- nexaroa-0.0.111.dist-info/top_level.txt +2 -0
- protos/__init__.py +0 -0
- protos/neuroshard.proto +651 -0
- protos/neuroshard_pb2.py +160 -0
- protos/neuroshard_pb2_grpc.py +1298 -0
|
@@ -0,0 +1,1602 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Distributed Training System for NeuroLLM
|
|
3
|
+
|
|
4
|
+
Implements decentralized training where:
|
|
5
|
+
1. Nodes contribute compute for forward/backward passes
|
|
6
|
+
2. Gradients are aggregated via gossip protocol
|
|
7
|
+
3. Training rewards are distributed in NEURO tokens
|
|
8
|
+
4. Model checkpoints are shared across the network
|
|
9
|
+
|
|
10
|
+
Key Components:
|
|
11
|
+
- GradientAggregator: Collects and averages gradients from peers
|
|
12
|
+
- TrainingCoordinator: Orchestrates distributed training
|
|
13
|
+
- DataContributor: Handles federated dataset management
|
|
14
|
+
- RewardCalculator: Computes NEURO rewards for contributions
|
|
15
|
+
|
|
16
|
+
Training Flow:
|
|
17
|
+
1. Coordinator broadcasts current model state hash
|
|
18
|
+
2. Nodes with matching state participate in training round
|
|
19
|
+
3. Each node processes local data batch
|
|
20
|
+
4. Gradients are compressed and gossiped
|
|
21
|
+
5. Aggregated gradients are applied
|
|
22
|
+
6. New checkpoint is created and distributed
|
|
23
|
+
7. NEURO rewards are calculated and distributed
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
import torch
|
|
27
|
+
import torch.nn as nn
|
|
28
|
+
import threading
|
|
29
|
+
import time
|
|
30
|
+
import hashlib
|
|
31
|
+
from concurrent.futures import ThreadPoolExecutor
|
|
32
|
+
import logging
|
|
33
|
+
import json
|
|
34
|
+
import io
|
|
35
|
+
import zlib
|
|
36
|
+
import base64
|
|
37
|
+
import os
|
|
38
|
+
import requests
|
|
39
|
+
from typing import Dict, List, Optional, Tuple, Any, Callable
|
|
40
|
+
from dataclasses import dataclass, field
|
|
41
|
+
from enum import Enum
|
|
42
|
+
from collections import defaultdict
|
|
43
|
+
|
|
44
|
+
# Import economics constants for consistency
|
|
45
|
+
from neuroshard.core.economics.constants import (
|
|
46
|
+
TRAINING_REWARD_PER_BATCH,
|
|
47
|
+
DATA_REWARD_PER_SAMPLE
|
|
48
|
+
)
|
|
49
|
+
|
|
50
|
+
logger = logging.getLogger(__name__)
|
|
51
|
+
|
|
52
|
+
|
|
53
|
+
class TrainingState(Enum):
|
|
54
|
+
"""State of the training coordinator."""
|
|
55
|
+
IDLE = "idle"
|
|
56
|
+
COLLECTING = "collecting" # Collecting gradients from peers
|
|
57
|
+
AGGREGATING = "aggregating" # Aggregating gradients
|
|
58
|
+
APPLYING = "applying" # Applying updates
|
|
59
|
+
CHECKPOINTING = "checkpointing" # Creating checkpoint
|
|
60
|
+
|
|
61
|
+
|
|
62
|
+
@dataclass
|
|
63
|
+
class GradientContribution:
|
|
64
|
+
"""A gradient contribution from a node."""
|
|
65
|
+
node_id: str
|
|
66
|
+
round_id: int
|
|
67
|
+
layer_gradients: Dict[str, bytes] # layer_name -> compressed gradient
|
|
68
|
+
batch_size: int
|
|
69
|
+
loss: float
|
|
70
|
+
timestamp: float
|
|
71
|
+
signature: str # Proof of work
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
@dataclass
|
|
75
|
+
class TrainingRound:
|
|
76
|
+
"""A single training round."""
|
|
77
|
+
round_id: int
|
|
78
|
+
started_at: float
|
|
79
|
+
model_hash: str
|
|
80
|
+
|
|
81
|
+
# Contributions
|
|
82
|
+
contributions: Dict[str, GradientContribution] = field(default_factory=dict)
|
|
83
|
+
min_contributions: int = 3
|
|
84
|
+
max_contributions: int = 100
|
|
85
|
+
|
|
86
|
+
# Results
|
|
87
|
+
aggregated_gradients: Optional[Dict[str, torch.Tensor]] = None
|
|
88
|
+
total_batch_size: int = 0
|
|
89
|
+
avg_loss: float = 0.0
|
|
90
|
+
|
|
91
|
+
# State
|
|
92
|
+
completed: bool = False
|
|
93
|
+
applied: bool = False
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
@dataclass
|
|
97
|
+
class TrainingReward:
|
|
98
|
+
"""Reward for training contribution."""
|
|
99
|
+
node_id: str
|
|
100
|
+
round_id: int
|
|
101
|
+
compute_reward: float # For compute contribution
|
|
102
|
+
data_reward: float # For data contribution
|
|
103
|
+
quality_bonus: float # For high-quality gradients
|
|
104
|
+
total_neuro: float
|
|
105
|
+
|
|
106
|
+
|
|
107
|
+
class GradientCompressor:
|
|
108
|
+
"""
|
|
109
|
+
Compresses gradients for efficient network transmission.
|
|
110
|
+
|
|
111
|
+
Uses a combination of:
|
|
112
|
+
1. Top-K sparsification (keep only largest gradients)
|
|
113
|
+
2. Quantization (reduce precision)
|
|
114
|
+
3. zlib compression
|
|
115
|
+
"""
|
|
116
|
+
|
|
117
|
+
def __init__(self, top_k_ratio: float = 0.1, bits: int = 8):
|
|
118
|
+
self.top_k_ratio = top_k_ratio
|
|
119
|
+
self.bits = bits
|
|
120
|
+
|
|
121
|
+
def compress(self, gradient: torch.Tensor) -> bytes:
|
|
122
|
+
"""Compress a gradient tensor."""
|
|
123
|
+
# CRITICAL: Move to CPU first for MPS/CUDA compatibility
|
|
124
|
+
gradient = gradient.detach().cpu()
|
|
125
|
+
|
|
126
|
+
# Flatten
|
|
127
|
+
flat = gradient.flatten()
|
|
128
|
+
|
|
129
|
+
# Top-K sparsification
|
|
130
|
+
k = max(1, int(len(flat) * self.top_k_ratio))
|
|
131
|
+
values, indices = torch.topk(flat.abs(), k)
|
|
132
|
+
|
|
133
|
+
# Get actual values (with signs)
|
|
134
|
+
sparse_values = flat[indices]
|
|
135
|
+
|
|
136
|
+
# Quantize to specified bits
|
|
137
|
+
max_val = sparse_values.abs().max()
|
|
138
|
+
if max_val > 0:
|
|
139
|
+
scale = (2 ** (self.bits - 1) - 1) / max_val
|
|
140
|
+
quantized = (sparse_values * scale).round().to(torch.int8)
|
|
141
|
+
else:
|
|
142
|
+
quantized = torch.zeros(k, dtype=torch.int8)
|
|
143
|
+
scale = 1.0
|
|
144
|
+
|
|
145
|
+
# Pack into bytes (tensors already on CPU)
|
|
146
|
+
data = {
|
|
147
|
+
"shape": list(gradient.shape),
|
|
148
|
+
"k": k,
|
|
149
|
+
"indices": base64.b64encode(indices.numpy().tobytes()).decode('ascii'),
|
|
150
|
+
"values": base64.b64encode(quantized.numpy().tobytes()).decode('ascii'),
|
|
151
|
+
"scale": float(scale),
|
|
152
|
+
"dtype": str(gradient.dtype),
|
|
153
|
+
}
|
|
154
|
+
|
|
155
|
+
# Serialize and compress
|
|
156
|
+
json_data = json.dumps(data).encode()
|
|
157
|
+
return zlib.compress(json_data)
|
|
158
|
+
|
|
159
|
+
def decompress(self, data: bytes, device: str = "cpu") -> torch.Tensor:
|
|
160
|
+
"""Decompress a gradient tensor."""
|
|
161
|
+
# Decompress and deserialize
|
|
162
|
+
json_data = zlib.decompress(data)
|
|
163
|
+
packed = json.loads(json_data)
|
|
164
|
+
|
|
165
|
+
# Unpack
|
|
166
|
+
shape = packed["shape"]
|
|
167
|
+
k = packed["k"]
|
|
168
|
+
indices = torch.frombuffer(
|
|
169
|
+
bytearray(base64.b64decode(packed["indices"])),
|
|
170
|
+
dtype=torch.int64
|
|
171
|
+
).clone().to(device)
|
|
172
|
+
values = torch.frombuffer(
|
|
173
|
+
bytearray(base64.b64decode(packed["values"])),
|
|
174
|
+
dtype=torch.int8
|
|
175
|
+
).float().clone().to(device)
|
|
176
|
+
scale = packed["scale"]
|
|
177
|
+
|
|
178
|
+
# Dequantize
|
|
179
|
+
values = values / scale
|
|
180
|
+
|
|
181
|
+
# Reconstruct sparse tensor
|
|
182
|
+
flat = torch.zeros(torch.prod(torch.tensor(shape)), device=device)
|
|
183
|
+
flat[indices] = values
|
|
184
|
+
|
|
185
|
+
return flat.view(*shape)
|
|
186
|
+
|
|
187
|
+
|
|
188
|
+
class GradientAggregator:
|
|
189
|
+
"""
|
|
190
|
+
Aggregates gradients from multiple nodes.
|
|
191
|
+
|
|
192
|
+
Supports:
|
|
193
|
+
- Simple averaging
|
|
194
|
+
- Weighted averaging (by batch size)
|
|
195
|
+
- Robust aggregation (median, trimmed mean)
|
|
196
|
+
"""
|
|
197
|
+
|
|
198
|
+
def __init__(self, method: str = "weighted_mean"):
|
|
199
|
+
self.method = method
|
|
200
|
+
self.compressor = GradientCompressor()
|
|
201
|
+
|
|
202
|
+
def aggregate(
|
|
203
|
+
self,
|
|
204
|
+
contributions: List[GradientContribution],
|
|
205
|
+
layer_names: List[str]
|
|
206
|
+
) -> Dict[str, torch.Tensor]:
|
|
207
|
+
"""
|
|
208
|
+
Aggregate gradients from multiple contributions.
|
|
209
|
+
|
|
210
|
+
Args:
|
|
211
|
+
contributions: List of gradient contributions
|
|
212
|
+
layer_names: Names of layers to aggregate
|
|
213
|
+
|
|
214
|
+
Returns:
|
|
215
|
+
Aggregated gradients per layer
|
|
216
|
+
"""
|
|
217
|
+
if not contributions:
|
|
218
|
+
return {}
|
|
219
|
+
|
|
220
|
+
aggregated = {}
|
|
221
|
+
total_batch_size = sum(c.batch_size for c in contributions)
|
|
222
|
+
|
|
223
|
+
for layer_name in layer_names:
|
|
224
|
+
# Collect gradients for this layer
|
|
225
|
+
gradients = []
|
|
226
|
+
weights = []
|
|
227
|
+
|
|
228
|
+
for contrib in contributions:
|
|
229
|
+
if layer_name in contrib.layer_gradients:
|
|
230
|
+
grad = self.compressor.decompress(contrib.layer_gradients[layer_name])
|
|
231
|
+
gradients.append(grad)
|
|
232
|
+
weights.append(contrib.batch_size)
|
|
233
|
+
|
|
234
|
+
if not gradients:
|
|
235
|
+
continue
|
|
236
|
+
|
|
237
|
+
# Stack gradients
|
|
238
|
+
stacked = torch.stack(gradients)
|
|
239
|
+
|
|
240
|
+
if self.method == "mean":
|
|
241
|
+
aggregated[layer_name] = stacked.mean(dim=0)
|
|
242
|
+
|
|
243
|
+
elif self.method == "weighted_mean":
|
|
244
|
+
weights_tensor = torch.tensor(weights, dtype=torch.float32)
|
|
245
|
+
weights_tensor = weights_tensor / weights_tensor.sum()
|
|
246
|
+
aggregated[layer_name] = (stacked * weights_tensor.view(-1, *([1] * (stacked.dim() - 1)))).sum(dim=0)
|
|
247
|
+
|
|
248
|
+
elif self.method == "median":
|
|
249
|
+
aggregated[layer_name] = stacked.median(dim=0)[0]
|
|
250
|
+
|
|
251
|
+
elif self.method == "trimmed_mean":
|
|
252
|
+
# Remove top and bottom 10%
|
|
253
|
+
k = max(1, len(gradients) // 10)
|
|
254
|
+
sorted_grads = stacked.sort(dim=0)[0]
|
|
255
|
+
aggregated[layer_name] = sorted_grads[k:-k].mean(dim=0) if k < len(gradients) // 2 else stacked.mean(dim=0)
|
|
256
|
+
|
|
257
|
+
return aggregated
|
|
258
|
+
|
|
259
|
+
|
|
260
|
+
class TrainingCoordinator:
|
|
261
|
+
"""
|
|
262
|
+
Coordinates distributed training across the network.
|
|
263
|
+
|
|
264
|
+
Responsibilities:
|
|
265
|
+
1. Initiate training rounds
|
|
266
|
+
2. Collect gradient contributions
|
|
267
|
+
3. Aggregate and apply updates
|
|
268
|
+
4. Distribute rewards
|
|
269
|
+
5. Manage checkpoints
|
|
270
|
+
|
|
271
|
+
NOTE: This class is LEGACY and not currently used in production.
|
|
272
|
+
The active reward path uses economics.py constants via ledger.py
|
|
273
|
+
"""
|
|
274
|
+
|
|
275
|
+
# Configuration
|
|
276
|
+
ROUND_DURATION_SECONDS = 60
|
|
277
|
+
MIN_CONTRIBUTIONS = 3
|
|
278
|
+
GRADIENT_CLIP_NORM = 1.0
|
|
279
|
+
|
|
280
|
+
# Reward rates (using economics.py constants for consistency)
|
|
281
|
+
# NOTE: LEGACY - These are kept for backwards compatibility but not actively used
|
|
282
|
+
# Import at class level to match economics.py values
|
|
283
|
+
from neuroshard.core.economics.constants import TRAINING_REWARD_PER_BATCH as COMPUTE_REWARD_PER_BATCH
|
|
284
|
+
from neuroshard.core.economics.constants import DATA_REWARD_PER_SAMPLE
|
|
285
|
+
QUALITY_BONUS_MULTIPLIER = 1.5
|
|
286
|
+
|
|
287
|
+
def __init__(
|
|
288
|
+
self,
|
|
289
|
+
model: nn.Module,
|
|
290
|
+
optimizer: torch.optim.Optimizer,
|
|
291
|
+
node_id: str,
|
|
292
|
+
ledger_manager = None,
|
|
293
|
+
on_round_complete: Optional[Callable] = None
|
|
294
|
+
):
|
|
295
|
+
self.model = model
|
|
296
|
+
self.optimizer = optimizer
|
|
297
|
+
self.node_id = node_id
|
|
298
|
+
self.ledger = ledger_manager
|
|
299
|
+
self.on_round_complete = on_round_complete
|
|
300
|
+
|
|
301
|
+
# State
|
|
302
|
+
self.state = TrainingState.IDLE
|
|
303
|
+
self.current_round: Optional[TrainingRound] = None
|
|
304
|
+
self.round_history: List[TrainingRound] = []
|
|
305
|
+
self.global_step = 0
|
|
306
|
+
|
|
307
|
+
# Components
|
|
308
|
+
self.aggregator = GradientAggregator()
|
|
309
|
+
self.compressor = GradientCompressor()
|
|
310
|
+
|
|
311
|
+
# Threading
|
|
312
|
+
self.lock = threading.Lock()
|
|
313
|
+
self.running = False
|
|
314
|
+
|
|
315
|
+
# Stats
|
|
316
|
+
self.total_rounds = 0
|
|
317
|
+
self.total_contributions = 0
|
|
318
|
+
self.total_neuro_distributed = 0.0
|
|
319
|
+
|
|
320
|
+
logger.info(f"TrainingCoordinator initialized for node {node_id}")
|
|
321
|
+
|
|
322
|
+
def start(self):
|
|
323
|
+
"""Start the training coordinator."""
|
|
324
|
+
self.running = True
|
|
325
|
+
threading.Thread(target=self._training_loop, daemon=True).start()
|
|
326
|
+
logger.info("Training coordinator started")
|
|
327
|
+
|
|
328
|
+
def stop(self):
|
|
329
|
+
"""Stop the training coordinator."""
|
|
330
|
+
self.running = False
|
|
331
|
+
|
|
332
|
+
def _training_loop(self):
|
|
333
|
+
"""Main training loop."""
|
|
334
|
+
while self.running:
|
|
335
|
+
try:
|
|
336
|
+
if self.state == TrainingState.IDLE:
|
|
337
|
+
# Start new round
|
|
338
|
+
self._start_round()
|
|
339
|
+
|
|
340
|
+
elif self.state == TrainingState.COLLECTING:
|
|
341
|
+
# Check if round should complete
|
|
342
|
+
if self._should_complete_round():
|
|
343
|
+
self._complete_round()
|
|
344
|
+
|
|
345
|
+
time.sleep(1)
|
|
346
|
+
|
|
347
|
+
except Exception as e:
|
|
348
|
+
logger.error(f"Training loop error: {e}")
|
|
349
|
+
time.sleep(5)
|
|
350
|
+
|
|
351
|
+
def _get_model_hash(self) -> str:
|
|
352
|
+
"""Get hash of current model state."""
|
|
353
|
+
state_dict = self.model.state_dict()
|
|
354
|
+
hasher = hashlib.sha256()
|
|
355
|
+
|
|
356
|
+
for name, param in sorted(state_dict.items()):
|
|
357
|
+
hasher.update(name.encode())
|
|
358
|
+
hasher.update(param.cpu().numpy().tobytes()[:1000]) # Sample for speed
|
|
359
|
+
|
|
360
|
+
return hasher.hexdigest()[:16]
|
|
361
|
+
|
|
362
|
+
def _start_round(self):
|
|
363
|
+
"""Start a new training round."""
|
|
364
|
+
with self.lock:
|
|
365
|
+
self.total_rounds += 1
|
|
366
|
+
|
|
367
|
+
self.current_round = TrainingRound(
|
|
368
|
+
round_id=self.total_rounds,
|
|
369
|
+
started_at=time.time(),
|
|
370
|
+
model_hash=self._get_model_hash(),
|
|
371
|
+
min_contributions=self.MIN_CONTRIBUTIONS,
|
|
372
|
+
)
|
|
373
|
+
|
|
374
|
+
self.state = TrainingState.COLLECTING
|
|
375
|
+
|
|
376
|
+
logger.info(f"Started training round {self.total_rounds}")
|
|
377
|
+
|
|
378
|
+
def _should_complete_round(self) -> bool:
|
|
379
|
+
"""Check if current round should complete."""
|
|
380
|
+
if not self.current_round:
|
|
381
|
+
return False
|
|
382
|
+
|
|
383
|
+
# Time limit
|
|
384
|
+
elapsed = time.time() - self.current_round.started_at
|
|
385
|
+
if elapsed >= self.ROUND_DURATION_SECONDS:
|
|
386
|
+
return True
|
|
387
|
+
|
|
388
|
+
# Max contributions
|
|
389
|
+
if len(self.current_round.contributions) >= self.current_round.max_contributions:
|
|
390
|
+
return True
|
|
391
|
+
|
|
392
|
+
return False
|
|
393
|
+
|
|
394
|
+
def _complete_round(self):
|
|
395
|
+
"""Complete the current training round."""
|
|
396
|
+
if not self.current_round:
|
|
397
|
+
return
|
|
398
|
+
|
|
399
|
+
with self.lock:
|
|
400
|
+
round_data = self.current_round
|
|
401
|
+
|
|
402
|
+
# Check minimum contributions
|
|
403
|
+
if len(round_data.contributions) < round_data.min_contributions:
|
|
404
|
+
logger.warning(f"Round {round_data.round_id} failed: insufficient contributions "
|
|
405
|
+
f"({len(round_data.contributions)}/{round_data.min_contributions})")
|
|
406
|
+
self.state = TrainingState.IDLE
|
|
407
|
+
self.current_round = None
|
|
408
|
+
return
|
|
409
|
+
|
|
410
|
+
self.state = TrainingState.AGGREGATING
|
|
411
|
+
|
|
412
|
+
logger.info(f"Completing round {round_data.round_id} with {len(round_data.contributions)} contributions")
|
|
413
|
+
|
|
414
|
+
# Aggregate gradients
|
|
415
|
+
layer_names = [name for name, _ in self.model.named_parameters()]
|
|
416
|
+
aggregated = self.aggregator.aggregate(
|
|
417
|
+
list(round_data.contributions.values()),
|
|
418
|
+
layer_names
|
|
419
|
+
)
|
|
420
|
+
|
|
421
|
+
round_data.aggregated_gradients = aggregated
|
|
422
|
+
round_data.total_batch_size = sum(c.batch_size for c in round_data.contributions.values())
|
|
423
|
+
round_data.avg_loss = sum(c.loss for c in round_data.contributions.values()) / len(round_data.contributions)
|
|
424
|
+
|
|
425
|
+
# Apply gradients
|
|
426
|
+
with self.lock:
|
|
427
|
+
self.state = TrainingState.APPLYING
|
|
428
|
+
|
|
429
|
+
self._apply_gradients(aggregated)
|
|
430
|
+
round_data.applied = True
|
|
431
|
+
|
|
432
|
+
# Calculate and distribute rewards
|
|
433
|
+
rewards = self._calculate_rewards(round_data)
|
|
434
|
+
self._distribute_rewards(rewards)
|
|
435
|
+
|
|
436
|
+
# Checkpoint
|
|
437
|
+
with self.lock:
|
|
438
|
+
self.state = TrainingState.CHECKPOINTING
|
|
439
|
+
|
|
440
|
+
self._create_checkpoint(round_data)
|
|
441
|
+
|
|
442
|
+
# Complete
|
|
443
|
+
round_data.completed = True
|
|
444
|
+
self.round_history.append(round_data)
|
|
445
|
+
|
|
446
|
+
# Keep only last 100 rounds
|
|
447
|
+
if len(self.round_history) > 100:
|
|
448
|
+
self.round_history = self.round_history[-100:]
|
|
449
|
+
|
|
450
|
+
# Callback
|
|
451
|
+
if self.on_round_complete:
|
|
452
|
+
self.on_round_complete(round_data)
|
|
453
|
+
|
|
454
|
+
# Reset
|
|
455
|
+
with self.lock:
|
|
456
|
+
self.current_round = None
|
|
457
|
+
self.state = TrainingState.IDLE
|
|
458
|
+
self.global_step += 1
|
|
459
|
+
|
|
460
|
+
logger.info(f"Round {round_data.round_id} complete: loss={round_data.avg_loss:.4f}, "
|
|
461
|
+
f"batch_size={round_data.total_batch_size}")
|
|
462
|
+
|
|
463
|
+
def _apply_gradients(self, gradients: Dict[str, torch.Tensor]):
|
|
464
|
+
"""Apply aggregated gradients to model."""
|
|
465
|
+
self.optimizer.zero_grad()
|
|
466
|
+
|
|
467
|
+
for name, param in self.model.named_parameters():
|
|
468
|
+
if name in gradients:
|
|
469
|
+
if param.grad is None:
|
|
470
|
+
param.grad = gradients[name].to(param.device)
|
|
471
|
+
else:
|
|
472
|
+
param.grad.copy_(gradients[name])
|
|
473
|
+
|
|
474
|
+
# Gradient clipping
|
|
475
|
+
torch.nn.utils.clip_grad_norm_(self.model.parameters(), self.GRADIENT_CLIP_NORM)
|
|
476
|
+
|
|
477
|
+
# Apply
|
|
478
|
+
self.optimizer.step()
|
|
479
|
+
|
|
480
|
+
def _calculate_rewards(self, round_data: TrainingRound) -> List[TrainingReward]:
|
|
481
|
+
"""Calculate NEURO rewards for contributions."""
|
|
482
|
+
rewards = []
|
|
483
|
+
|
|
484
|
+
# Calculate average loss for quality comparison
|
|
485
|
+
avg_loss = round_data.avg_loss
|
|
486
|
+
|
|
487
|
+
for node_id, contrib in round_data.contributions.items():
|
|
488
|
+
# Base compute reward
|
|
489
|
+
compute_reward = contrib.batch_size * self.COMPUTE_REWARD_PER_BATCH
|
|
490
|
+
|
|
491
|
+
# Data contribution reward
|
|
492
|
+
data_reward = contrib.batch_size * self.DATA_REWARD_PER_SAMPLE
|
|
493
|
+
|
|
494
|
+
# Quality bonus (lower loss = better)
|
|
495
|
+
quality_bonus = 0.0
|
|
496
|
+
if contrib.loss < avg_loss:
|
|
497
|
+
quality_bonus = compute_reward * (self.QUALITY_BONUS_MULTIPLIER - 1)
|
|
498
|
+
|
|
499
|
+
total = compute_reward + data_reward + quality_bonus
|
|
500
|
+
|
|
501
|
+
rewards.append(TrainingReward(
|
|
502
|
+
node_id=node_id,
|
|
503
|
+
round_id=round_data.round_id,
|
|
504
|
+
compute_reward=compute_reward,
|
|
505
|
+
data_reward=data_reward,
|
|
506
|
+
quality_bonus=quality_bonus,
|
|
507
|
+
total_neuro=total
|
|
508
|
+
))
|
|
509
|
+
|
|
510
|
+
self.total_neuro_distributed += total
|
|
511
|
+
|
|
512
|
+
return rewards
|
|
513
|
+
|
|
514
|
+
def _distribute_rewards(self, rewards: List[TrainingReward]):
|
|
515
|
+
"""Distribute NEURO rewards to contributors using PoNW proofs."""
|
|
516
|
+
if not self.ledger:
|
|
517
|
+
logger.debug("No ledger available for reward distribution")
|
|
518
|
+
return
|
|
519
|
+
|
|
520
|
+
for reward in rewards:
|
|
521
|
+
try:
|
|
522
|
+
from neuroshard.core.economics.ledger import PoNWProof, ProofType
|
|
523
|
+
import time
|
|
524
|
+
|
|
525
|
+
# Create a proper training PoNW proof
|
|
526
|
+
proof = PoNWProof(
|
|
527
|
+
node_id=reward.node_id,
|
|
528
|
+
proof_type=ProofType.TRAINING.value,
|
|
529
|
+
timestamp=time.time(),
|
|
530
|
+
nonce=f"train_{reward.round_id}_{reward.node_id[:16]}",
|
|
531
|
+
training_batches=int(reward.compute_reward / self.COMPUTE_REWARD_PER_BATCH),
|
|
532
|
+
data_samples=int(reward.data_reward / self.DATA_REWARD_PER_SAMPLE),
|
|
533
|
+
signature=f"training_reward_{reward.round_id}_{reward.node_id}"
|
|
534
|
+
)
|
|
535
|
+
|
|
536
|
+
# Process through the ledger (handles deduplication, stats, etc.)
|
|
537
|
+
success, amount, msg = self.ledger.process_proof(proof)
|
|
538
|
+
|
|
539
|
+
if success:
|
|
540
|
+
logger.info(f"Reward: {reward.node_id[:8]}... earned {amount:.6f} NEURO "
|
|
541
|
+
f"(compute={reward.compute_reward:.6f}, data={reward.data_reward:.6f}, "
|
|
542
|
+
f"quality={reward.quality_bonus:.6f})")
|
|
543
|
+
else:
|
|
544
|
+
logger.debug(f"Training reward not processed: {msg}")
|
|
545
|
+
|
|
546
|
+
except Exception as e:
|
|
547
|
+
logger.error(f"Failed to distribute reward to {reward.node_id}: {e}")
|
|
548
|
+
|
|
549
|
+
def _create_checkpoint(self, round_data: TrainingRound):
|
|
550
|
+
"""Create a checkpoint after training round."""
|
|
551
|
+
checkpoint_path = f"checkpoints/neuro_llm_round_{round_data.round_id}.pt"
|
|
552
|
+
|
|
553
|
+
try:
|
|
554
|
+
import os
|
|
555
|
+
os.makedirs("checkpoints", exist_ok=True)
|
|
556
|
+
|
|
557
|
+
torch.save({
|
|
558
|
+
"model_state_dict": self.model.state_dict(),
|
|
559
|
+
"optimizer_state_dict": self.optimizer.state_dict(),
|
|
560
|
+
"round_id": round_data.round_id,
|
|
561
|
+
"global_step": self.global_step,
|
|
562
|
+
"model_hash": self._get_model_hash(),
|
|
563
|
+
"avg_loss": round_data.avg_loss,
|
|
564
|
+
"total_batch_size": round_data.total_batch_size,
|
|
565
|
+
"timestamp": time.time(),
|
|
566
|
+
}, checkpoint_path)
|
|
567
|
+
|
|
568
|
+
logger.info(f"Checkpoint saved: {checkpoint_path}")
|
|
569
|
+
|
|
570
|
+
except Exception as e:
|
|
571
|
+
logger.error(f"Failed to save checkpoint: {e}")
|
|
572
|
+
|
|
573
|
+
def submit_contribution(self, contribution: GradientContribution) -> bool:
|
|
574
|
+
"""
|
|
575
|
+
Submit a gradient contribution for the current round.
|
|
576
|
+
|
|
577
|
+
Called by peers when they have computed gradients.
|
|
578
|
+
"""
|
|
579
|
+
with self.lock:
|
|
580
|
+
if self.state != TrainingState.COLLECTING:
|
|
581
|
+
return False
|
|
582
|
+
|
|
583
|
+
if not self.current_round:
|
|
584
|
+
return False
|
|
585
|
+
|
|
586
|
+
# Verify model hash matches
|
|
587
|
+
# In production, this would be more sophisticated
|
|
588
|
+
|
|
589
|
+
# Add contribution
|
|
590
|
+
self.current_round.contributions[contribution.node_id] = contribution
|
|
591
|
+
self.total_contributions += 1
|
|
592
|
+
|
|
593
|
+
logger.debug(f"Received contribution from {contribution.node_id[:8]}... "
|
|
594
|
+
f"(batch_size={contribution.batch_size}, loss={contribution.loss:.4f})")
|
|
595
|
+
|
|
596
|
+
return True
|
|
597
|
+
|
|
598
|
+
def compute_local_gradients(
|
|
599
|
+
self,
|
|
600
|
+
input_ids: torch.Tensor,
|
|
601
|
+
labels: torch.Tensor
|
|
602
|
+
) -> GradientContribution:
|
|
603
|
+
"""
|
|
604
|
+
Compute gradients on local data.
|
|
605
|
+
|
|
606
|
+
Call this to participate in training.
|
|
607
|
+
"""
|
|
608
|
+
self.model.train()
|
|
609
|
+
|
|
610
|
+
# Forward pass
|
|
611
|
+
outputs = self.model(input_ids=input_ids, labels=labels)
|
|
612
|
+
loss = outputs["loss"]
|
|
613
|
+
|
|
614
|
+
# Backward pass
|
|
615
|
+
loss.backward()
|
|
616
|
+
|
|
617
|
+
# Collect and compress gradients
|
|
618
|
+
layer_gradients = {}
|
|
619
|
+
for name, param in self.model.named_parameters():
|
|
620
|
+
if param.grad is not None:
|
|
621
|
+
layer_gradients[name] = self.compressor.compress(param.grad)
|
|
622
|
+
|
|
623
|
+
# Clear gradients (they're saved in contribution)
|
|
624
|
+
self.optimizer.zero_grad()
|
|
625
|
+
|
|
626
|
+
# Create contribution
|
|
627
|
+
contribution = GradientContribution(
|
|
628
|
+
node_id=self.node_id,
|
|
629
|
+
round_id=self.current_round.round_id if self.current_round else 0,
|
|
630
|
+
layer_gradients=layer_gradients,
|
|
631
|
+
batch_size=input_ids.shape[0],
|
|
632
|
+
loss=loss.item(),
|
|
633
|
+
timestamp=time.time(),
|
|
634
|
+
signature=self._sign_contribution(layer_gradients)
|
|
635
|
+
)
|
|
636
|
+
|
|
637
|
+
return contribution
|
|
638
|
+
|
|
639
|
+
def _sign_contribution(self, gradients: Dict[str, bytes]) -> str:
|
|
640
|
+
"""Sign a contribution for verification."""
|
|
641
|
+
hasher = hashlib.sha256()
|
|
642
|
+
hasher.update(self.node_id.encode())
|
|
643
|
+
hasher.update(str(time.time()).encode())
|
|
644
|
+
for name, data in sorted(gradients.items()):
|
|
645
|
+
hasher.update(name.encode())
|
|
646
|
+
hasher.update(data[:100]) # Sample for speed
|
|
647
|
+
return hasher.hexdigest()
|
|
648
|
+
|
|
649
|
+
def get_status(self) -> Dict[str, Any]:
|
|
650
|
+
"""Get coordinator status."""
|
|
651
|
+
return {
|
|
652
|
+
"state": self.state.value,
|
|
653
|
+
"global_step": self.global_step,
|
|
654
|
+
"total_rounds": self.total_rounds,
|
|
655
|
+
"total_contributions": self.total_contributions,
|
|
656
|
+
"total_neuro_distributed": self.total_neuro_distributed,
|
|
657
|
+
"current_round": {
|
|
658
|
+
"round_id": self.current_round.round_id,
|
|
659
|
+
"contributions": len(self.current_round.contributions),
|
|
660
|
+
"elapsed": time.time() - self.current_round.started_at,
|
|
661
|
+
"model_hash": self.current_round.model_hash,
|
|
662
|
+
} if self.current_round else None,
|
|
663
|
+
"recent_rounds": [
|
|
664
|
+
{
|
|
665
|
+
"round_id": r.round_id,
|
|
666
|
+
"contributions": len(r.contributions),
|
|
667
|
+
"avg_loss": r.avg_loss,
|
|
668
|
+
"total_batch_size": r.total_batch_size,
|
|
669
|
+
}
|
|
670
|
+
for r in self.round_history[-10:]
|
|
671
|
+
]
|
|
672
|
+
}
|
|
673
|
+
|
|
674
|
+
|
|
675
|
+
class GenesisDataLoader:
|
|
676
|
+
"""
|
|
677
|
+
Loads training data from the verified Genesis Dataset.
|
|
678
|
+
|
|
679
|
+
Features:
|
|
680
|
+
- Dynamic shard count (reads from manifest)
|
|
681
|
+
- User-configurable storage limit (max_storage_mb)
|
|
682
|
+
- Shard rotation (cycles through dataset over time)
|
|
683
|
+
- Multi-shard support (downloads multiple shards up to storage limit)
|
|
684
|
+
- ASYNC PREFETCHING: Pre-downloads next shard while training on current
|
|
685
|
+
|
|
686
|
+
Active only for nodes holding Layer 0 (Embedding Layer).
|
|
687
|
+
|
|
688
|
+
Data Source: CloudFront CDN (backed by S3)
|
|
689
|
+
"""
|
|
690
|
+
# CloudFront CDN URL - single source of truth (cached, DDoS protected)
|
|
691
|
+
GENESIS_CDN_URL = "https://dwquwt9gkkeil.cloudfront.net"
|
|
692
|
+
# Size per shard in MB (must match populate_genesis_s3.py)
|
|
693
|
+
SHARD_SIZE_MB = 10
|
|
694
|
+
|
|
695
|
+
def __init__(
|
|
696
|
+
self,
|
|
697
|
+
node_id: str,
|
|
698
|
+
tokenizer,
|
|
699
|
+
cache_dir: str = None, # Default to ~/.neuroshard/data_cache
|
|
700
|
+
max_storage_mb: float = 100.0, # User-configurable limit
|
|
701
|
+
manifest_version: int = 1
|
|
702
|
+
):
|
|
703
|
+
self.node_id = node_id
|
|
704
|
+
self.tokenizer = tokenizer
|
|
705
|
+
|
|
706
|
+
# Default cache_dir to ~/.neuroshard/data_cache for consistent storage
|
|
707
|
+
if cache_dir is None:
|
|
708
|
+
cache_dir = os.path.join(os.path.expanduser("~"), ".neuroshard", "data_cache")
|
|
709
|
+
self.cache_dir = cache_dir
|
|
710
|
+
self.max_storage_mb = max_storage_mb
|
|
711
|
+
self.manifest_version = manifest_version
|
|
712
|
+
|
|
713
|
+
# CloudFront CDN manifest URL - single source of truth
|
|
714
|
+
self.manifest_url = f"{self.GENESIS_CDN_URL}/manifest.json"
|
|
715
|
+
|
|
716
|
+
# Manifest data (cached, refreshed periodically)
|
|
717
|
+
self.manifest = None
|
|
718
|
+
self.total_shards = 0
|
|
719
|
+
self.manifest_last_fetch = 0
|
|
720
|
+
self.MANIFEST_REFRESH_INTERVAL = 600 # Refresh manifest every 10 minutes (auto-update tokenizer)
|
|
721
|
+
|
|
722
|
+
# Shard management
|
|
723
|
+
self.max_shards = max(1, int(max_storage_mb / self.SHARD_SIZE_MB))
|
|
724
|
+
self.assigned_shard_ids = [] # List of shard IDs this node is responsible for
|
|
725
|
+
self.loaded_shards = {} # shard_id -> tensor data
|
|
726
|
+
self.current_shard_idx = 0 # Index into assigned_shard_ids for rotation
|
|
727
|
+
self.shard_rotation_count = 0 # How many times we've rotated through
|
|
728
|
+
self.loading_shards = set() # Track shards currently being downloaded
|
|
729
|
+
self._shard_lock = threading.Lock() # Lock for shard loading
|
|
730
|
+
self._download_executor = ThreadPoolExecutor(max_workers=3, thread_name_prefix="shard-download")
|
|
731
|
+
|
|
732
|
+
# ASYNC PREFETCHING: Keep next shard(s) ready in background
|
|
733
|
+
self._prefetch_in_progress = set() # Shard IDs being prefetched
|
|
734
|
+
self._prefetch_ready = {} # shard_id -> tensor data (ready to use)
|
|
735
|
+
self._prefetch_ahead = 2 # Number of shards to prefetch ahead (was 1)
|
|
736
|
+
|
|
737
|
+
# LOSS PLATEAU DETECTION: Track loss to detect when to rotate shards early
|
|
738
|
+
self._loss_history = [] # Recent loss values
|
|
739
|
+
self._loss_history_max = 50 # Number of loss values to track
|
|
740
|
+
self._loss_plateau_threshold = 0.02 # If loss variance < this, plateau detected
|
|
741
|
+
self._min_steps_per_shard = 100 # Minimum steps before considering early rotation
|
|
742
|
+
self._steps_on_current_shard = 0 # Steps taken on current shard
|
|
743
|
+
|
|
744
|
+
# Initialize Data Swarm for P2P downloading
|
|
745
|
+
self.swarm = None
|
|
746
|
+
|
|
747
|
+
self.current_dataset = None
|
|
748
|
+
self.dataset_iterator = 0
|
|
749
|
+
|
|
750
|
+
# Fetch manifest and assign initial shards
|
|
751
|
+
self._refresh_manifest()
|
|
752
|
+
self._assign_shards()
|
|
753
|
+
|
|
754
|
+
# Try to load learned tokenizer from CDN (for proper vocab)
|
|
755
|
+
self._load_learned_tokenizer()
|
|
756
|
+
|
|
757
|
+
# THUNDERING HERD PREVENTION: Add random jitter before first download
|
|
758
|
+
# This spreads load across the CDN when many nodes start simultaneously
|
|
759
|
+
# Jitter: 0-5 seconds based on node_id hash
|
|
760
|
+
import random
|
|
761
|
+
jitter_seed = int(hashlib.sha256(self.node_id.encode()).hexdigest()[:8], 16)
|
|
762
|
+
jitter_seconds = (jitter_seed % 5000) / 1000.0 # 0-5 seconds
|
|
763
|
+
|
|
764
|
+
def delayed_prefetch():
|
|
765
|
+
time.sleep(jitter_seconds)
|
|
766
|
+
self._start_prefetch_next()
|
|
767
|
+
|
|
768
|
+
# Start prefetching first shard with jitter (non-blocking)
|
|
769
|
+
threading.Thread(target=delayed_prefetch, daemon=True).start()
|
|
770
|
+
|
|
771
|
+
logger.info(f"GenesisDataLoader initialized: "
|
|
772
|
+
f"total_shards={self.total_shards}, "
|
|
773
|
+
f"max_storage={max_storage_mb}MB ({self.max_shards} shards), "
|
|
774
|
+
f"assigned={self.assigned_shard_ids[:5]}{'...' if len(self.assigned_shard_ids) > 5 else ''}, "
|
|
775
|
+
f"prefetch_jitter={jitter_seconds:.2f}s")
|
|
776
|
+
|
|
777
|
+
def _load_learned_tokenizer(self):
|
|
778
|
+
"""
|
|
779
|
+
Download and load the learned tokenizer from Genesis CDN.
|
|
780
|
+
Checks if the network has learned more tokens and updates locally.
|
|
781
|
+
"""
|
|
782
|
+
try:
|
|
783
|
+
tokenizer_url = f"{self.GENESIS_CDN_URL}/tokenizer.json"
|
|
784
|
+
tokenizer_cache_path = os.path.join(self.cache_dir, "tokenizer.json")
|
|
785
|
+
|
|
786
|
+
# Always try to fetch latest from CDN
|
|
787
|
+
try:
|
|
788
|
+
logger.debug(f"[GENESIS] Checking for tokenizer updates from {tokenizer_url}...")
|
|
789
|
+
resp = requests.get(tokenizer_url, timeout=10)
|
|
790
|
+
|
|
791
|
+
if resp.status_code == 200:
|
|
792
|
+
remote_tokenizer_data = resp.json()
|
|
793
|
+
remote_vocab_size = remote_tokenizer_data.get("next_merge_id", 0)
|
|
794
|
+
|
|
795
|
+
# Always cache the downloaded tokenizer (for offline use)
|
|
796
|
+
os.makedirs(self.cache_dir, exist_ok=True)
|
|
797
|
+
with open(tokenizer_cache_path, 'w') as f:
|
|
798
|
+
f.write(resp.text)
|
|
799
|
+
|
|
800
|
+
# Update our tokenizer if remote has more tokens
|
|
801
|
+
if remote_vocab_size > self.tokenizer.next_merge_id:
|
|
802
|
+
logger.info(f"[GENESIS] Found improved tokenizer! ({self.tokenizer.next_merge_id} -> {remote_vocab_size} tokens)")
|
|
803
|
+
|
|
804
|
+
from neuroshard.core.model.tokenizer import NeuroTokenizer
|
|
805
|
+
learned_tokenizer = NeuroTokenizer.load(tokenizer_cache_path)
|
|
806
|
+
|
|
807
|
+
self.tokenizer.merges = learned_tokenizer.merges
|
|
808
|
+
self.tokenizer.merge_to_tokens = learned_tokenizer.merge_to_tokens
|
|
809
|
+
self.tokenizer.next_merge_id = learned_tokenizer.next_merge_id
|
|
810
|
+
|
|
811
|
+
logger.info(f"[GENESIS] Tokenizer updated: {self.tokenizer.next_merge_id} tokens")
|
|
812
|
+
else:
|
|
813
|
+
logger.info(f"[GENESIS] Tokenizer cached: {remote_vocab_size} tokens (current: {self.tokenizer.next_merge_id})")
|
|
814
|
+
return
|
|
815
|
+
except Exception as e:
|
|
816
|
+
logger.debug(f"[GENESIS] Failed to check for tokenizer updates: {e}")
|
|
817
|
+
|
|
818
|
+
# Fallback to cached version if download failed
|
|
819
|
+
if os.path.exists(tokenizer_cache_path) and self.tokenizer.next_merge_id <= 266:
|
|
820
|
+
logger.info(f"[GENESIS] Loading cached tokenizer from {tokenizer_cache_path}")
|
|
821
|
+
try:
|
|
822
|
+
from neuroshard.core.model.tokenizer import NeuroTokenizer
|
|
823
|
+
learned_tokenizer = NeuroTokenizer.load(tokenizer_cache_path)
|
|
824
|
+
|
|
825
|
+
if learned_tokenizer.next_merge_id > self.tokenizer.next_merge_id:
|
|
826
|
+
self.tokenizer.merges = learned_tokenizer.merges
|
|
827
|
+
self.tokenizer.merge_to_tokens = learned_tokenizer.merge_to_tokens
|
|
828
|
+
self.tokenizer.next_merge_id = learned_tokenizer.next_merge_id
|
|
829
|
+
logger.info(f"[GENESIS] Loaded cached tokenizer: {self.tokenizer.next_merge_id} tokens")
|
|
830
|
+
except Exception as e:
|
|
831
|
+
logger.warning(f"[GENESIS] Failed to load cached tokenizer: {e}")
|
|
832
|
+
|
|
833
|
+
except Exception as e:
|
|
834
|
+
logger.warning(f"[GENESIS] Error managing tokenizer: {e}")
|
|
835
|
+
|
|
836
|
+
def _refresh_manifest_sync(self):
|
|
837
|
+
"""Synchronous manifest fetch (runs in background thread)."""
|
|
838
|
+
try:
|
|
839
|
+
logger.info(f"[GENESIS] Fetching manifest from {self.manifest_url}...")
|
|
840
|
+
resp = requests.get(self.manifest_url, timeout=15)
|
|
841
|
+
if resp.status_code == 200:
|
|
842
|
+
manifest_data = resp.json()
|
|
843
|
+
total_shards = manifest_data.get("total_shards", 0)
|
|
844
|
+
|
|
845
|
+
# Update state atomically
|
|
846
|
+
with self._shard_lock:
|
|
847
|
+
self.manifest = manifest_data
|
|
848
|
+
self.total_shards = total_shards
|
|
849
|
+
self.manifest_last_fetch = time.time()
|
|
850
|
+
|
|
851
|
+
logger.info(f"[GENESIS] Manifest loaded: {self.total_shards} shards available")
|
|
852
|
+
|
|
853
|
+
# Also check if tokenizer has improved (in background)
|
|
854
|
+
self._load_learned_tokenizer()
|
|
855
|
+
else:
|
|
856
|
+
logger.error(f"[GENESIS] Failed to fetch manifest: HTTP {resp.status_code}")
|
|
857
|
+
logger.error(f"[GENESIS] Response: {resp.text[:200]}")
|
|
858
|
+
except Exception as e:
|
|
859
|
+
logger.error(f"[GENESIS] Failed to fetch manifest from {self.manifest_url}: {type(e).__name__}: {e}")
|
|
860
|
+
import traceback
|
|
861
|
+
logger.error(f"[GENESIS] Traceback: {traceback.format_exc()}")
|
|
862
|
+
|
|
863
|
+
def _refresh_manifest(self):
|
|
864
|
+
"""Fetch latest manifest from S3 (non-blocking after first load)."""
|
|
865
|
+
now = time.time()
|
|
866
|
+
|
|
867
|
+
# First time initialization - must be synchronous
|
|
868
|
+
if self.manifest is None:
|
|
869
|
+
self._refresh_manifest_sync()
|
|
870
|
+
if self.total_shards == 0:
|
|
871
|
+
raise RuntimeError(f"Cannot fetch manifest from {self.manifest_url}. Check S3 bucket.")
|
|
872
|
+
return
|
|
873
|
+
|
|
874
|
+
# Subsequent refreshes - use cached if recent
|
|
875
|
+
if (now - self.manifest_last_fetch) < self.MANIFEST_REFRESH_INTERVAL:
|
|
876
|
+
return # Use cached manifest
|
|
877
|
+
|
|
878
|
+
# Refresh in background (non-blocking)
|
|
879
|
+
self._download_executor.submit(self._refresh_manifest_sync)
|
|
880
|
+
|
|
881
|
+
def _assign_shards(self):
|
|
882
|
+
"""
|
|
883
|
+
Assign shards to this node based on:
|
|
884
|
+
1. Node's deterministic hash (ensures different nodes get different shards)
|
|
885
|
+
2. User's storage limit (max_shards)
|
|
886
|
+
3. Rotation offset (allows cycling through entire dataset over time)
|
|
887
|
+
"""
|
|
888
|
+
if self.total_shards == 0:
|
|
889
|
+
self.assigned_shard_ids = [0]
|
|
890
|
+
return
|
|
891
|
+
|
|
892
|
+
# Base offset from node ID (deterministic)
|
|
893
|
+
node_hash = int(hashlib.sha256(self.node_id.encode()).hexdigest(), 16)
|
|
894
|
+
base_offset = node_hash % self.total_shards
|
|
895
|
+
|
|
896
|
+
# Rotation offset (changes over time to cover more data)
|
|
897
|
+
rotation_offset = (self.shard_rotation_count * self.max_shards) % self.total_shards
|
|
898
|
+
|
|
899
|
+
# Assign shards starting from (base + rotation) offset
|
|
900
|
+
self.assigned_shard_ids = []
|
|
901
|
+
for i in range(self.max_shards):
|
|
902
|
+
shard_id = (base_offset + rotation_offset + i) % self.total_shards
|
|
903
|
+
self.assigned_shard_ids.append(shard_id)
|
|
904
|
+
|
|
905
|
+
logger.info(f"Assigned {len(self.assigned_shard_ids)} shards: "
|
|
906
|
+
f"{self.assigned_shard_ids[:5]}{'...' if len(self.assigned_shard_ids) > 5 else ''}")
|
|
907
|
+
|
|
908
|
+
def rotate_shards(self):
|
|
909
|
+
"""
|
|
910
|
+
Rotate to next set of shards.
|
|
911
|
+
Call this periodically to train on different parts of the dataset.
|
|
912
|
+
"""
|
|
913
|
+
# Clear old loaded shards to free memory
|
|
914
|
+
old_shards = list(self.loaded_shards.keys())
|
|
915
|
+
self.loaded_shards.clear()
|
|
916
|
+
self.current_dataset = None
|
|
917
|
+
self.dataset_iterator = 0
|
|
918
|
+
|
|
919
|
+
# Increment rotation counter
|
|
920
|
+
self.shard_rotation_count += 1
|
|
921
|
+
|
|
922
|
+
# Refresh manifest (in case new shards were added)
|
|
923
|
+
self._refresh_manifest()
|
|
924
|
+
|
|
925
|
+
# Reassign shards with new rotation offset
|
|
926
|
+
self._assign_shards()
|
|
927
|
+
|
|
928
|
+
# Clean up old shard files from disk
|
|
929
|
+
self._cleanup_old_shards(old_shards)
|
|
930
|
+
|
|
931
|
+
logger.info(f"Rotated to new shards (rotation #{self.shard_rotation_count})")
|
|
932
|
+
|
|
933
|
+
def _cleanup_old_shards(self, old_shard_ids: list):
|
|
934
|
+
"""Remove old shard files from disk to stay within storage limit."""
|
|
935
|
+
for shard_id in old_shard_ids:
|
|
936
|
+
if shard_id not in self.assigned_shard_ids:
|
|
937
|
+
shard_path = os.path.join(self.cache_dir, f"genesis_shard_{shard_id}.pt")
|
|
938
|
+
try:
|
|
939
|
+
if os.path.exists(shard_path):
|
|
940
|
+
os.remove(shard_path)
|
|
941
|
+
logger.debug(f"Cleaned up old shard: {shard_path}")
|
|
942
|
+
except Exception as e:
|
|
943
|
+
logger.warning(f"Failed to cleanup shard {shard_id}: {e}")
|
|
944
|
+
|
|
945
|
+
def set_swarm(self, swarm):
|
|
946
|
+
"""Set the DataSwarm instance."""
|
|
947
|
+
self.swarm = swarm
|
|
948
|
+
|
|
949
|
+
def record_loss(self, loss: float):
|
|
950
|
+
"""
|
|
951
|
+
Record a training loss for plateau detection.
|
|
952
|
+
|
|
953
|
+
Call this from the training loop to enable adaptive shard rotation.
|
|
954
|
+
When loss plateaus, the loader will rotate to fresh data.
|
|
955
|
+
"""
|
|
956
|
+
self._loss_history.append(loss)
|
|
957
|
+
if len(self._loss_history) > self._loss_history_max:
|
|
958
|
+
self._loss_history.pop(0)
|
|
959
|
+
self._steps_on_current_shard += 1
|
|
960
|
+
|
|
961
|
+
def _should_rotate_early(self) -> bool:
|
|
962
|
+
"""
|
|
963
|
+
Check if we should rotate to a new shard early due to loss plateau.
|
|
964
|
+
|
|
965
|
+
Conditions for early rotation:
|
|
966
|
+
1. Have enough loss samples (at least 20)
|
|
967
|
+
2. Minimum steps on current shard (100)
|
|
968
|
+
3. Loss has plateaued (low variance)
|
|
969
|
+
4. Loss is low enough that we're not still actively learning
|
|
970
|
+
"""
|
|
971
|
+
if len(self._loss_history) < 20:
|
|
972
|
+
return False
|
|
973
|
+
|
|
974
|
+
if self._steps_on_current_shard < self._min_steps_per_shard:
|
|
975
|
+
return False
|
|
976
|
+
|
|
977
|
+
# Calculate loss statistics
|
|
978
|
+
recent_losses = self._loss_history[-20:]
|
|
979
|
+
avg_loss = sum(recent_losses) / len(recent_losses)
|
|
980
|
+
variance = sum((l - avg_loss) ** 2 for l in recent_losses) / len(recent_losses)
|
|
981
|
+
|
|
982
|
+
# Also check if loss is decreasing (don't rotate if still improving)
|
|
983
|
+
if len(self._loss_history) >= 40:
|
|
984
|
+
older_avg = sum(self._loss_history[-40:-20]) / 20
|
|
985
|
+
improvement = older_avg - avg_loss
|
|
986
|
+
|
|
987
|
+
# Still improving significantly - don't rotate
|
|
988
|
+
if improvement > 0.005:
|
|
989
|
+
return False
|
|
990
|
+
|
|
991
|
+
# Plateau detected: low variance AND low absolute loss
|
|
992
|
+
if variance < self._loss_plateau_threshold and avg_loss < 0.05:
|
|
993
|
+
logger.info(f"[GENESIS] Loss plateau detected: avg={avg_loss:.4f}, variance={variance:.6f}")
|
|
994
|
+
logger.info(f"[GENESIS] Rotating to fresh data for continued learning")
|
|
995
|
+
return True
|
|
996
|
+
|
|
997
|
+
return False
|
|
998
|
+
|
|
999
|
+
def force_shard_rotation(self, reason: str = "manual"):
|
|
1000
|
+
"""
|
|
1001
|
+
Force rotation to a new shard.
|
|
1002
|
+
|
|
1003
|
+
Call this when you want to move to fresh data (e.g., loss plateau).
|
|
1004
|
+
"""
|
|
1005
|
+
logger.info(f"[GENESIS] Forcing shard rotation: {reason}")
|
|
1006
|
+
self._loss_history.clear()
|
|
1007
|
+
self._steps_on_current_shard = 0
|
|
1008
|
+
|
|
1009
|
+
# Move to next shard
|
|
1010
|
+
self.current_shard_idx += 1
|
|
1011
|
+
|
|
1012
|
+
if self.current_shard_idx >= len(self.assigned_shard_ids):
|
|
1013
|
+
# We've gone through all assigned shards - rotate to new set
|
|
1014
|
+
logger.info(f"[GENESIS] Exhausted all {len(self.assigned_shard_ids)} assigned shards. Getting new set...")
|
|
1015
|
+
self.rotate_shards()
|
|
1016
|
+
|
|
1017
|
+
# Reset dataset iterator to start fresh
|
|
1018
|
+
self.current_dataset = None
|
|
1019
|
+
self.dataset_iterator = 0
|
|
1020
|
+
|
|
1021
|
+
# Start prefetching the new shard
|
|
1022
|
+
self._start_prefetch_next()
|
|
1023
|
+
|
|
1024
|
+
def _start_prefetch_next(self):
|
|
1025
|
+
"""Start prefetching the next shard(s) in background."""
|
|
1026
|
+
if not self.assigned_shard_ids:
|
|
1027
|
+
return
|
|
1028
|
+
|
|
1029
|
+
# Prefetch current and multiple next shards for faster data access
|
|
1030
|
+
shards_to_prefetch = []
|
|
1031
|
+
for offset in range(self._prefetch_ahead + 1): # Current + prefetch_ahead (default: 0, 1, 2)
|
|
1032
|
+
idx = (self.current_shard_idx + offset) % len(self.assigned_shard_ids)
|
|
1033
|
+
shard_id = self.assigned_shard_ids[idx]
|
|
1034
|
+
|
|
1035
|
+
with self._shard_lock:
|
|
1036
|
+
# Skip if already loaded, prefetching, or ready
|
|
1037
|
+
if (shard_id in self.loaded_shards or
|
|
1038
|
+
shard_id in self._prefetch_in_progress or
|
|
1039
|
+
shard_id in self._prefetch_ready or
|
|
1040
|
+
shard_id in self.loading_shards):
|
|
1041
|
+
continue
|
|
1042
|
+
|
|
1043
|
+
# Limit total prefetch in progress to avoid overwhelming the system
|
|
1044
|
+
if len(self._prefetch_in_progress) >= 3:
|
|
1045
|
+
break
|
|
1046
|
+
|
|
1047
|
+
shards_to_prefetch.append(shard_id)
|
|
1048
|
+
self._prefetch_in_progress.add(shard_id)
|
|
1049
|
+
|
|
1050
|
+
# Start downloads in background
|
|
1051
|
+
for shard_id in shards_to_prefetch:
|
|
1052
|
+
target_url = self.get_shard_url(shard_id)
|
|
1053
|
+
logger.debug(f"Prefetching shard {shard_id} in background...")
|
|
1054
|
+
self._download_executor.submit(self._prefetch_shard_sync, shard_id, target_url)
|
|
1055
|
+
|
|
1056
|
+
def _prefetch_shard_sync(self, shard_id: int, target_url: str):
|
|
1057
|
+
"""Synchronous shard prefetch (runs in background thread)."""
|
|
1058
|
+
try:
|
|
1059
|
+
logger.info(f"[GENESIS] Downloading shard {shard_id}...")
|
|
1060
|
+
# Download the Shard
|
|
1061
|
+
shard_path = None
|
|
1062
|
+
|
|
1063
|
+
if self.swarm:
|
|
1064
|
+
try:
|
|
1065
|
+
shard_path = self.swarm.download_shard(shard_id, manifest_url=target_url)
|
|
1066
|
+
logger.info(f"[GENESIS] Swarm download succeeded for shard {shard_id}")
|
|
1067
|
+
except Exception as e:
|
|
1068
|
+
logger.warning(f"[GENESIS] Swarm prefetch failed: {e}")
|
|
1069
|
+
|
|
1070
|
+
if not shard_path:
|
|
1071
|
+
logger.info(f"[GENESIS] Using HTTP fallback for shard {shard_id}")
|
|
1072
|
+
shard_path = self._http_fallback_download(shard_id, target_url)
|
|
1073
|
+
logger.info(f"[GENESIS] HTTP download completed for shard {shard_id}")
|
|
1074
|
+
|
|
1075
|
+
# Load tensor into prefetch buffer
|
|
1076
|
+
tensor_data = torch.load(shard_path, weights_only=True)
|
|
1077
|
+
|
|
1078
|
+
with self._shard_lock:
|
|
1079
|
+
# DYNAMIC MEMORY LIMIT: Based on user's max_storage_mb setting
|
|
1080
|
+
# Each shard is ~10MB compressed on disk, ~100-200MB uncompressed in RAM
|
|
1081
|
+
# Calculate max shards we can keep in memory
|
|
1082
|
+
shard_size_mb = 150 # Conservative estimate per shard in RAM
|
|
1083
|
+
max_cached_shards = max(3, int(self.max_storage_mb / shard_size_mb))
|
|
1084
|
+
|
|
1085
|
+
total_loaded = len(self.loaded_shards) + len(self._prefetch_ready)
|
|
1086
|
+
if total_loaded >= max_cached_shards:
|
|
1087
|
+
# Clear oldest loaded shard (not the current one)
|
|
1088
|
+
current_shard = self.assigned_shard_ids[self.current_shard_idx % len(self.assigned_shard_ids)] if self.assigned_shard_ids else None
|
|
1089
|
+
for old_shard_id in list(self.loaded_shards.keys()):
|
|
1090
|
+
if old_shard_id != current_shard:
|
|
1091
|
+
del self.loaded_shards[old_shard_id]
|
|
1092
|
+
logger.debug(f"Evicted shard {old_shard_id} from cache (limit: {max_cached_shards} shards)")
|
|
1093
|
+
break
|
|
1094
|
+
|
|
1095
|
+
self._prefetch_ready[shard_id] = tensor_data
|
|
1096
|
+
self._prefetch_in_progress.discard(shard_id)
|
|
1097
|
+
|
|
1098
|
+
logger.info(f"[GENESIS] Shard {shard_id} ready: {len(tensor_data):,} tokens")
|
|
1099
|
+
|
|
1100
|
+
except Exception as e:
|
|
1101
|
+
logger.error(f"[GENESIS] Download FAILED for shard {shard_id}: {type(e).__name__}: {e}")
|
|
1102
|
+
import traceback
|
|
1103
|
+
logger.error(f"[GENESIS] Traceback: {traceback.format_exc()}")
|
|
1104
|
+
with self._shard_lock:
|
|
1105
|
+
self._prefetch_in_progress.discard(shard_id)
|
|
1106
|
+
|
|
1107
|
+
def is_data_ready(self) -> bool:
|
|
1108
|
+
"""Check if data is ready for training (non-blocking check)."""
|
|
1109
|
+
# Try to acquire lock with timeout to prevent blocking training loop
|
|
1110
|
+
acquired = self._shard_lock.acquire(timeout=0.5)
|
|
1111
|
+
if not acquired:
|
|
1112
|
+
# Lock held by download thread - assume data might be ready soon
|
|
1113
|
+
logger.debug("[GENESIS] Lock contention in is_data_ready - skipping check")
|
|
1114
|
+
return False
|
|
1115
|
+
|
|
1116
|
+
try:
|
|
1117
|
+
# Data ready if we have current dataset OR prefetched shard is ready
|
|
1118
|
+
if self.current_dataset is not None and len(self.current_dataset) > 0:
|
|
1119
|
+
return True
|
|
1120
|
+
|
|
1121
|
+
# Check if ANY assigned shard is ready (not just current)
|
|
1122
|
+
# This handles the case where prefetch completes before is_data_ready is called
|
|
1123
|
+
if self._prefetch_ready:
|
|
1124
|
+
# A prefetched shard is ready - we can use it
|
|
1125
|
+
return True
|
|
1126
|
+
|
|
1127
|
+
# Also check loaded_shards
|
|
1128
|
+
if self.loaded_shards:
|
|
1129
|
+
return True
|
|
1130
|
+
|
|
1131
|
+
# Check if current shard is specifically ready
|
|
1132
|
+
if self.assigned_shard_ids:
|
|
1133
|
+
shard_id = self.assigned_shard_ids[self.current_shard_idx % len(self.assigned_shard_ids)]
|
|
1134
|
+
if shard_id in self._prefetch_ready:
|
|
1135
|
+
return True
|
|
1136
|
+
if shard_id in self.loaded_shards:
|
|
1137
|
+
return True
|
|
1138
|
+
|
|
1139
|
+
return False
|
|
1140
|
+
finally:
|
|
1141
|
+
self._shard_lock.release()
|
|
1142
|
+
|
|
1143
|
+
def get_shard_url(self, shard_id: int) -> str:
|
|
1144
|
+
"""Get download URL for a specific shard (always use CDN)."""
|
|
1145
|
+
# Always use CDN URL regardless of what manifest says
|
|
1146
|
+
# This ensures we go through CloudFront for caching/security
|
|
1147
|
+
return f"{self.GENESIS_CDN_URL}/shard_{shard_id}.pt"
|
|
1148
|
+
|
|
1149
|
+
def _load_shard_sync(self, shard_id: int, target_url: str):
|
|
1150
|
+
"""Synchronous shard loading (runs in background thread)."""
|
|
1151
|
+
# Download the Shard (Swarm or HTTP)
|
|
1152
|
+
shard_path = None
|
|
1153
|
+
|
|
1154
|
+
if self.swarm:
|
|
1155
|
+
try:
|
|
1156
|
+
shard_path = self.swarm.download_shard(shard_id, manifest_url=target_url)
|
|
1157
|
+
except Exception as e:
|
|
1158
|
+
logger.error(f"Swarm download failed: {e}")
|
|
1159
|
+
|
|
1160
|
+
if not shard_path:
|
|
1161
|
+
shard_path = self._http_fallback_download(shard_id, target_url)
|
|
1162
|
+
|
|
1163
|
+
# Load tensor
|
|
1164
|
+
try:
|
|
1165
|
+
tensor_data = torch.load(shard_path, weights_only=True)
|
|
1166
|
+
with self._shard_lock:
|
|
1167
|
+
self.loaded_shards[shard_id] = tensor_data
|
|
1168
|
+
self.current_dataset = tensor_data
|
|
1169
|
+
self.dataset_iterator = 0
|
|
1170
|
+
self.loading_shards.discard(shard_id)
|
|
1171
|
+
logger.info(f"Loaded Shard {shard_id}: {len(tensor_data)} tokens")
|
|
1172
|
+
except Exception as e:
|
|
1173
|
+
logger.error(f"Failed to load shard {shard_path}: {e}")
|
|
1174
|
+
with self._shard_lock:
|
|
1175
|
+
self.loading_shards.discard(shard_id)
|
|
1176
|
+
# Create dummy data if all else fails (use valid byte tokens 10-265)
|
|
1177
|
+
self.current_dataset = torch.randint(10, 266, (10000,), dtype=torch.long)
|
|
1178
|
+
|
|
1179
|
+
def ensure_shard_loaded(self, shard_id: int = None):
|
|
1180
|
+
"""
|
|
1181
|
+
Download and load a shard if not present.
|
|
1182
|
+
Opportunistically switches to ANY ready shard if the target isn't ready.
|
|
1183
|
+
"""
|
|
1184
|
+
target_shard_id = shard_id
|
|
1185
|
+
|
|
1186
|
+
if target_shard_id is None:
|
|
1187
|
+
# Default: try current shard in rotation
|
|
1188
|
+
if not self.assigned_shard_ids:
|
|
1189
|
+
return
|
|
1190
|
+
target_shard_id = self.assigned_shard_ids[self.current_shard_idx % len(self.assigned_shard_ids)]
|
|
1191
|
+
|
|
1192
|
+
with self._shard_lock:
|
|
1193
|
+
# 1. Check if target is ready (Fastest)
|
|
1194
|
+
if target_shard_id in self.loaded_shards:
|
|
1195
|
+
self.current_dataset = self.loaded_shards[target_shard_id]
|
|
1196
|
+
return
|
|
1197
|
+
|
|
1198
|
+
# 2. Check if target is in prefetch buffer
|
|
1199
|
+
if target_shard_id in self._prefetch_ready:
|
|
1200
|
+
self.current_dataset = self._prefetch_ready.pop(target_shard_id)
|
|
1201
|
+
self.loaded_shards[target_shard_id] = self.current_dataset
|
|
1202
|
+
self.dataset_iterator = 0
|
|
1203
|
+
logger.info(f"Using prefetched shard {target_shard_id}: {len(self.current_dataset)} tokens")
|
|
1204
|
+
self._start_prefetch_next_unlocked()
|
|
1205
|
+
return
|
|
1206
|
+
|
|
1207
|
+
# 3. OPPORTUNISTIC: If target isn't ready, check if ANY assigned shard is ready in prefetch
|
|
1208
|
+
# This prevents blocking on shard A when shard B is already downloaded
|
|
1209
|
+
if shard_id is None: # Only if caller didn't request specific shard
|
|
1210
|
+
for ready_id in list(self._prefetch_ready.keys()):
|
|
1211
|
+
if ready_id in self.assigned_shard_ids:
|
|
1212
|
+
# Switch to this ready shard!
|
|
1213
|
+
logger.info(f"Opportunistically switching to ready shard {ready_id} (was waiting for {target_shard_id})")
|
|
1214
|
+
|
|
1215
|
+
# Update index to match
|
|
1216
|
+
try:
|
|
1217
|
+
new_idx = self.assigned_shard_ids.index(ready_id)
|
|
1218
|
+
self.current_shard_idx = new_idx
|
|
1219
|
+
except ValueError:
|
|
1220
|
+
pass
|
|
1221
|
+
|
|
1222
|
+
self.current_dataset = self._prefetch_ready.pop(ready_id)
|
|
1223
|
+
self.loaded_shards[ready_id] = self.current_dataset
|
|
1224
|
+
self.dataset_iterator = 0
|
|
1225
|
+
self._start_prefetch_next_unlocked()
|
|
1226
|
+
return
|
|
1227
|
+
|
|
1228
|
+
# 4. If still nothing, trigger download for target
|
|
1229
|
+
if target_shard_id in self.loading_shards or target_shard_id in self._prefetch_in_progress:
|
|
1230
|
+
logger.debug(f"Shard {target_shard_id} is already being downloaded, waiting...")
|
|
1231
|
+
return # Don't block
|
|
1232
|
+
|
|
1233
|
+
# Mark as loading and start download in background
|
|
1234
|
+
self.loading_shards.add(target_shard_id)
|
|
1235
|
+
|
|
1236
|
+
target_url = self.get_shard_url(target_shard_id)
|
|
1237
|
+
logger.info(f"Loading Shard {target_shard_id} from {target_url}")
|
|
1238
|
+
|
|
1239
|
+
# Submit to thread pool (non-blocking)
|
|
1240
|
+
self._download_executor.submit(self._load_shard_sync, target_shard_id, target_url)
|
|
1241
|
+
|
|
1242
|
+
def _start_prefetch_next_unlocked(self):
|
|
1243
|
+
"""Start prefetching next shard (call only when holding _shard_lock)."""
|
|
1244
|
+
# Schedule prefetch in background (don't hold lock during download)
|
|
1245
|
+
self._download_executor.submit(self._start_prefetch_next)
|
|
1246
|
+
|
|
1247
|
+
def _http_fallback_download(self, shard_id: int, target_url: str = None) -> str:
|
|
1248
|
+
"""Download shard from CloudFront CDN."""
|
|
1249
|
+
os.makedirs(self.cache_dir, exist_ok=True)
|
|
1250
|
+
shard_path = os.path.join(self.cache_dir, f"genesis_shard_{shard_id}.pt")
|
|
1251
|
+
|
|
1252
|
+
if os.path.exists(shard_path):
|
|
1253
|
+
return shard_path
|
|
1254
|
+
|
|
1255
|
+
# Use target URL from manifest, or construct CDN URL
|
|
1256
|
+
url = target_url or f"{self.GENESIS_CDN_URL}/shard_{shard_id}.pt"
|
|
1257
|
+
|
|
1258
|
+
try:
|
|
1259
|
+
with requests.get(url, stream=True, timeout=60) as r:
|
|
1260
|
+
r.raise_for_status()
|
|
1261
|
+
with open(shard_path, 'wb') as f:
|
|
1262
|
+
for chunk in r.iter_content(chunk_size=8192):
|
|
1263
|
+
f.write(chunk)
|
|
1264
|
+
logger.info(f"Downloaded shard {shard_id}: {os.path.getsize(shard_path)/1e6:.1f}MB")
|
|
1265
|
+
return shard_path
|
|
1266
|
+
except Exception as e:
|
|
1267
|
+
logger.error(f"Failed to download shard {shard_id} from {url}: {e}")
|
|
1268
|
+
raise RuntimeError(f"Failed to download shard {shard_id}: {e}")
|
|
1269
|
+
|
|
1270
|
+
def get_batch(self, batch_size: int = 4, seq_len: int = 512) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
1271
|
+
"""
|
|
1272
|
+
Get a batch from the current shard.
|
|
1273
|
+
|
|
1274
|
+
NON-BLOCKING VERSION: Returns quickly if data not ready.
|
|
1275
|
+
Uses prefetch buffer for instant shard switches.
|
|
1276
|
+
|
|
1277
|
+
Automatically rotates to next shard when current one is exhausted.
|
|
1278
|
+
Returns (input_ids, labels).
|
|
1279
|
+
|
|
1280
|
+
Raises RuntimeError if data not ready (caller should retry later).
|
|
1281
|
+
"""
|
|
1282
|
+
# Try to load from prefetch buffer first
|
|
1283
|
+
self.ensure_shard_loaded()
|
|
1284
|
+
|
|
1285
|
+
# NON-BLOCKING: Check if data is actually ready
|
|
1286
|
+
# Don't wait/block - let the caller handle the retry
|
|
1287
|
+
if self.current_dataset is None:
|
|
1288
|
+
# Check if anything is in progress
|
|
1289
|
+
with self._shard_lock:
|
|
1290
|
+
loading_any = bool(self.loading_shards or self._prefetch_in_progress)
|
|
1291
|
+
prefetch_ready = bool(self._prefetch_ready)
|
|
1292
|
+
|
|
1293
|
+
if prefetch_ready:
|
|
1294
|
+
# There's a prefetched shard - try to use it
|
|
1295
|
+
self.ensure_shard_loaded()
|
|
1296
|
+
elif not loading_any:
|
|
1297
|
+
# Nothing loading - kick off a new load
|
|
1298
|
+
self._start_prefetch_next()
|
|
1299
|
+
|
|
1300
|
+
# Return early - data not ready yet
|
|
1301
|
+
raise RuntimeError("Data not ready - shard still loading")
|
|
1302
|
+
|
|
1303
|
+
data_len = len(self.current_dataset)
|
|
1304
|
+
req_len = (batch_size * seq_len) + 1
|
|
1305
|
+
|
|
1306
|
+
# Check for early rotation due to loss plateau
|
|
1307
|
+
if self._should_rotate_early():
|
|
1308
|
+
current_shard = self.assigned_shard_ids[self.current_shard_idx % len(self.assigned_shard_ids)]
|
|
1309
|
+
logger.info(f"[GENESIS] Early rotation from shard {current_shard} due to loss plateau")
|
|
1310
|
+
self.force_shard_rotation("loss_plateau")
|
|
1311
|
+
# Ensure new shard is loaded
|
|
1312
|
+
self.ensure_shard_loaded()
|
|
1313
|
+
if self.current_dataset is None:
|
|
1314
|
+
raise RuntimeError("Data not ready - loading fresh shard after plateau")
|
|
1315
|
+
data_len = len(self.current_dataset)
|
|
1316
|
+
|
|
1317
|
+
# Check if we've exhausted current shard
|
|
1318
|
+
if self.dataset_iterator + req_len > data_len:
|
|
1319
|
+
# Log completion of current shard
|
|
1320
|
+
completed_shard = self.assigned_shard_ids[self.current_shard_idx % len(self.assigned_shard_ids)]
|
|
1321
|
+
steps_done = data_len // req_len
|
|
1322
|
+
logger.info(f"✓ Completed shard {completed_shard} ({steps_done} steps, {data_len:,} tokens)")
|
|
1323
|
+
|
|
1324
|
+
# Reset loss tracking for new shard
|
|
1325
|
+
self._loss_history.clear()
|
|
1326
|
+
self._steps_on_current_shard = 0
|
|
1327
|
+
|
|
1328
|
+
# Move to next shard in our assigned list
|
|
1329
|
+
self.current_shard_idx += 1
|
|
1330
|
+
|
|
1331
|
+
if self.current_shard_idx >= len(self.assigned_shard_ids):
|
|
1332
|
+
# We've gone through all assigned shards - rotate to new set
|
|
1333
|
+
logger.info(f"Exhausted all {len(self.assigned_shard_ids)} assigned shards. Rotating to new set...")
|
|
1334
|
+
self.rotate_shards()
|
|
1335
|
+
|
|
1336
|
+
# Try to use prefetched shard (FAST PATH)
|
|
1337
|
+
next_shard_id = self.assigned_shard_ids[self.current_shard_idx % len(self.assigned_shard_ids)]
|
|
1338
|
+
|
|
1339
|
+
with self._shard_lock:
|
|
1340
|
+
if next_shard_id in self._prefetch_ready:
|
|
1341
|
+
# Instant switch to prefetched shard
|
|
1342
|
+
self.current_dataset = self._prefetch_ready.pop(next_shard_id)
|
|
1343
|
+
self.loaded_shards[next_shard_id] = self.current_dataset
|
|
1344
|
+
logger.info(f"Switched to prefetched shard {next_shard_id}: {len(self.current_dataset)} tokens")
|
|
1345
|
+
elif next_shard_id in self.loaded_shards:
|
|
1346
|
+
self.current_dataset = self.loaded_shards[next_shard_id]
|
|
1347
|
+
else:
|
|
1348
|
+
# Need to wait for next shard - trigger load
|
|
1349
|
+
self.ensure_shard_loaded(next_shard_id)
|
|
1350
|
+
raise RuntimeError("Data not ready - loading next shard")
|
|
1351
|
+
|
|
1352
|
+
# Start prefetching the shard after next
|
|
1353
|
+
self._start_prefetch_next()
|
|
1354
|
+
|
|
1355
|
+
self.dataset_iterator = 0
|
|
1356
|
+
data_len = len(self.current_dataset)
|
|
1357
|
+
|
|
1358
|
+
start_idx = self.dataset_iterator
|
|
1359
|
+
end_idx = start_idx + req_len
|
|
1360
|
+
|
|
1361
|
+
chunk = self.current_dataset[start_idx:end_idx]
|
|
1362
|
+
self.dataset_iterator += req_len
|
|
1363
|
+
|
|
1364
|
+
# Log shard progress periodically (every 100 steps within shard)
|
|
1365
|
+
steps_in_shard = self.dataset_iterator // req_len
|
|
1366
|
+
total_steps_in_shard = data_len // req_len
|
|
1367
|
+
if steps_in_shard % 100 == 0:
|
|
1368
|
+
progress_pct = (self.dataset_iterator / data_len) * 100
|
|
1369
|
+
current_shard = self.assigned_shard_ids[self.current_shard_idx % len(self.assigned_shard_ids)]
|
|
1370
|
+
logger.info(f"Shard {current_shard} progress: {progress_pct:.1f}% "
|
|
1371
|
+
f"({steps_in_shard}/{total_steps_in_shard} steps)")
|
|
1372
|
+
|
|
1373
|
+
# Prepare batch
|
|
1374
|
+
exact_len = batch_size * seq_len
|
|
1375
|
+
|
|
1376
|
+
inputs = chunk[:exact_len].view(batch_size, seq_len)
|
|
1377
|
+
labels = chunk[1:exact_len+1].view(batch_size, seq_len)
|
|
1378
|
+
|
|
1379
|
+
return inputs, labels
|
|
1380
|
+
|
|
1381
|
+
def get_stats(self) -> dict:
|
|
1382
|
+
"""Get loader statistics."""
|
|
1383
|
+
# Calculate progress within current shard
|
|
1384
|
+
shard_progress = 0.0
|
|
1385
|
+
steps_in_shard = 0
|
|
1386
|
+
total_steps_in_shard = 0
|
|
1387
|
+
current_shard_id = None
|
|
1388
|
+
|
|
1389
|
+
if self.current_dataset is not None and len(self.current_dataset) > 0:
|
|
1390
|
+
data_len = len(self.current_dataset)
|
|
1391
|
+
req_len = 1025 # Approximate: batch_size * seq_len + 1
|
|
1392
|
+
shard_progress = (self.dataset_iterator / data_len) * 100
|
|
1393
|
+
steps_in_shard = self.dataset_iterator // req_len
|
|
1394
|
+
total_steps_in_shard = data_len // req_len
|
|
1395
|
+
if self.assigned_shard_ids:
|
|
1396
|
+
current_shard_id = self.assigned_shard_ids[self.current_shard_idx % len(self.assigned_shard_ids)]
|
|
1397
|
+
|
|
1398
|
+
# Compute loss plateau stats
|
|
1399
|
+
loss_avg = 0.0
|
|
1400
|
+
loss_variance = 0.0
|
|
1401
|
+
if self._loss_history:
|
|
1402
|
+
loss_avg = sum(self._loss_history) / len(self._loss_history)
|
|
1403
|
+
if len(self._loss_history) >= 2:
|
|
1404
|
+
loss_variance = sum((l - loss_avg) ** 2 for l in self._loss_history) / len(self._loss_history)
|
|
1405
|
+
|
|
1406
|
+
return {
|
|
1407
|
+
"total_shards_available": self.total_shards,
|
|
1408
|
+
"max_shards_configured": self.max_shards,
|
|
1409
|
+
"max_storage_mb": self.max_storage_mb,
|
|
1410
|
+
"assigned_shards": len(self.assigned_shard_ids),
|
|
1411
|
+
"loaded_shards": len(self.loaded_shards),
|
|
1412
|
+
"prefetch_in_progress": len(self._prefetch_in_progress),
|
|
1413
|
+
"prefetch_ready": len(self._prefetch_ready),
|
|
1414
|
+
"current_shard_idx": self.current_shard_idx,
|
|
1415
|
+
"current_shard_id": current_shard_id,
|
|
1416
|
+
"shard_progress_pct": round(shard_progress, 1),
|
|
1417
|
+
"steps_in_shard": steps_in_shard,
|
|
1418
|
+
"total_steps_in_shard": total_steps_in_shard,
|
|
1419
|
+
"rotation_count": self.shard_rotation_count,
|
|
1420
|
+
"storage_used_mb": len(self.loaded_shards) * self.SHARD_SIZE_MB,
|
|
1421
|
+
# Loss plateau detection stats
|
|
1422
|
+
"steps_on_current_shard": self._steps_on_current_shard,
|
|
1423
|
+
"loss_history_size": len(self._loss_history),
|
|
1424
|
+
"loss_avg": round(loss_avg, 6),
|
|
1425
|
+
"loss_variance": round(loss_variance, 8),
|
|
1426
|
+
"plateau_threshold": self._loss_plateau_threshold,
|
|
1427
|
+
}
|
|
1428
|
+
|
|
1429
|
+
|
|
1430
|
+
class DataValidator:
|
|
1431
|
+
"""
|
|
1432
|
+
Validates training data quality before it enters the buffer.
|
|
1433
|
+
|
|
1434
|
+
Prevents garbage/spam from polluting the local training set.
|
|
1435
|
+
"""
|
|
1436
|
+
def __init__(self):
|
|
1437
|
+
pass
|
|
1438
|
+
|
|
1439
|
+
def validate_text(self, text: str) -> Tuple[bool, str]:
|
|
1440
|
+
"""
|
|
1441
|
+
Validate text quality.
|
|
1442
|
+
Returns (is_valid, reason).
|
|
1443
|
+
"""
|
|
1444
|
+
if not text or not text.strip():
|
|
1445
|
+
return False, "Empty text"
|
|
1446
|
+
|
|
1447
|
+
if len(text) < 20:
|
|
1448
|
+
return False, "Text too short (<20 chars)"
|
|
1449
|
+
|
|
1450
|
+
# Entropy check (compression ratio)
|
|
1451
|
+
# Highly repetitive text compresses too well (ratio > 5.0)
|
|
1452
|
+
# Random text compresses poorly (ratio ~ 1.0)
|
|
1453
|
+
import zlib
|
|
1454
|
+
compressed = zlib.compress(text.encode())
|
|
1455
|
+
ratio = len(text) / len(compressed)
|
|
1456
|
+
|
|
1457
|
+
if ratio > 6.0:
|
|
1458
|
+
return False, f"High compression ratio ({ratio:.1f}) - likely repetitive spam"
|
|
1459
|
+
|
|
1460
|
+
if ratio < 1.1 and len(text) > 200:
|
|
1461
|
+
return False, f"Low compression ratio ({ratio:.1f}) - likely random gibberish"
|
|
1462
|
+
|
|
1463
|
+
# Basic character distribution check
|
|
1464
|
+
# Check if text is mostly special characters
|
|
1465
|
+
alnum_count = sum(c.isalnum() for c in text)
|
|
1466
|
+
if alnum_count / len(text) < 0.5:
|
|
1467
|
+
return False, "Too many special characters"
|
|
1468
|
+
|
|
1469
|
+
return True, "OK"
|
|
1470
|
+
|
|
1471
|
+
|
|
1472
|
+
class FederatedDataManager:
|
|
1473
|
+
"""
|
|
1474
|
+
Manages federated dataset for distributed training.
|
|
1475
|
+
|
|
1476
|
+
Nodes can contribute:
|
|
1477
|
+
1. Text data (tokenized)
|
|
1478
|
+
2. Curated datasets
|
|
1479
|
+
3. Synthetic data from other models
|
|
1480
|
+
|
|
1481
|
+
Privacy features:
|
|
1482
|
+
- Differential privacy (noise injection)
|
|
1483
|
+
- Data hashing (no raw text stored)
|
|
1484
|
+
- Local processing only
|
|
1485
|
+
"""
|
|
1486
|
+
|
|
1487
|
+
def __init__(self, tokenizer, max_seq_len: int = 2048):
|
|
1488
|
+
self.tokenizer = tokenizer
|
|
1489
|
+
self.max_seq_len = max_seq_len
|
|
1490
|
+
|
|
1491
|
+
# Validator
|
|
1492
|
+
self.validator = DataValidator()
|
|
1493
|
+
|
|
1494
|
+
# Local data buffer
|
|
1495
|
+
self.data_buffer: List[torch.Tensor] = []
|
|
1496
|
+
self.max_buffer_size = 10000
|
|
1497
|
+
|
|
1498
|
+
# Stats
|
|
1499
|
+
self.total_samples_contributed = 0
|
|
1500
|
+
self.total_tokens_contributed = 0
|
|
1501
|
+
self.rejected_samples = 0
|
|
1502
|
+
|
|
1503
|
+
def add_text(self, text: str, apply_dp: bool = True, epsilon: float = 1.0):
|
|
1504
|
+
"""
|
|
1505
|
+
Add text to the local training buffer.
|
|
1506
|
+
|
|
1507
|
+
Args:
|
|
1508
|
+
text: Raw text to add
|
|
1509
|
+
apply_dp: Apply differential privacy
|
|
1510
|
+
epsilon: DP epsilon (lower = more private)
|
|
1511
|
+
"""
|
|
1512
|
+
# Validate first
|
|
1513
|
+
is_valid, reason = self.validator.validate_text(text)
|
|
1514
|
+
if not is_valid:
|
|
1515
|
+
logger.warning(f"Rejected training data: {reason}")
|
|
1516
|
+
self.rejected_samples += 1
|
|
1517
|
+
return
|
|
1518
|
+
|
|
1519
|
+
# Tokenize
|
|
1520
|
+
tokens = self.tokenizer.encode(text)
|
|
1521
|
+
|
|
1522
|
+
if len(tokens) == 0:
|
|
1523
|
+
return
|
|
1524
|
+
|
|
1525
|
+
# Chunk into sequences with overlap
|
|
1526
|
+
# Use smaller chunk size for flexibility
|
|
1527
|
+
chunk_size = min(self.max_seq_len, 512) # Use 512 for training efficiency
|
|
1528
|
+
stride = chunk_size // 2 # 50% overlap
|
|
1529
|
+
|
|
1530
|
+
chunks_added = 0
|
|
1531
|
+
for i in range(0, max(1, len(tokens) - chunk_size + 1), stride):
|
|
1532
|
+
chunk = tokens[i:i + chunk_size]
|
|
1533
|
+
|
|
1534
|
+
# Pad if needed
|
|
1535
|
+
if len(chunk) < chunk_size:
|
|
1536
|
+
chunk = chunk + [self.tokenizer.pad_token_id] * (chunk_size - len(chunk))
|
|
1537
|
+
|
|
1538
|
+
tensor = torch.tensor(chunk, dtype=torch.long)
|
|
1539
|
+
|
|
1540
|
+
# Apply differential privacy (token-level noise)
|
|
1541
|
+
if apply_dp:
|
|
1542
|
+
tensor = self._apply_dp(tensor, epsilon)
|
|
1543
|
+
|
|
1544
|
+
self.data_buffer.append(tensor)
|
|
1545
|
+
self.total_samples_contributed += 1
|
|
1546
|
+
self.total_tokens_contributed += len(chunk)
|
|
1547
|
+
chunks_added += 1
|
|
1548
|
+
|
|
1549
|
+
# Also handle short texts (< chunk_size)
|
|
1550
|
+
if len(tokens) < chunk_size and chunks_added == 0:
|
|
1551
|
+
chunk = tokens + [self.tokenizer.pad_token_id] * (chunk_size - len(tokens))
|
|
1552
|
+
tensor = torch.tensor(chunk, dtype=torch.long)
|
|
1553
|
+
if apply_dp:
|
|
1554
|
+
tensor = self._apply_dp(tensor, epsilon)
|
|
1555
|
+
self.data_buffer.append(tensor)
|
|
1556
|
+
self.total_samples_contributed += 1
|
|
1557
|
+
self.total_tokens_contributed += len(tokens)
|
|
1558
|
+
|
|
1559
|
+
# Trim buffer if too large
|
|
1560
|
+
if len(self.data_buffer) > self.max_buffer_size:
|
|
1561
|
+
self.data_buffer = self.data_buffer[-self.max_buffer_size:]
|
|
1562
|
+
|
|
1563
|
+
def _apply_dp(self, tokens: torch.Tensor, epsilon: float) -> torch.Tensor:
|
|
1564
|
+
"""Apply differential privacy to tokens."""
|
|
1565
|
+
# Simple DP: randomly replace some tokens
|
|
1566
|
+
# More sophisticated methods would use the exponential mechanism
|
|
1567
|
+
noise_mask = torch.rand(tokens.shape) < (1.0 / epsilon)
|
|
1568
|
+
# Use current_vocab_size (not max vocab_size) to only sample valid tokens
|
|
1569
|
+
valid_vocab_size = getattr(self.tokenizer, 'current_vocab_size', 266)
|
|
1570
|
+
random_tokens = torch.randint(0, valid_vocab_size, tokens.shape)
|
|
1571
|
+
return torch.where(noise_mask, random_tokens, tokens)
|
|
1572
|
+
|
|
1573
|
+
def get_batch(self, batch_size: int) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
1574
|
+
"""
|
|
1575
|
+
Get a batch for training.
|
|
1576
|
+
|
|
1577
|
+
Returns:
|
|
1578
|
+
(input_ids, labels) - labels are shifted input_ids
|
|
1579
|
+
"""
|
|
1580
|
+
if len(self.data_buffer) < batch_size:
|
|
1581
|
+
raise ValueError(f"Not enough data: have {len(self.data_buffer)}, need {batch_size}")
|
|
1582
|
+
|
|
1583
|
+
# Random sample
|
|
1584
|
+
import random
|
|
1585
|
+
indices = random.sample(range(len(self.data_buffer)), batch_size)
|
|
1586
|
+
batch = torch.stack([self.data_buffer[i] for i in indices])
|
|
1587
|
+
|
|
1588
|
+
# For causal LM, labels = inputs shifted by 1
|
|
1589
|
+
input_ids = batch[:, :-1]
|
|
1590
|
+
labels = batch[:, 1:]
|
|
1591
|
+
|
|
1592
|
+
return input_ids, labels
|
|
1593
|
+
|
|
1594
|
+
def get_stats(self) -> Dict[str, Any]:
|
|
1595
|
+
"""Get data contribution stats."""
|
|
1596
|
+
return {
|
|
1597
|
+
"buffer_size": len(self.data_buffer),
|
|
1598
|
+
"total_samples": self.total_samples_contributed,
|
|
1599
|
+
"total_tokens": self.total_tokens_contributed,
|
|
1600
|
+
"rejected_samples": self.rejected_samples,
|
|
1601
|
+
}
|
|
1602
|
+
|