nexaroa 0.0.111__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- neuroshard/__init__.py +93 -0
- neuroshard/__main__.py +4 -0
- neuroshard/cli.py +466 -0
- neuroshard/core/__init__.py +92 -0
- neuroshard/core/consensus/verifier.py +252 -0
- neuroshard/core/crypto/__init__.py +20 -0
- neuroshard/core/crypto/ecdsa.py +392 -0
- neuroshard/core/economics/__init__.py +52 -0
- neuroshard/core/economics/constants.py +387 -0
- neuroshard/core/economics/ledger.py +2111 -0
- neuroshard/core/economics/market.py +975 -0
- neuroshard/core/economics/wallet.py +168 -0
- neuroshard/core/governance/__init__.py +74 -0
- neuroshard/core/governance/proposal.py +561 -0
- neuroshard/core/governance/registry.py +545 -0
- neuroshard/core/governance/versioning.py +332 -0
- neuroshard/core/governance/voting.py +453 -0
- neuroshard/core/model/__init__.py +30 -0
- neuroshard/core/model/dynamic.py +4186 -0
- neuroshard/core/model/llm.py +905 -0
- neuroshard/core/model/registry.py +164 -0
- neuroshard/core/model/scaler.py +387 -0
- neuroshard/core/model/tokenizer.py +568 -0
- neuroshard/core/network/__init__.py +56 -0
- neuroshard/core/network/connection_pool.py +72 -0
- neuroshard/core/network/dht.py +130 -0
- neuroshard/core/network/dht_plan.py +55 -0
- neuroshard/core/network/dht_proof_store.py +516 -0
- neuroshard/core/network/dht_protocol.py +261 -0
- neuroshard/core/network/dht_service.py +506 -0
- neuroshard/core/network/encrypted_channel.py +141 -0
- neuroshard/core/network/nat.py +201 -0
- neuroshard/core/network/nat_traversal.py +695 -0
- neuroshard/core/network/p2p.py +929 -0
- neuroshard/core/network/p2p_data.py +150 -0
- neuroshard/core/swarm/__init__.py +106 -0
- neuroshard/core/swarm/aggregation.py +729 -0
- neuroshard/core/swarm/buffers.py +643 -0
- neuroshard/core/swarm/checkpoint.py +709 -0
- neuroshard/core/swarm/compute.py +624 -0
- neuroshard/core/swarm/diloco.py +844 -0
- neuroshard/core/swarm/factory.py +1288 -0
- neuroshard/core/swarm/heartbeat.py +669 -0
- neuroshard/core/swarm/logger.py +487 -0
- neuroshard/core/swarm/router.py +658 -0
- neuroshard/core/swarm/service.py +640 -0
- neuroshard/core/training/__init__.py +29 -0
- neuroshard/core/training/checkpoint.py +600 -0
- neuroshard/core/training/distributed.py +1602 -0
- neuroshard/core/training/global_tracker.py +617 -0
- neuroshard/core/training/production.py +276 -0
- neuroshard/governance_cli.py +729 -0
- neuroshard/grpc_server.py +895 -0
- neuroshard/runner.py +3223 -0
- neuroshard/sdk/__init__.py +92 -0
- neuroshard/sdk/client.py +990 -0
- neuroshard/sdk/errors.py +101 -0
- neuroshard/sdk/types.py +282 -0
- neuroshard/tracker/__init__.py +0 -0
- neuroshard/tracker/server.py +864 -0
- neuroshard/ui/__init__.py +0 -0
- neuroshard/ui/app.py +102 -0
- neuroshard/ui/templates/index.html +1052 -0
- neuroshard/utils/__init__.py +0 -0
- neuroshard/utils/autostart.py +81 -0
- neuroshard/utils/hardware.py +121 -0
- neuroshard/utils/serialization.py +90 -0
- neuroshard/version.py +1 -0
- nexaroa-0.0.111.dist-info/METADATA +283 -0
- nexaroa-0.0.111.dist-info/RECORD +78 -0
- nexaroa-0.0.111.dist-info/WHEEL +5 -0
- nexaroa-0.0.111.dist-info/entry_points.txt +4 -0
- nexaroa-0.0.111.dist-info/licenses/LICENSE +190 -0
- nexaroa-0.0.111.dist-info/top_level.txt +2 -0
- protos/__init__.py +0 -0
- protos/neuroshard.proto +651 -0
- protos/neuroshard_pb2.py +160 -0
- protos/neuroshard_pb2_grpc.py +1298 -0
|
@@ -0,0 +1,600 @@
|
|
|
1
|
+
"""
|
|
2
|
+
Checkpoint Sharding System for NeuroShard
|
|
3
|
+
|
|
4
|
+
This module enables distributing model checkpoints across multiple nodes,
|
|
5
|
+
allowing the network to scale beyond what any single node can hold.
|
|
6
|
+
|
|
7
|
+
Architecture:
|
|
8
|
+
- Full model is split into N shards (typically by layer groups)
|
|
9
|
+
- Each shard is replicated across M nodes (default M=3 for redundancy)
|
|
10
|
+
- Shards are announced via DHT for discovery
|
|
11
|
+
- Training updates are coordinated across shard holders
|
|
12
|
+
|
|
13
|
+
Sharding Strategy:
|
|
14
|
+
- Phase 1-2 (Bootstrap/Early): Full model per node (no sharding needed)
|
|
15
|
+
- Phase 3 (Growth): Optional sharding for 7B model
|
|
16
|
+
- Phase 4 (Mature): Mandatory sharding for 70B+ model
|
|
17
|
+
|
|
18
|
+
Example for 70B model (80 layers):
|
|
19
|
+
- Shard 0: Embedding + Layers 0-19 (~70GB / 4 = ~17.5GB)
|
|
20
|
+
- Shard 1: Layers 20-39 (~17.5GB)
|
|
21
|
+
- Shard 2: Layers 40-59 (~17.5GB)
|
|
22
|
+
- Shard 3: Layers 60-79 + LM Head (~17.5GB)
|
|
23
|
+
"""
|
|
24
|
+
|
|
25
|
+
import hashlib
|
|
26
|
+
import json
|
|
27
|
+
import time
|
|
28
|
+
import threading
|
|
29
|
+
import logging
|
|
30
|
+
from typing import Dict, List, Optional, Tuple, Set, Any
|
|
31
|
+
from dataclasses import dataclass, field
|
|
32
|
+
from enum import Enum
|
|
33
|
+
from pathlib import Path
|
|
34
|
+
|
|
35
|
+
import torch
|
|
36
|
+
|
|
37
|
+
logger = logging.getLogger(__name__)
|
|
38
|
+
|
|
39
|
+
|
|
40
|
+
class ShardingStrategy(Enum):
|
|
41
|
+
"""How to split the model."""
|
|
42
|
+
NONE = "none" # Full model per node (Phase 1-2)
|
|
43
|
+
LAYER_GROUPS = "layer_groups" # Split by layer ranges (Phase 3-4)
|
|
44
|
+
TENSOR_PARALLEL = "tensor_parallel" # Split within layers (future)
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
@dataclass
|
|
48
|
+
class ShardConfig:
|
|
49
|
+
"""Configuration for a single shard."""
|
|
50
|
+
shard_id: int
|
|
51
|
+
total_shards: int
|
|
52
|
+
|
|
53
|
+
# Layer range this shard covers
|
|
54
|
+
start_layer: int
|
|
55
|
+
end_layer: int # Exclusive
|
|
56
|
+
|
|
57
|
+
# Special flags
|
|
58
|
+
has_embedding: bool = False
|
|
59
|
+
has_lm_head: bool = False
|
|
60
|
+
|
|
61
|
+
# Size estimates
|
|
62
|
+
estimated_size_mb: float = 0.0
|
|
63
|
+
|
|
64
|
+
def __post_init__(self):
|
|
65
|
+
"""Validate shard config."""
|
|
66
|
+
assert 0 <= self.shard_id < self.total_shards
|
|
67
|
+
assert self.start_layer < self.end_layer
|
|
68
|
+
|
|
69
|
+
|
|
70
|
+
@dataclass
|
|
71
|
+
class ShardState:
|
|
72
|
+
"""State of a shard on this node."""
|
|
73
|
+
config: ShardConfig
|
|
74
|
+
|
|
75
|
+
# Version tracking
|
|
76
|
+
version: int = 0
|
|
77
|
+
model_hash: str = ""
|
|
78
|
+
|
|
79
|
+
# Weights (loaded or None)
|
|
80
|
+
weights: Optional[Dict[str, torch.Tensor]] = None
|
|
81
|
+
|
|
82
|
+
# Status
|
|
83
|
+
is_loaded: bool = False
|
|
84
|
+
last_updated: float = field(default_factory=time.time)
|
|
85
|
+
|
|
86
|
+
# Replication
|
|
87
|
+
replicas: Set[str] = field(default_factory=set) # Node URLs holding copies
|
|
88
|
+
|
|
89
|
+
|
|
90
|
+
@dataclass
|
|
91
|
+
class ShardAnnouncement:
|
|
92
|
+
"""Announcement of shard availability."""
|
|
93
|
+
node_id: str
|
|
94
|
+
node_url: str
|
|
95
|
+
grpc_addr: str
|
|
96
|
+
|
|
97
|
+
shard_id: int
|
|
98
|
+
total_shards: int
|
|
99
|
+
version: int
|
|
100
|
+
model_hash: str
|
|
101
|
+
|
|
102
|
+
# Capacity info
|
|
103
|
+
available_memory_mb: float
|
|
104
|
+
current_load: float # 0-1
|
|
105
|
+
|
|
106
|
+
timestamp: float = field(default_factory=time.time)
|
|
107
|
+
|
|
108
|
+
def to_dict(self) -> Dict:
|
|
109
|
+
return {
|
|
110
|
+
"node_id": self.node_id,
|
|
111
|
+
"node_url": self.node_url,
|
|
112
|
+
"grpc_addr": self.grpc_addr,
|
|
113
|
+
"shard_id": self.shard_id,
|
|
114
|
+
"total_shards": self.total_shards,
|
|
115
|
+
"version": self.version,
|
|
116
|
+
"model_hash": self.model_hash,
|
|
117
|
+
"available_memory_mb": self.available_memory_mb,
|
|
118
|
+
"current_load": self.current_load,
|
|
119
|
+
"timestamp": self.timestamp,
|
|
120
|
+
}
|
|
121
|
+
|
|
122
|
+
@classmethod
|
|
123
|
+
def from_dict(cls, data: Dict) -> 'ShardAnnouncement':
|
|
124
|
+
return cls(**data)
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
class ShardRegistry:
|
|
128
|
+
"""
|
|
129
|
+
Registry for tracking which nodes hold which shards.
|
|
130
|
+
|
|
131
|
+
Uses DHT for decentralized discovery:
|
|
132
|
+
- Key: "shard_{model_id}_{shard_id}"
|
|
133
|
+
- Value: JSON list of ShardAnnouncements
|
|
134
|
+
"""
|
|
135
|
+
|
|
136
|
+
def __init__(self, dht_protocol=None, model_id: str = "neuro_llm"):
|
|
137
|
+
self.dht = dht_protocol
|
|
138
|
+
self.model_id = model_id
|
|
139
|
+
|
|
140
|
+
# Local cache
|
|
141
|
+
self.shard_holders: Dict[int, List[ShardAnnouncement]] = {}
|
|
142
|
+
self.lock = threading.Lock()
|
|
143
|
+
|
|
144
|
+
# Refresh interval
|
|
145
|
+
self.cache_ttl = 60 # seconds
|
|
146
|
+
self.last_refresh: Dict[int, float] = {}
|
|
147
|
+
|
|
148
|
+
def announce_shard(self, announcement: ShardAnnouncement) -> bool:
|
|
149
|
+
"""Announce that we hold a shard."""
|
|
150
|
+
if not self.dht:
|
|
151
|
+
logger.warning("No DHT available for shard announcement")
|
|
152
|
+
return False
|
|
153
|
+
|
|
154
|
+
try:
|
|
155
|
+
key = f"shard_{self.model_id}_{announcement.shard_id}"
|
|
156
|
+
|
|
157
|
+
# Get existing announcements
|
|
158
|
+
existing = self._get_shard_holders_from_dht(announcement.shard_id)
|
|
159
|
+
|
|
160
|
+
# Add/update our announcement
|
|
161
|
+
updated = False
|
|
162
|
+
for i, ann in enumerate(existing):
|
|
163
|
+
if ann.node_id == announcement.node_id:
|
|
164
|
+
existing[i] = announcement
|
|
165
|
+
updated = True
|
|
166
|
+
break
|
|
167
|
+
|
|
168
|
+
if not updated:
|
|
169
|
+
existing.append(announcement)
|
|
170
|
+
|
|
171
|
+
# Store back to DHT
|
|
172
|
+
value = json.dumps([a.to_dict() for a in existing])
|
|
173
|
+
self.dht.store(key, value)
|
|
174
|
+
|
|
175
|
+
# Update local cache
|
|
176
|
+
with self.lock:
|
|
177
|
+
self.shard_holders[announcement.shard_id] = existing
|
|
178
|
+
|
|
179
|
+
logger.info(f"Announced shard {announcement.shard_id} "
|
|
180
|
+
f"(version={announcement.version}, holders={len(existing)})")
|
|
181
|
+
return True
|
|
182
|
+
|
|
183
|
+
except Exception as e:
|
|
184
|
+
logger.error(f"Failed to announce shard: {e}")
|
|
185
|
+
return False
|
|
186
|
+
|
|
187
|
+
def get_shard_holders(self, shard_id: int, refresh: bool = False) -> List[ShardAnnouncement]:
|
|
188
|
+
"""Get nodes holding a specific shard."""
|
|
189
|
+
with self.lock:
|
|
190
|
+
# Check cache
|
|
191
|
+
if not refresh and shard_id in self.shard_holders:
|
|
192
|
+
if time.time() - self.last_refresh.get(shard_id, 0) < self.cache_ttl:
|
|
193
|
+
return self.shard_holders[shard_id]
|
|
194
|
+
|
|
195
|
+
# Refresh from DHT
|
|
196
|
+
holders = self._get_shard_holders_from_dht(shard_id)
|
|
197
|
+
|
|
198
|
+
with self.lock:
|
|
199
|
+
self.shard_holders[shard_id] = holders
|
|
200
|
+
self.last_refresh[shard_id] = time.time()
|
|
201
|
+
|
|
202
|
+
return holders
|
|
203
|
+
|
|
204
|
+
def _get_shard_holders_from_dht(self, shard_id: int) -> List[ShardAnnouncement]:
|
|
205
|
+
"""Query DHT for shard holders."""
|
|
206
|
+
if not self.dht:
|
|
207
|
+
return []
|
|
208
|
+
|
|
209
|
+
try:
|
|
210
|
+
key = f"shard_{self.model_id}_{shard_id}"
|
|
211
|
+
value = self.dht.lookup_value(key)
|
|
212
|
+
|
|
213
|
+
if not value:
|
|
214
|
+
return []
|
|
215
|
+
|
|
216
|
+
data = json.loads(value)
|
|
217
|
+
return [ShardAnnouncement.from_dict(d) for d in data]
|
|
218
|
+
|
|
219
|
+
except Exception as e:
|
|
220
|
+
logger.debug(f"DHT lookup failed for shard {shard_id}: {e}")
|
|
221
|
+
return []
|
|
222
|
+
|
|
223
|
+
def find_best_holder(self, shard_id: int) -> Optional[ShardAnnouncement]:
|
|
224
|
+
"""Find the best node to request a shard from."""
|
|
225
|
+
holders = self.get_shard_holders(shard_id)
|
|
226
|
+
|
|
227
|
+
if not holders:
|
|
228
|
+
return None
|
|
229
|
+
|
|
230
|
+
# Sort by: highest version, lowest load, most recent
|
|
231
|
+
holders.sort(key=lambda h: (-h.version, h.current_load, -h.timestamp))
|
|
232
|
+
|
|
233
|
+
return holders[0]
|
|
234
|
+
|
|
235
|
+
def get_all_shards_status(self, total_shards: int) -> Dict[int, Dict]:
|
|
236
|
+
"""Get status of all shards."""
|
|
237
|
+
status = {}
|
|
238
|
+
|
|
239
|
+
for shard_id in range(total_shards):
|
|
240
|
+
holders = self.get_shard_holders(shard_id)
|
|
241
|
+
|
|
242
|
+
status[shard_id] = {
|
|
243
|
+
"shard_id": shard_id,
|
|
244
|
+
"holder_count": len(holders),
|
|
245
|
+
"is_available": len(holders) > 0,
|
|
246
|
+
"is_redundant": len(holders) >= 3,
|
|
247
|
+
"max_version": max((h.version for h in holders), default=0),
|
|
248
|
+
"holders": [h.node_url for h in holders],
|
|
249
|
+
}
|
|
250
|
+
|
|
251
|
+
return status
|
|
252
|
+
|
|
253
|
+
|
|
254
|
+
class CheckpointShardManager:
|
|
255
|
+
"""
|
|
256
|
+
Manages sharded checkpoints for a NeuroLLM model.
|
|
257
|
+
|
|
258
|
+
Responsibilities:
|
|
259
|
+
1. Determine sharding strategy based on model size and node capacity
|
|
260
|
+
2. Load/save individual shards
|
|
261
|
+
3. Coordinate with ShardRegistry for discovery
|
|
262
|
+
4. Handle shard replication
|
|
263
|
+
"""
|
|
264
|
+
|
|
265
|
+
# Minimum replicas per shard
|
|
266
|
+
MIN_REPLICAS = 3
|
|
267
|
+
|
|
268
|
+
def __init__(
|
|
269
|
+
self,
|
|
270
|
+
model, # NeuroLLMForCausalLM
|
|
271
|
+
node_id: str,
|
|
272
|
+
node_url: str,
|
|
273
|
+
registry: Optional[ShardRegistry] = None,
|
|
274
|
+
checkpoint_dir: Optional[Path] = None
|
|
275
|
+
):
|
|
276
|
+
self.model = model
|
|
277
|
+
self.node_id = node_id
|
|
278
|
+
self.node_url = node_url
|
|
279
|
+
self.registry = registry
|
|
280
|
+
self.checkpoint_dir = checkpoint_dir or Path.home() / ".neuroshard" / "shards"
|
|
281
|
+
self.checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
|
282
|
+
|
|
283
|
+
# Current shard state
|
|
284
|
+
self.my_shards: Dict[int, ShardState] = {}
|
|
285
|
+
self.shard_configs: List[ShardConfig] = []
|
|
286
|
+
|
|
287
|
+
# Threading
|
|
288
|
+
self.lock = threading.Lock()
|
|
289
|
+
|
|
290
|
+
def compute_sharding_strategy(
|
|
291
|
+
self,
|
|
292
|
+
total_params: int,
|
|
293
|
+
available_memory_mb: float,
|
|
294
|
+
num_layers: int
|
|
295
|
+
) -> Tuple[ShardingStrategy, List[ShardConfig]]:
|
|
296
|
+
"""
|
|
297
|
+
Determine optimal sharding strategy.
|
|
298
|
+
|
|
299
|
+
Args:
|
|
300
|
+
total_params: Total model parameters
|
|
301
|
+
available_memory_mb: Available memory on this node
|
|
302
|
+
num_layers: Number of transformer layers
|
|
303
|
+
|
|
304
|
+
Returns:
|
|
305
|
+
(strategy, list of shard configs)
|
|
306
|
+
"""
|
|
307
|
+
# Estimate model size (4 bytes per param for float32)
|
|
308
|
+
model_size_mb = (total_params * 4) / (1024 * 1024)
|
|
309
|
+
|
|
310
|
+
# Add overhead for optimizer state, activations (2x)
|
|
311
|
+
required_memory_mb = model_size_mb * 2
|
|
312
|
+
|
|
313
|
+
logger.info(f"Model size: {model_size_mb:.1f}MB, "
|
|
314
|
+
f"Required: {required_memory_mb:.1f}MB, "
|
|
315
|
+
f"Available: {available_memory_mb:.1f}MB")
|
|
316
|
+
|
|
317
|
+
# Strategy decision
|
|
318
|
+
if required_memory_mb <= available_memory_mb:
|
|
319
|
+
# Can hold full model
|
|
320
|
+
return ShardingStrategy.NONE, [
|
|
321
|
+
ShardConfig(
|
|
322
|
+
shard_id=0,
|
|
323
|
+
total_shards=1,
|
|
324
|
+
start_layer=0,
|
|
325
|
+
end_layer=num_layers,
|
|
326
|
+
has_embedding=True,
|
|
327
|
+
has_lm_head=True,
|
|
328
|
+
estimated_size_mb=model_size_mb
|
|
329
|
+
)
|
|
330
|
+
]
|
|
331
|
+
|
|
332
|
+
# Need to shard - calculate number of shards needed
|
|
333
|
+
shard_memory_target = available_memory_mb * 0.8 # Leave 20% headroom
|
|
334
|
+
num_shards = max(2, int(required_memory_mb / shard_memory_target) + 1)
|
|
335
|
+
|
|
336
|
+
# Round up to power of 2 for cleaner division
|
|
337
|
+
num_shards = 2 ** (num_shards - 1).bit_length()
|
|
338
|
+
num_shards = min(num_shards, num_layers) # Can't have more shards than layers
|
|
339
|
+
|
|
340
|
+
# Create shard configs
|
|
341
|
+
layers_per_shard = num_layers // num_shards
|
|
342
|
+
shard_size_mb = model_size_mb / num_shards
|
|
343
|
+
|
|
344
|
+
configs = []
|
|
345
|
+
for i in range(num_shards):
|
|
346
|
+
start = i * layers_per_shard
|
|
347
|
+
end = (i + 1) * layers_per_shard if i < num_shards - 1 else num_layers
|
|
348
|
+
|
|
349
|
+
configs.append(ShardConfig(
|
|
350
|
+
shard_id=i,
|
|
351
|
+
total_shards=num_shards,
|
|
352
|
+
start_layer=start,
|
|
353
|
+
end_layer=end,
|
|
354
|
+
has_embedding=(i == 0),
|
|
355
|
+
has_lm_head=(i == num_shards - 1),
|
|
356
|
+
estimated_size_mb=shard_size_mb
|
|
357
|
+
))
|
|
358
|
+
|
|
359
|
+
logger.info(f"Sharding strategy: {num_shards} shards, "
|
|
360
|
+
f"{layers_per_shard} layers each, "
|
|
361
|
+
f"~{shard_size_mb:.1f}MB per shard")
|
|
362
|
+
|
|
363
|
+
return ShardingStrategy.LAYER_GROUPS, configs
|
|
364
|
+
|
|
365
|
+
def extract_shard_weights(self, shard_config: ShardConfig) -> Dict[str, torch.Tensor]:
|
|
366
|
+
"""Extract weights for a specific shard from the full model."""
|
|
367
|
+
state_dict = self.model.state_dict()
|
|
368
|
+
shard_weights = {}
|
|
369
|
+
|
|
370
|
+
for name, param in state_dict.items():
|
|
371
|
+
# Check if this parameter belongs to this shard
|
|
372
|
+
if self._param_belongs_to_shard(name, shard_config):
|
|
373
|
+
shard_weights[name] = param.clone()
|
|
374
|
+
|
|
375
|
+
return shard_weights
|
|
376
|
+
|
|
377
|
+
def _param_belongs_to_shard(self, param_name: str, config: ShardConfig) -> bool:
|
|
378
|
+
"""Check if a parameter belongs to a shard."""
|
|
379
|
+
# Embedding layer
|
|
380
|
+
if "embed" in param_name.lower() or "wte" in param_name.lower():
|
|
381
|
+
return config.has_embedding
|
|
382
|
+
|
|
383
|
+
# LM head
|
|
384
|
+
if "lm_head" in param_name.lower() or "output" in param_name.lower():
|
|
385
|
+
return config.has_lm_head
|
|
386
|
+
|
|
387
|
+
# Transformer layers - extract layer number
|
|
388
|
+
import re
|
|
389
|
+
match = re.search(r'layers?[._](\d+)', param_name.lower())
|
|
390
|
+
if match:
|
|
391
|
+
layer_num = int(match.group(1))
|
|
392
|
+
return config.start_layer <= layer_num < config.end_layer
|
|
393
|
+
|
|
394
|
+
# Final norm (belongs to last shard)
|
|
395
|
+
if "final" in param_name.lower() or "ln_f" in param_name.lower():
|
|
396
|
+
return config.has_lm_head
|
|
397
|
+
|
|
398
|
+
# Default: include in first shard
|
|
399
|
+
return config.shard_id == 0
|
|
400
|
+
|
|
401
|
+
def save_shard(self, shard_id: int, version: int) -> Optional[Path]:
|
|
402
|
+
"""Save a shard to disk."""
|
|
403
|
+
if shard_id not in self.my_shards:
|
|
404
|
+
logger.error(f"Shard {shard_id} not loaded")
|
|
405
|
+
return None
|
|
406
|
+
|
|
407
|
+
shard_state = self.my_shards[shard_id]
|
|
408
|
+
|
|
409
|
+
if not shard_state.weights:
|
|
410
|
+
# Extract from model
|
|
411
|
+
shard_state.weights = self.extract_shard_weights(shard_state.config)
|
|
412
|
+
|
|
413
|
+
# Compute hash
|
|
414
|
+
model_hash = self._compute_weights_hash(shard_state.weights)
|
|
415
|
+
|
|
416
|
+
# Save
|
|
417
|
+
filename = f"shard_{shard_id}_v{version}.pt"
|
|
418
|
+
path = self.checkpoint_dir / filename
|
|
419
|
+
|
|
420
|
+
torch.save({
|
|
421
|
+
"shard_id": shard_id,
|
|
422
|
+
"config": {
|
|
423
|
+
"shard_id": shard_state.config.shard_id,
|
|
424
|
+
"total_shards": shard_state.config.total_shards,
|
|
425
|
+
"start_layer": shard_state.config.start_layer,
|
|
426
|
+
"end_layer": shard_state.config.end_layer,
|
|
427
|
+
"has_embedding": shard_state.config.has_embedding,
|
|
428
|
+
"has_lm_head": shard_state.config.has_lm_head,
|
|
429
|
+
},
|
|
430
|
+
"weights": shard_state.weights,
|
|
431
|
+
"version": version,
|
|
432
|
+
"model_hash": model_hash,
|
|
433
|
+
"timestamp": time.time(),
|
|
434
|
+
}, path)
|
|
435
|
+
|
|
436
|
+
# Update state
|
|
437
|
+
shard_state.version = version
|
|
438
|
+
shard_state.model_hash = model_hash
|
|
439
|
+
shard_state.last_updated = time.time()
|
|
440
|
+
|
|
441
|
+
logger.info(f"Saved shard {shard_id} v{version} to {path}")
|
|
442
|
+
|
|
443
|
+
return path
|
|
444
|
+
|
|
445
|
+
def load_shard(self, shard_id: int, path: Optional[Path] = None) -> bool:
|
|
446
|
+
"""Load a shard from disk."""
|
|
447
|
+
if path is None:
|
|
448
|
+
# Find latest version
|
|
449
|
+
pattern = f"shard_{shard_id}_v*.pt"
|
|
450
|
+
files = list(self.checkpoint_dir.glob(pattern))
|
|
451
|
+
if not files:
|
|
452
|
+
logger.warning(f"No shard {shard_id} found on disk")
|
|
453
|
+
return False
|
|
454
|
+
path = max(files, key=lambda p: p.stat().st_mtime)
|
|
455
|
+
|
|
456
|
+
try:
|
|
457
|
+
data = torch.load(path, map_location="cpu", weights_only=False)
|
|
458
|
+
|
|
459
|
+
config = ShardConfig(**data["config"])
|
|
460
|
+
|
|
461
|
+
self.my_shards[shard_id] = ShardState(
|
|
462
|
+
config=config,
|
|
463
|
+
version=data["version"],
|
|
464
|
+
model_hash=data["model_hash"],
|
|
465
|
+
weights=data["weights"],
|
|
466
|
+
is_loaded=True,
|
|
467
|
+
last_updated=data.get("timestamp", time.time()),
|
|
468
|
+
)
|
|
469
|
+
|
|
470
|
+
logger.info(f"Loaded shard {shard_id} v{data['version']} from {path}")
|
|
471
|
+
return True
|
|
472
|
+
|
|
473
|
+
except Exception as e:
|
|
474
|
+
logger.error(f"Failed to load shard {shard_id}: {e}")
|
|
475
|
+
return False
|
|
476
|
+
|
|
477
|
+
def load_shard_into_model(self, shard_id: int) -> bool:
|
|
478
|
+
"""Load shard weights into the model."""
|
|
479
|
+
if shard_id not in self.my_shards:
|
|
480
|
+
logger.error(f"Shard {shard_id} not loaded")
|
|
481
|
+
return False
|
|
482
|
+
|
|
483
|
+
shard_state = self.my_shards[shard_id]
|
|
484
|
+
|
|
485
|
+
if not shard_state.weights:
|
|
486
|
+
logger.error(f"Shard {shard_id} has no weights")
|
|
487
|
+
return False
|
|
488
|
+
|
|
489
|
+
# Load into model
|
|
490
|
+
current_state = self.model.state_dict()
|
|
491
|
+
|
|
492
|
+
for name, param in shard_state.weights.items():
|
|
493
|
+
if name in current_state:
|
|
494
|
+
current_state[name].copy_(param)
|
|
495
|
+
else:
|
|
496
|
+
logger.warning(f"Parameter {name} not found in model")
|
|
497
|
+
|
|
498
|
+
logger.info(f"Loaded shard {shard_id} weights into model")
|
|
499
|
+
return True
|
|
500
|
+
|
|
501
|
+
def _compute_weights_hash(self, weights: Dict[str, torch.Tensor]) -> str:
|
|
502
|
+
"""Compute hash of weights."""
|
|
503
|
+
hasher = hashlib.sha256()
|
|
504
|
+
|
|
505
|
+
for name in sorted(weights.keys()):
|
|
506
|
+
hasher.update(name.encode())
|
|
507
|
+
hasher.update(weights[name].cpu().numpy().tobytes()[:1000])
|
|
508
|
+
|
|
509
|
+
return hasher.hexdigest()[:16]
|
|
510
|
+
|
|
511
|
+
def get_shard_status(self) -> Dict[str, Any]:
|
|
512
|
+
"""Get status of all shards on this node."""
|
|
513
|
+
return {
|
|
514
|
+
"node_id": self.node_id,
|
|
515
|
+
"shards": {
|
|
516
|
+
shard_id: {
|
|
517
|
+
"shard_id": shard_id,
|
|
518
|
+
"version": state.version,
|
|
519
|
+
"model_hash": state.model_hash,
|
|
520
|
+
"is_loaded": state.is_loaded,
|
|
521
|
+
"layers": f"{state.config.start_layer}-{state.config.end_layer}",
|
|
522
|
+
"has_embedding": state.config.has_embedding,
|
|
523
|
+
"has_lm_head": state.config.has_lm_head,
|
|
524
|
+
"replicas": len(state.replicas),
|
|
525
|
+
}
|
|
526
|
+
for shard_id, state in self.my_shards.items()
|
|
527
|
+
},
|
|
528
|
+
"total_shards": len(self.shard_configs),
|
|
529
|
+
}
|
|
530
|
+
|
|
531
|
+
|
|
532
|
+
def compute_shard_assignment(
|
|
533
|
+
node_capacities: Dict[str, float], # node_id -> available_memory_mb
|
|
534
|
+
model_size_mb: float,
|
|
535
|
+
num_layers: int,
|
|
536
|
+
min_replicas: int = 3
|
|
537
|
+
) -> Dict[str, List[int]]:
|
|
538
|
+
"""
|
|
539
|
+
Compute optimal shard assignment across nodes.
|
|
540
|
+
|
|
541
|
+
This is a simplified greedy algorithm. A production system would use
|
|
542
|
+
more sophisticated optimization (e.g., constraint satisfaction, ILP).
|
|
543
|
+
|
|
544
|
+
Args:
|
|
545
|
+
node_capacities: Available memory per node
|
|
546
|
+
model_size_mb: Total model size
|
|
547
|
+
num_layers: Number of transformer layers
|
|
548
|
+
min_replicas: Minimum replicas per shard
|
|
549
|
+
|
|
550
|
+
Returns:
|
|
551
|
+
Dict mapping node_id -> list of shard_ids to hold
|
|
552
|
+
"""
|
|
553
|
+
# Sort nodes by capacity (largest first)
|
|
554
|
+
sorted_nodes = sorted(node_capacities.items(), key=lambda x: -x[1])
|
|
555
|
+
|
|
556
|
+
if not sorted_nodes:
|
|
557
|
+
return {}
|
|
558
|
+
|
|
559
|
+
# Determine number of shards based on largest node
|
|
560
|
+
max_capacity = sorted_nodes[0][1]
|
|
561
|
+
shard_size = max_capacity * 0.8 # Leave headroom
|
|
562
|
+
num_shards = max(1, int(model_size_mb / shard_size) + 1)
|
|
563
|
+
num_shards = min(num_shards, num_layers)
|
|
564
|
+
|
|
565
|
+
actual_shard_size = model_size_mb / num_shards
|
|
566
|
+
|
|
567
|
+
# Assign shards to nodes (round-robin with capacity check)
|
|
568
|
+
assignments: Dict[str, List[int]] = {node_id: [] for node_id, _ in sorted_nodes}
|
|
569
|
+
shard_counts = {shard_id: 0 for shard_id in range(num_shards)}
|
|
570
|
+
|
|
571
|
+
# First pass: assign one shard per node (ensure coverage)
|
|
572
|
+
for shard_id in range(num_shards):
|
|
573
|
+
for node_id, capacity in sorted_nodes:
|
|
574
|
+
current_load = len(assignments[node_id]) * actual_shard_size
|
|
575
|
+
if current_load + actual_shard_size <= capacity:
|
|
576
|
+
assignments[node_id].append(shard_id)
|
|
577
|
+
shard_counts[shard_id] += 1
|
|
578
|
+
break
|
|
579
|
+
|
|
580
|
+
# Second pass: add replicas until min_replicas reached
|
|
581
|
+
for shard_id in range(num_shards):
|
|
582
|
+
while shard_counts[shard_id] < min_replicas:
|
|
583
|
+
# Find node with capacity that doesn't already have this shard
|
|
584
|
+
assigned = False
|
|
585
|
+
for node_id, capacity in sorted_nodes:
|
|
586
|
+
if shard_id in assignments[node_id]:
|
|
587
|
+
continue
|
|
588
|
+
current_load = len(assignments[node_id]) * actual_shard_size
|
|
589
|
+
if current_load + actual_shard_size <= capacity:
|
|
590
|
+
assignments[node_id].append(shard_id)
|
|
591
|
+
shard_counts[shard_id] += 1
|
|
592
|
+
assigned = True
|
|
593
|
+
break
|
|
594
|
+
|
|
595
|
+
if not assigned:
|
|
596
|
+
logger.warning(f"Could not assign {min_replicas} replicas for shard {shard_id}")
|
|
597
|
+
break
|
|
598
|
+
|
|
599
|
+
return assignments
|
|
600
|
+
|